Files
ArtStudies/M2/Reinforcement Learning/Lab 4 - Monte Carlo Control in Blackjack game.ipynb
2026-01-13 10:36:09 +01:00

2570 lines
423 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "2846680b-b04d-41ab-8891-85b6db569a29",
"metadata": {},
"source": [
"# Lab 4 - Monte Carlo Control in Blackjack game\n",
"\n",
"## 1. Objectives\n",
"\n",
"In this lab, we will:\n",
"\n",
"1. Understand how Monte Carlo methods can be used to estimate value functions from experience, without knowing the transition probabilities of the environment.\n",
"\n",
"2. Model the Blackjack game as an episodic Markov Decision Process (MDP) with terminal rewards.\n",
"\n",
"3. Implement first-visit Monte Carlo policy evaluation to estimate state-value and action-value functions.\n",
"\n",
"4. Apply Monte Carlo control to learn an optimal policy for Blackjack.\n",
"\n",
"5. Address the exploration problem by using exploring starts and $\\varepsilon$-soft policies.\n",
"\n",
"6. Study off-policy Monte Carlo learning and understand how a target policy can be evaluated using data generated by a different behavior policy. Implement importance sampling (ordinary and weighted) to correct the distribution mismatch in off-policy learning.\n",
"\n",
"-------------------"
]
},
{
"cell_type": "markdown",
"id": "3c9bbf3c-3974-4996-b1c9-805dd9935905",
"metadata": {},
"source": [
"## 2. Blackjack game\n",
"\n",
"Blackjack is a card game played between a **player** and a **dealer**. The goal is to obtain a hand value as close as possible to 21, without exceeding it.\n",
"\n",
"- Card values\n",
" - Number cards (210): face value\n",
" - Face cards (J, Q, K): value 10\n",
" - Ace (A): value 1 or 11\n",
" - The Ace (A) is a special card in Blackjack because it can take two possible values: 1 or 11.\n",
" - When an Ace is counted as 11 without making the hand exceed 21, we say the hand has a usable ace.\n",
" - If counting the Ace as 11 would cause the total to exceed 21, the Ace is automatically counted as 1.\n",
" - This choice is not an action taken by the player; it is an automatic rule of the game that always favors avoiding a bust.\n",
"\n",
"\n",
"- Game flow\n",
" - The game starts by giving the player an initial hand and showing one visible card for the dealer.\n",
" - The player then plays first and repeatedly chooses between:\n",
" - **Hit**: draw one additional card\n",
" - **Stick**: stop drawing cards\n",
" - If the players total exceeds 21 at any time, the episode ends immediately with a loss.\n",
" - After the player sticks, the dealer plays:\n",
" - The dealer must hit until the hand value is at least 17\n",
" - Then the dealer sticks\n",
"\n",
"- Outcome and rewards\n",
" - Player wins if their final hand value is higher than the dealers without busting.\n",
" - Player loses if they bust or if the dealer has a higher hand.\n",
" - A draw (tie) gives zero reward.\n",
"\n",
"**Next, we will model the Blackjack game as an episodic Markov Decision Process (MDP) with terminal rewards.**"
]
},
{
"cell_type": "markdown",
"id": "52633063-0087-4290-9b04-ed23dd8ab04d",
"metadata": {},
"source": [
"### 2.0 Imports and global settings"
]
},
{
"cell_type": "code",
"execution_count": 218,
"id": "ae5c04c2-e544-4307-a82e-63d920383bb5",
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Reproducibility\n",
"rng = np.random.default_rng(0)"
]
},
{
"cell_type": "markdown",
"id": "d3bca55f-6db6-4400-9b8f-c21cfb0c14e5",
"metadata": {},
"source": [
"### 2.1 Blackjack environment (State, Action, Episode, Rewards)\n",
"\n",
"1. State is a summary of the game \n",
"$$\n",
"s=(p,d,u)\n",
"$$\n",
"where\n",
" - $p =$ player's current sum \n",
" - $d =$ dealers showing card (110)\n",
" - $u =$ usable ace indicator (0/1)\n",
"\n",
" \n",
"2. Player's action :\n",
" - STICK = 0\n",
" - HIT = 1\n",
" \n",
"3. Episode ends when:\n",
" - player busts (sum > 21), or\n",
" - player sticks and dealer finishes their fixed policy (hit until 17)\n",
"\n",
"4. Rewards:\n",
" - player bust: 1\n",
" - else at terminal after dealer plays: win +1, lose 1, draw 0\n",
" - Here, intermediate rewards are 0; only terminal reward matters. \n",
"\n",
"5. Uniform card draws: Here in this lab, we assume a uniform draw of cards to simplify the environment and preserve the Markov property.\n",
" - With uniform draws: future cards are independent of the past; the state $(p,d,u)$ is sufficient.\n",
" - With a real deck: card probabilities depend on the history of drawn cards; therefore, the state $(p,d,u)$ is no longer Markov."
]
},
{
"cell_type": "code",
"execution_count": 219,
"id": "5b5110be-3795-495b-aacc-67b6d4798109",
"metadata": {},
"outputs": [],
"source": [
"# In Blackjack the player has two possible decisions:\n",
"STICK = 0 # STICK: stop drawing cards (end the player's turn)\n",
"HIT = 1 # HIT: draw one more card\n",
"ACTIONS = [STICK, HIT] # the list of all valid actions"
]
},
{
"cell_type": "markdown",
"id": "bd019f08-fb38-4086-aac6-5cbead812007",
"metadata": {},
"source": [
"We now define a function that draws a card with a value in $\\{1,2,\\cdots,10\\}$.\n",
" \n",
"1. The value 1 represents an Ace.\n",
"2. The value 10 represents any 10, Jack, Queen, or King, all of which are treated as having value 10.\n",
"\n",
"**Important** To simplify the environment and make it suitable for modeling as a Markov Decision Process (MDP), we assume that cards are drawn uniformly from this set. This removes the need to track the remaining deck composition and allows the transition dynamics to depend only on the current state."
]
},
{
"cell_type": "code",
"execution_count": 220,
"id": "006775e4-940d-4cc2-a6d7-bd0481609f4f",
"metadata": {},
"outputs": [],
"source": [
"def draw_card(rng: np.random.Generator) -> int:\n",
" \"\"\"Draw a card value in {1,...,10}. Face cards have value 10.\"\"\"\n",
" return int(rng.integers(1, 11))\n"
]
},
{
"cell_type": "markdown",
"id": "6e2740fd-5be3-49c8-827d-d75d35aa441f",
"metadata": {},
"source": [
"Now we define a function `usable_ace(hand)` to detect whether an Ace in hand can be trated as “11”, which means, it is a usable ace.\n",
"\n",
"Recall that a “usable ace” means: the hand contains an Ace (1 in hand) and we can treat that Ace as 11 instead of 1 without going over 21. \n",
"\n",
"Examples:\n",
"- hand = [1, 7] $\\rightarrow$ sum(hand)=8 $\\rightarrow$ 8+10=18 <= 21 $\\rightarrow$ usable ace = True\n",
"- hand = [1, 7, 5] $\\rightarrow$ sum=13 $\\rightarrow$ 13+10=23 > 21 $\\rightarrow$ usable ace = False (Here ace must be counted as 1)\n"
]
},
{
"cell_type": "code",
"execution_count": 221,
"id": "76573eb5-73a7-4250-908c-682e17ac9a67",
"metadata": {},
"outputs": [],
"source": [
"def usable_ace(hand: list[int]) -> bool:\n",
" \"\"\"Return True if hand has an ace (value 1) that can be counted as 11 without busting.\"\"\"\n",
" return (1 in hand) and (sum(hand) + 10 <= 21) # noqa: PLR2004\n"
]
},
{
"cell_type": "markdown",
"id": "2d4ae74e-3b11-4c8b-8b28-86deab846830",
"metadata": {},
"source": [
"This `sum_hand(hand)` function returns the hands effective score under Blackjack rules:\n",
"\n",
"1. If there is a usable ace, count one ace as 11 (add 10).\n",
"2. Otherwise, just return the raw sum.\n",
"\n",
"Why “best” sum? Since Blackjack players always prefer the highest total $\\leq$ 21.\n"
]
},
{
"cell_type": "code",
"execution_count": 222,
"id": "e520924e-dd82-45ff-b039-fde966da7a65",
"metadata": {},
"outputs": [],
"source": [
"def sum_hand(hand: list[int]) -> int:\n",
" \"\"\"Return best (highest) sum <=21 if possible by counting usable ace as 11. Otherwise return the raw sum.\"\"\"\n",
" s = sum(hand)\n",
" if usable_ace(hand):\n",
" return s + 10\n",
" return s"
]
},
{
"cell_type": "markdown",
"id": "3d8b22e4-c828-4a2b-8228-d34bfbe36e0c",
"metadata": {},
"source": [
"Now we define `is_bust(hand)` function to check if we exceed 21. "
]
},
{
"cell_type": "code",
"execution_count": 223,
"id": "83fa2134-a552-4a83-991a-ae8dd46eb139",
"metadata": {},
"outputs": [],
"source": [
"def is_bust(hand: list[int]) -> bool:\n",
" \"\"\"Return True if hand sum exceeds 21.\"\"\"\n",
" return sum_hand(hand) > 21 # noqa: PLR2004\n"
]
},
{
"cell_type": "markdown",
"id": "4d78d770-a3fb-445e-906d-ca5dc0f78861",
"metadata": {},
"source": [
"Now we define `score(hand)` for the terminal “value” of a hand. It returns a “hand score” used at the end of a game:\n",
"\n",
"1. If bust: score = 0 (player is effectively dead)\n",
"2. Else: score = effective sum"
]
},
{
"cell_type": "code",
"execution_count": 224,
"id": "7ab1c557-d03b-45e3-b821-6d7eb3879f67",
"metadata": {},
"outputs": [],
"source": [
"def score(hand: list[int]) -> int:\n",
" \"\"\"Terminal score is 0 if bust otherwise sum_hand.\"\"\"\n",
" return 0 if is_bust(hand) else sum_hand(hand)"
]
},
{
"cell_type": "markdown",
"id": "b3d443e6-3f15-4c6f-b2ef-1c2ceffc9ea1",
"metadata": {},
"source": [
"Now we define `initial_hands(rng)` to start a new episode (Environment reset).\n",
"\n",
"It creates the initial state of a Blackjack episode:\n",
"\n",
"- player starts with 2 cards\n",
"\n",
"- dealer starts with 2 cards\n",
"\n",
"The environment internally keeps both dealer cards, but the agent will only observe one (typical Blackjack rule: one dealer card is face-up)."
]
},
{
"cell_type": "code",
"execution_count": 225,
"id": "a7debbe9-6d8f-44f4-b751-e4ccfd10ace8",
"metadata": {},
"outputs": [],
"source": [
"def initial_hands(rng: np.random.Generator) -> tuple[list[int], list[int]]:\n",
" \"\"\"Start of a Blackjack episode.\n",
"\n",
" - Player draws 2 cards\n",
" - Dealer draws 2 cards, but we only \"observe\" dealer[0]\n",
" \"\"\"\n",
" player = [draw_card(rng), draw_card(rng)]\n",
" dealer = [draw_card(rng), draw_card(rng)]\n",
" return player, dealer"
]
},
{
"cell_type": "markdown",
"id": "2197aa35-6734-4367-961a-b048802f8594",
"metadata": {},
"source": [
"Now we define the state representation for the player, which is the RL agent.\n",
"\n",
"We convert the full hands (lists of cards) into a state tuple $(p, d, u)$, where\n",
"\n",
"1. `p = sum_hand(player)`, which is Players current total under Blackjack rules (Ace possibly counted as 11).\n",
"2. `d = dealer[0]`, which is Dealers visible card (the one face-up).\n",
"3. `u = int(usable_ace(player))`, which is a binary feature : 1 if the player has a usable ace; 0 otherwise\n"
]
},
{
"cell_type": "code",
"execution_count": 226,
"id": "0f8eab84-564e-44a6-a174-7bf8d0c411dc",
"metadata": {},
"outputs": [],
"source": [
"def obs(player: list[int], dealer: list[int]) -> tuple[int, int, int]:\n",
" \"\"\"Convert internal hands into tabular state (p, d, u).\n",
"\n",
" Args:\n",
" player: list of player's cards\n",
" dealer: list of dealer's cards\n",
"\n",
" Returns:\n",
" (p, d, u): p = player's sum, d = dealer's visible card\n",
"\n",
" \"\"\"\n",
" p = sum_hand(player)\n",
" d = dealer[0]\n",
" u = int(usable_ace(player))\n",
" return (p, d, u)"
]
},
{
"cell_type": "markdown",
"id": "a23f83c4-91ee-4374-9510-e42cae3d4751",
"metadata": {},
"source": [
"### 2.2 Blackjack step function\n",
"\n",
"In reinforcement learning, an environment step function typically implements: \n",
"$$\n",
"(S_{t+1}, r_{t+1}, \\text{done})=\\text{step}(s_t, a_t)\n",
"$$\n",
"where \n",
"1. the input are the current state (here: `(player, dealer)`) and an action `(HIT or STICK)`; \n",
"2. and the output are the updated state `(player, dealer)`; the reward for this transition and a boolean `done` indicating whether the episode ended.\n",
"\n",
"In Blackjack, \n",
"1. If the player hits, the game may continue (unless the player busts).\n",
"2. If the player sticks, the dealer finishes playing, and the episode ends immediately."
]
},
{
"cell_type": "code",
"execution_count": 227,
"id": "ba254a3f-6ce3-44a3-81ff-8df9e41a11e2",
"metadata": {},
"outputs": [],
"source": [
"def step(\n",
" player: list[int],\n",
" dealer: list[int],\n",
" action: int,\n",
" rng: np.random.Generator,\n",
") -> tuple[tuple[list[int], list[int]], int, bool]:\n",
" \"\"\"Take action from current state.\n",
"\n",
" Args:\n",
" player: list of player's cards\n",
" dealer: list of dealer's cards\n",
" action: HIT or STICK\n",
" rng: random generator used to draw cards\n",
"\n",
" Returns:\n",
" next_state: (player's cards, dealer's cards)\n",
" reward: -1 (lose), 0 (draw), +1 (win)\n",
" done: True if episode ended\n",
"\n",
" \"\"\"\n",
" if action == HIT:\n",
" player.append(\n",
" draw_card(rng),\n",
" )\n",
"\n",
" if is_bust(player):\n",
" return (\n",
" (player, dealer),\n",
" -1,\n",
" True,\n",
" )\n",
" return (\n",
" (player, dealer),\n",
" 0,\n",
" False,\n",
" )\n",
"\n",
" while sum_hand(dealer) < 17: # noqa: PLR2004\n",
" dealer.append(draw_card(rng))\n",
"\n",
" ps = score(player)\n",
" ds = score(dealer)\n",
"\n",
" if ds == 0 or ps > ds:\n",
" reward = 1\n",
" elif ps < ds:\n",
" reward = -1\n",
" else:\n",
" reward = 0\n",
"\n",
" return (player, dealer), reward, True"
]
},
{
"cell_type": "markdown",
"id": "3b25f5f4-1c7d-49cc-ac32-cfe9849a0eb2",
"metadata": {},
"source": [
"-------------------\n",
"\n",
"## 3. First-visit Monte Carlo method to estimate state-value functions"
]
},
{
"cell_type": "markdown",
"id": "0efc1fe8-f848-4bdd-8568-5087438d1b9c",
"metadata": {},
"source": [
"### 3.1 Episode generation (core for Monte Carlo)\n",
"\n",
"Monte Carlo methods learn from complete episodes.\n",
"So we need a function that:\n",
"\n",
"1. starts a new game,\n",
"\n",
"2. follows a policy to generate $(S_t, A_t)$ pairs until terminal,\n",
"\n",
"3. returns the whole episode + terminal reward.\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "6c3ae6dd-98bb-4b92-b662-680b15adb185",
"metadata": {},
"source": [
"**Exercise 1.** Write a function `generate_episode` that simulates one complete Blackjack game (one episode) using a input policy. \n",
"The Args of this function `generate_episode` are \n",
"1. `policy_fn` the input policy, which is a function that takes a state $s$ and returns an action: `HIT` (1) or `STICK` (0)\n",
"2. `rng`, which is a random generator for drawing cards (so results are reproducible).\n",
"\n",
"The outputs of this function `generate_episode` are\n",
"1. `episode`: which is a list like $[(S_0, A_0), (S_1, A_1), ..., (S_{T-1}, A_{T-1})]$\n",
"2. `G`, which is the terminal return (final reward) in $\\{-1, 0,1\\}$ (lose, draw, win) since in our simplified Blackjack: intermediate rewards are 0, only the final outcome matters. \n",
"\n",
"*Remark.* In Blackjack, not all states are equally important for decision-making. When the players sum is below 12, hitting is always safe, since even the largest possible card (value 10) cannot make the player bust.\n",
"\n",
"As a result, there is no real choice to make in these states, and the player should always hit. For this reason, we simplify the problem by: forcing the action HIT when the players sum $p<12$,and only learning decisions for states with $p=12,…,21$.\n",
"\n",
"This reduces the number of states we need to consider, speeds up learning, and does not change the optimal behavior."
]
},
{
"cell_type": "code",
"execution_count": 228,
"id": "cf8415f3-97c6-4d83-b42d-4743c770abbf",
"metadata": {},
"outputs": [],
"source": [
"from collections.abc import Callable\n",
"\n",
"\n",
"def generate_episode(\n",
" policy_fn: Callable[[tuple[int, int, int]], int],\n",
" rng: np.random.Generator,\n",
") -> tuple[list[tuple[tuple[int, int, int], int]], int]:\n",
" \"\"\"Generate one full episode following a given policy.\n",
"\n",
" Hint:\n",
" - An episode is a complete Blackjack game.\n",
" - You should repeatedly:\n",
" 1) observe the current state,\n",
" 2) choose an action using the policy,\n",
" 3) apply the action to the environment,\n",
" 4) stop when the game ends.\n",
"\n",
" Returns:\n",
" episode: list of (state, action) pairs\n",
" G: terminal return (final reward), in {-1, 0, 1}\n",
"\n",
" \"\"\"\n",
" player, dealer = initial_hands(rng)\n",
"\n",
" episode = []\n",
"\n",
" while True:\n",
" s = obs(player, dealer)\n",
"\n",
" a = HIT if s[0] < 12 else policy_fn(s) # noqa: PLR2004\n",
"\n",
" episode.append((s, a))\n",
" (player, dealer), r, done = step(player, dealer, a, rng)\n",
"\n",
" if done:\n",
" return episode, r"
]
},
{
"cell_type": "markdown",
"id": "ddfb1a92-e71a-4a33-a225-695568da4e7b",
"metadata": {},
"source": [
"Test your `generate_episode` function with the following `always_hit_policy` :)."
]
},
{
"cell_type": "code",
"execution_count": 229,
"id": "4c12bb7a-76ae-4659-9064-8745ca89eef9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode length: 3\n",
"Terminal return G: -1\n",
"First 3 steps: [((16, 6, 0), 1), ((20, 6, 0), 1), ((21, 6, 0), 1)]\n"
]
}
],
"source": [
"def always_hit_policy(_state: tuple[int, int, int]) -> int:\n",
" \"\"\"Return a policy that always hits.\"\"\"\n",
" return HIT\n",
"\n",
"\n",
"episode, G = generate_episode(always_hit_policy, rng)\n",
"\n",
"print(\"Episode length:\", len(episode))\n",
"print(\"Terminal return G:\", G)\n",
"print(\"First 3 steps:\", episode[:3])\n"
]
},
{
"cell_type": "markdown",
"id": "d4072bae-d975-4637-88cc-6a466a49cb76",
"metadata": {},
"source": [
"### 3.2 Estimates the state-value function of a fixed policy $\\pi$ using Monte-Carlo\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "643165a2-c569-4506-9dc5-43355e599c92",
"metadata": {},
"source": [
"In this section, we estimate the **state-value function** of a **fixed policy** using **first-visit Monte Carlo methods**.\n",
"\n",
"The goal is to approximate, for each state $s$,\n",
"$$\n",
"V_\\pi(s) = \\mathbb{E}_\\pi[G \\mid S_t = s],\n",
"$$\n",
"that is, the expected final reward obtained when starting from state $s$ and following the policy $\\pi$.\n",
"\n",
"**Key ideas**\n",
"\n",
"- We generate many **complete episodes** (full Blackjack games) by following a fixed policy.\n",
"- Each episode ends with a **terminal reward** $G \\in \\{-1, 0, 1\\}$.\n",
"- For a given state $s$, we use the observed return $G$ as a **sample** of $V_\\pi(s)$.\n",
"- In **first-visit Monte Carlo**, each state is updated **only the first time it appears in an episode**.\n",
"- The value $V_{\\pi}(s)$ is estimated by the **average of all observed returns** from first visits to $s$.\n",
"\n",
"**Important remarks** \n",
"\n",
"- This method does not require transition probabilities.\n",
"- Updates are performed **only after an episode terminates**. \n",
"- In our Blackjack setting, all intermediate rewards are zero, so the return is simply the terminal reward. Hence, we do not code it backward because, in this Blackjack setting, the return is the same for all time steps, so a backward pass brings no benefit. A forward pass is simpler and equivalent.\n"
]
},
{
"cell_type": "markdown",
"id": "0b782f88-b21e-4bd7-a4a8-ad5e57923a38",
"metadata": {},
"source": [
"**Exercise 2.** Define a `mc_prediction_first_visit()` function to estimates the state-value function of a fixed given policy $\\pi$. \n",
"\n",
"Recall that in the simplified setting of Blackjack that we use in this lab: rewards are 0 during the game, only the terminal reward is nonzero, so the return $G$ is just the final outcome in {1,0,1}. And Monte Carlo prediction approximates the expectation by averaging returns from many episodes.\n",
"\n",
"The Args of `mc_prediction_first_visit()` are \n",
"1. `num_episodes`: Number of complete Blackjack games (episodes) to simulate. Each episode is one full game, from the initial deal until the game ends. Monte Carlo methods estimate expectations by averaging over many episodes.\n",
"A larger num_episodes gives: more accurate estimates of the value function, but higher computational cost.\n",
"\n",
"2. `rng`: random number generator, which is the source of randomness used to simulate the environment.\n",
"\n",
"The output of `mc_prediction_first_visit()` is \n",
"\n",
"`V` a dictionary mapping each state $s$ to an estimate of $V_{\\pi}(s)$. \n",
"\n",
"A example of type `dict`:"
]
},
{
"cell_type": "code",
"execution_count": 230,
"id": "14d6496d-b6fc-4130-ae59-de35b305b0a8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"15\n"
]
}
],
"source": [
"grades = {\"Alice\": 15, \"Bob\": 12, \"Charlie\": 18}\n",
"print(grades[\"Alice\"]) # returns 15"
]
},
{
"cell_type": "code",
"execution_count": 231,
"id": "8438b2e6",
"metadata": {},
"outputs": [],
"source": [
"def target_policy(state: tuple[int, int, int]) -> int:\n",
" \"\"\"Deterministic policy: HIT if player's sum < 20 else STICK.\"\"\"\n",
" p, _d, _u = state\n",
" return HIT if p < 20 else STICK # noqa: PLR2004\n"
]
},
{
"cell_type": "code",
"execution_count": 232,
"id": "522a0b55-972e-410e-87dd-8a9fee7cf8be",
"metadata": {},
"outputs": [],
"source": [
"def mc_prediction_first_visit(\n",
" num_episodes: int = 200_000,\n",
" rng: np.random.Generator = rng,\n",
") -> dict:\n",
" \"\"\"First-visit Monte Carlo prediction for a fixed policy.\n",
"\n",
" Goal:\n",
" - Estimate the state-value function V(s) of a given policy.\n",
" - Use first-visit Monte Carlo: each state is updated at most once per episode.\n",
"\n",
" Args:\n",
" num_episodes: number of episodes to sample\n",
" rng: random generator used to draw cards\n",
"\n",
" Returns:\n",
" V: dict mapping state -> estimated value\n",
"\n",
" \"\"\"\n",
" V = defaultdict(float)\n",
" returns_sum = defaultdict(float)\n",
" returns_count = defaultdict(int)\n",
"\n",
" def pi(s: tuple[int, int, int]) -> int:\n",
" return target_policy(s)\n",
"\n",
" for _ in range(num_episodes):\n",
" episode, G = generate_episode(pi, rng)\n",
" visited_states = set()\n",
"\n",
" for s, _a in episode:\n",
" p, _d, _u = s\n",
"\n",
" if p < 12: # noqa: PLR2004\n",
" continue\n",
"\n",
" if s not in visited_states:\n",
" visited_states.add(s)\n",
" returns_sum[s] += G\n",
" returns_count[s] += 1\n",
" V[s] = returns_sum[s] / returns_count[s]\n",
" return V\n"
]
},
{
"cell_type": "markdown",
"id": "23327cd4-88b2-4f84-90d7-80a45308f99e",
"metadata": {},
"source": [
"Now we check if `mc_prediction_first_visit` is correctly defined. "
]
},
{
"cell_type": "code",
"execution_count": 233,
"id": "d81a2dd1-27ab-4b62-8b52-e0dc156b96c5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'collections.defaultdict'> 200\n",
"0.6234567901234568\n"
]
}
],
"source": [
"V = mc_prediction_first_visit(num_episodes=10000, rng=rng)\n",
"print(type(V), len(V))\n",
"\n",
"s_test = (20, 10, 0)\n",
"\n",
"print(V[s_test])"
]
},
{
"cell_type": "markdown",
"id": "b0ea21a2-29e0-4657-a502-ef7791bfa8cb",
"metadata": {},
"source": [
"--------------------------------\n",
"\n",
"## 4. Apply Monte Carlo control to learn an optimal policy for Blackjack (with exploring starts)."
]
},
{
"cell_type": "markdown",
"id": "f0c93491-1b67-4599-84d8-ae6cffd3a126",
"metadata": {},
"source": [
"In this section, we move from **policy evaluation** to **policy control**.\n",
"\n",
"Our goal is no longer to evaluate a fixed policy, but to **learn an (approximately) optimal policy** for Blackjack directly from experience. To do this, we use **Monte Carlo control with exploring starts**.\n",
"\n",
"### Key ideas\n",
"\n",
"- We estimate the **action-value function** $Q(s,a)$, which represents the expected final reward when taking action $a$ in state $s$.\n",
"- We generate many complete episodes (full Blackjack games).\n",
"- After each episode, we update $Q(s,a)$ using the observed return.\n",
"- We then **improve the policy greedily** by choosing, in each state, the action with the highest estimated $Q$-value.\n",
"- **Exploring starts** ensure that different stateaction pairs are tried, by forcing the first action of each episode to be random.\n",
"\n",
"This approach alternates between **evaluation** (updating $Q$) and **improvement** (updating the policy), gradually converging toward an optimal strategy.\n",
"\n",
"> **This exercise is more challenging than the previous ones.** \n",
"> It combines several ideas at once (episodes, first-visit Monte Carlo, action-values, and policy improvement). \n",
"> **Do not be discouraged if it does not work immediately**: take it step by step, follow the hints, and focus on understanding the logic rather than writing everything at once.\n"
]
},
{
"cell_type": "markdown",
"id": "bd3b5114-1299-4aad-a73f-0ca608c58d00",
"metadata": {},
"source": [
"**Exercise 3 ($\\star$).** Write a `mc_control_exploring_starts` function that learns an (approximately) optimal policy for Blackjack by iterating:\n",
"\n",
"1. Generate an episode using the current policy (with exploring starts)\n",
"2. Compute the return $G$ (here its just the terminal reward)\n",
"3. Update the action-value function $Q(s,a)$ using sample averages\n",
"4. Improve the policy to be greedy w.r.t. the updated $Q$.\n",
"\n",
"Remark. This is Monte Carlo control (not just prediction), because we are improving the policy while learning.\n",
"\n",
"The Args of this `mc_control_exploring_starts` function are \n",
"1. `num_episodes`: how many games to simulate (more = better learning, slower)\n",
"2. `rng`: random generator for reproducibility\n",
" \n",
"The outputs of this `mc_control_exploring_starts` function are \n",
"1. `Q`: mapping from (state, action) to estimated action-value\n",
" $$\n",
" Q(s,a) \\simeq \\mathbb{E}[G | S_t=s, A_t=a]\n",
" $$\n",
"2. `policy`: mapping from state to greedy action\n",
" $$\n",
" \\pi(s)=\\text{argmax}_a Q(s,a).\n",
" $$"
]
},
{
"cell_type": "markdown",
"id": "c4a77754-5e3e-4f3a-8a5c-ff1d316ad591",
"metadata": {},
"source": [
"**Hints** In the following code, \n",
"- `policy.get(s, HIT)` means: return `policy[s]` if it exists, otherwise return `HIT`.\n",
"- `visited_sa` enforces *first-visit*: each (state, action) pair is updated only once per episode.\n",
"- Sample-average update: (increment update)\n",
" $$ Q(s,a) \\leftarrow Q(s,a) + \\frac{1}{N(s,a)}(G - Q(s,a)) $$\n",
"- Greedy improvement: \n",
" $$ \\pi(s) = \\arg\\max_{a \\in \\{HIT,STICK\\}} Q(s,a) $$\n"
]
},
{
"cell_type": "code",
"execution_count": 234,
"id": "89bdd3ec",
"metadata": {},
"outputs": [],
"source": [
"def mc_control_exploring_starts(\n",
" num_episodes: int = 500_000,\n",
" rng: np.random.Generator = rng,\n",
") -> tuple[dict, dict]:\n",
" \"\"\"Monte Carlo Control with Exploring Starts (first-visit).\n",
"\n",
" Goal:\n",
" - Learn an (approximately) optimal deterministic policy for Blackjack.\n",
" - We learn Q(s,a) by averaging returns from complete episodes.\n",
" - We improve the policy by making it greedy w.r.t. Q.\n",
"\n",
" Exploring starts:\n",
" - Each episode starts from a random initial deal (random state).\n",
" - The FIRST action is chosen randomly (HIT or STICK) so that both actions\n",
" can be tried from many starting states.\n",
"\n",
" Returns:\n",
" Q: dict mapping (state, action) -> action-value estimate\n",
" policy: dict mapping state -> greedy action\n",
"\n",
" \"\"\"\n",
" Q = defaultdict(float)\n",
" Nsa = defaultdict(int)\n",
" policy = {}\n",
"\n",
" for _ in range(num_episodes):\n",
" player, dealer = initial_hands(rng)\n",
"\n",
" episode = []\n",
"\n",
" s0 = obs(player, dealer)\n",
" a0 = int(rng.integers(0, 2))\n",
"\n",
" episode.append((s0, a0))\n",
"\n",
" (player, dealer), r, done = step(player, dealer, a0, rng)\n",
"\n",
" if done:\n",
" G = r\n",
" else:\n",
" while True:\n",
" s = obs(player, dealer)\n",
"\n",
" a = HIT if s[0] < 12 else policy.get(s, HIT) # noqa: PLR2004\n",
"\n",
" episode.append((s, a))\n",
" (player, dealer), r, done = step(player, dealer, a, rng)\n",
"\n",
" if done:\n",
" G = r\n",
" break\n",
"\n",
" visited_sa = set()\n",
"\n",
" for s, a in episode:\n",
" if s[0] < 12: # noqa: PLR2004\n",
" continue\n",
"\n",
" if (s, a) in visited_sa:\n",
" continue\n",
"\n",
" visited_sa.add((s, a))\n",
" Nsa[(s, a)] += 1\n",
" Q[(s, a)] += (G - Q[(s, a)]) / Nsa[(s, a)]\n",
"\n",
" q_stick = Q[(s, STICK)]\n",
" q_hit = Q[(s, HIT)]\n",
" policy[s] = HIT if q_hit >= q_stick else STICK\n",
"\n",
" return Q, policy\n"
]
},
{
"cell_type": "code",
"execution_count": 235,
"id": "9d68d504-7596-47e4-9d72-99ea29eb3950",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"1\n",
"3\n",
"4\n"
]
}
],
"source": [
"# A \"continue\" example :\n",
"\n",
"for i in range(5):\n",
" if i == 2: # noqa: PLR2004\n",
" continue\n",
" print(i)"
]
},
{
"cell_type": "code",
"execution_count": 236,
"id": "c381ec6a-61bd-4e51-bd95-ed7b68829e52",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(20, 10, 0) -> STICK\n",
"(13, 2, 0) -> HIT\n",
"(18, 6, 1) -> HIT\n"
]
}
],
"source": [
"# Check your mc_control_exploring_starts function.\n",
"\n",
"Q, pi_star = mc_control_exploring_starts(num_episodes=50000, rng=rng)\n",
"\n",
"# Check a few states\n",
"test_states = [(20, 10, 0), (13, 2, 0), (18, 6, 1)]\n",
"for s in test_states:\n",
" a = pi_star.get(s, None) # pi_star is a dictionary mapping states to actions.\n",
" # .get(key, default) returns: pi_star[key] if key exists, otherwise default, here it is None.\n",
" print(s, \"->\", \"HIT\" if a == HIT else \"STICK\" if a == STICK else \"unseen\")\n"
]
},
{
"cell_type": "markdown",
"id": "70fd45c8-d006-42b7-8c06-383b6fba369e",
"metadata": {},
"source": [
"-----------------------------------------------\n",
"\n",
"## 5. Address the exploration problem with $\\varepsilon$-soft policies (instead of exploring starts).\n"
]
},
{
"cell_type": "markdown",
"id": "ac79dafc-3c95-4a52-b786-41cdd2704cf1",
"metadata": {},
"source": [
"\n",
"In the previous section, we used **exploring starts** to ensure sufficient exploration when learning an optimal policy with Monte Carlo control. While exploring starts are useful for theoretical analysis, they rely on the ability to **randomly choose the initial stateaction pair**, which is often unrealistic in practice.\n",
"\n",
"In this section, we address the exploration problem in a more practical way by using **$\\varepsilon$-soft ($\\varepsilon$-greedy) policies**.\n",
"\n",
"**Key ideas**\n",
"\n",
"- An **$\\varepsilon$-soft policy** assigns a nonzero probability to all actions in every state.\n",
"- With probability $1 - \\varepsilon$, the policy chooses the **greedy action** (exploitation).\n",
"- With probability $\\varepsilon$, the policy chooses a **random action** (exploration).\n",
"- (Optional) By gradually decreasing $\\varepsilon$, the policy shifts from exploration to exploitation as learning progresses.\n",
"\n",
"This approach allows the agent to explore **throughout the episode**, without requiring control over the initial state or action.\n",
"\n",
"**What you will implement**\n",
"\n",
"In the following exercises, you will:\n",
"\n",
"- Implement an **$\\varepsilon$-greedy action selection rule**.\n",
"- Use it to construct an **on-policy Monte Carlo control algorithm**.\n",
"\n",
"\n",
"*This section is conceptually challenging. Do not be discouraged if it takes time to understand how exploration and learning interact.*\n",
"\n",
"\n",
"By the end of this section, you will have implemented a fully **on-policy Monte Carlo control method** that does not rely on exploring starts and is closer to what is used in real-world reinforcement learning systems.\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "02b0b0e1-9922-4732-b2af-02da3c7fa3a2",
"metadata": {},
"source": [
"**Exercise 4.** In this exercise, you will define the **$\\varepsilon$-greedy policy** used for on-policy Monte Carlo control.\n",
"\n",
"Recall that the $\\varepsilon$-greedy strategy balances:\n",
"- **exploration**: trying random actions to discover new information,\n",
"- **exploitation**: choosing the action that currently looks best w.r.t. `Q`.\n",
"\n",
"More precisely, given a state $s$:\n",
"\n",
"- with probability $\\varepsilon$, choose a random action (HIT or STICK);\n",
"- with probability $1-\\varepsilon$, choose the greedy action $ \\arg\\max_a Q(s,a). $\n",
"\n",
"**Tasks**\n",
"\n",
"1. Complete the function `epsilon_greedy(Q, s, epsilon, rng)`.\n",
"2. Use `rng.random()` to decide whether to explore or exploit.\n",
"3. When exploiting, compare `Q[(s, HIT)]` and `Q[(s, STICK)]` and choose the action with the larger value.\n",
"\n",
"*Hint*\n",
"\n",
"- In Blackjack, the action space contains only two actions: `HIT` and `STICK`.\n",
"- The dictionary `Q` may contain unseen stateaction pairs, which default to value 0.\n"
]
},
{
"cell_type": "code",
"execution_count": 237,
"id": "4e876128-1c6e-4c28-874d-8cb68f422d21",
"metadata": {},
"outputs": [],
"source": [
"def epsilon_greedy(Q: dict, s: tuple, epsilon: float, rng: np.random.Generator) -> int:\n",
" \"\"\"Epsilon-greedy action selection for Blackjack.\n",
"\n",
" With probability epsilon:\n",
" - choose a random action (exploration)\n",
" With probability 1 - epsilon:\n",
" - choose the greedy action argmax_a Q(s,a) (exploitation)\n",
"\n",
" Args:\n",
" Q: dict mapping (state, action) -> action-value estimate\n",
" s: current state (p, d, u)\n",
" epsilon: exploration probability in [0,1]\n",
" rng: random number generator\n",
"\n",
" Returns:\n",
" action: HIT or STICK\n",
"\n",
" \"\"\"\n",
" if rng.random() < epsilon:\n",
" return int(rng.integers(0, 2))\n",
"\n",
" q_stick = Q[(s, STICK)]\n",
" q_hit = Q[(s, HIT)]\n",
" return HIT if q_hit >= q_stick else STICK"
]
},
{
"cell_type": "markdown",
"id": "ec7a405e-36c7-4d2e-940d-3412ec68ea7c",
"metadata": {},
"source": [
"**Exercise 5.** Define a function `mc_control_epsilon_soft` that learns an (approximately) optimal Blackjack policy by repeating:\n",
"\n",
"1. Generate an episode using a $\\varepsilon$-greedy behavior policy that explores.\n",
"2. Update `Q(s,a)` using the episode return `G` with first-visit MC.\n",
"3. As learning progresses, $\\varepsilon$ decreases so the policy becomes more greedy.\n",
"\n",
"*Remark.* Unlike exploring starts, this approach does not require forcing the initial $(s,a)$. Exploration is achieved via $\\varepsilon$-greedy actions throughout the episode. Here we define a on-policy control, which means, we evaluate and improve the same policy that generates data.\n",
"\n",
"The Args of this `mc_control_epsilon_soft` function are \n",
"1. `num_episodes`: how many complete games we simulate (more = better learning, slower)\n",
"2. `N0`: a constant controlling how much exploration we do early on and how fast it decays (see the *Hint.* below).\n",
"3. `rng`: random generator for reproducibility\n",
" \n",
"The outputs of this `mc_control_epsilon_soft` function are \n",
"1. `Q`: mapping from (state, action) to estimated action-value\n",
" $$\n",
" Q(s,a) \\simeq \\mathbb{E}[G | S_t=s, A_t=a]\n",
" $$\n",
"2. `policy`: final greedy policy derived from `Q`. "
]
},
{
"cell_type": "markdown",
"id": "62819c64-d022-4c98-a0f4-ac7ae9a22dd6",
"metadata": {},
"source": [
"*Hint.* We will use the following classical exploration definition of $\\varepsilon(s)$ for a state $s$:\n",
"$$\n",
"\\varepsilon(s) = \\frac{N_0}{N_0 + N(s)},\n",
"$$\n",
"where:\n",
"- $N_0 > 0$ is a fixed constant (it is a parameter),\n",
"- $N(s)$ is the number of times state $s$ has been visited so far.\n",
"\n",
"(Why use this definition ?) This choice of $\\varepsilon(s)$ has several important properties:\n",
"\n",
"- *Pure exploration at first*: If a state has never been visited, $N(s) = 0$, then $\\varepsilon(s) = \\frac{N_0}{N_0} = 1,$ so the agent chooses actions completely at random.\n",
"\n",
"- *Gradual reduction of exploration*: As $N(s)$ increases, $\\varepsilon(s)$ decreases towards 0, and the policy becomes increasingly greedy.\n",
"\n",
"- *Controlled exploration*: The parameter $N_0$ determines how long the agent keeps exploring:\n",
" - a larger $N_0$ means **more exploration for longer**,\n",
" - a smaller $N_0$ means the policy becomes greedy **more quickly**.\n"
]
},
{
"cell_type": "code",
"execution_count": 238,
"id": "7521a88d-eed5-4415-8f61-ac4c395aff5e",
"metadata": {},
"outputs": [],
"source": [
"def mc_control_epsilon_soft(\n",
" num_episodes: int = 500_000,\n",
" N0: int = 100,\n",
" rng: np.random.Generator = rng,\n",
") -> tuple[dict, dict]:\n",
" \"\"\"On-policy first-visit Monte Carlo control using epsilon-greedy exploration.\n",
"\n",
" Args:\n",
" num_episodes: number of simulated episodes (games)\n",
" N0: controls exploration via epsilon(s) = N0 / (N0 + N(s))\n",
" rng: random generator\n",
"\n",
" Returns:\n",
" Q: dict mapping (state, action) -> action-value estimate\n",
" policy: dict mapping state -> greedy action (HIT or STICK)\n",
"\n",
" \"\"\"\n",
" Q = defaultdict(float)\n",
" Nsa = defaultdict(int)\n",
"\n",
" for _ in range(num_episodes):\n",
"\n",
" def behavior(s: tuple) -> int:\n",
" p, _d, _u = s\n",
"\n",
" if p < 12: # noqa: PLR2004\n",
" return HIT\n",
"\n",
" Ns = np.sum([Nsa[(s, a)] for a in ACTIONS])\n",
" epsilon = N0 / (N0 + Ns)\n",
"\n",
" return epsilon_greedy(Q, s, epsilon, rng)\n",
"\n",
" episode, G = generate_episode(behavior, rng)\n",
"\n",
" visited_sa = set()\n",
"\n",
" for s, a in episode:\n",
" p, _d, _u = s\n",
"\n",
" if p < 12: # noqa: PLR2004\n",
" continue\n",
"\n",
" if (s, a) in visited_sa:\n",
" continue\n",
"\n",
" visited_sa.add((s, a))\n",
"\n",
" Nsa[(s, a)] += 1\n",
"\n",
" Q[(s, a)] += (G - Q[(s, a)]) / Nsa[(s, a)]\n",
"\n",
" policy = {}\n",
"\n",
" for p in range(12, 22):\n",
" for d in range(1, 11):\n",
" for u in [0, 1]:\n",
" s = (p, d, u)\n",
"\n",
" q_stick = Q[(s, STICK)]\n",
" q_hit = Q[(s, HIT)]\n",
"\n",
" policy[s] = STICK if q_stick > q_hit else HIT\n",
"\n",
" return Q, policy\n"
]
},
{
"cell_type": "code",
"execution_count": 239,
"id": "fed938c9-3dbf-4f7c-865d-2a21206c6ff8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Q range: [min, max] = (-1.0, 0.9687500000000001)\n"
]
}
],
"source": [
"# Quick checks for mc_control_epsilon_soft\n",
"\n",
"# Run a small training to check the function executes without errors\n",
"Q, pi_star = mc_control_epsilon_soft(num_episodes=20_000, N0=100, rng=rng)\n",
"\n",
"# Check that Q values are within the valid range [-1, 1]\n",
"# Explanation:\n",
"# In this Blackjack environment, the terminal reward is in {-1,0,1}.\n",
"# Therefore, any expected return Q(s,a) must lie between -1 and 1.\n",
"q_vals = list(Q.values())\n",
"print(\"Q range: [min, max] =\", (min(q_vals), max(q_vals)))\n",
"\n",
"assert min(q_vals) >= -1.01 and max(q_vals) <= 1.01, \"Q-values out of expected range!\" # noqa: PLR2004, PT018, S101\n"
]
},
{
"cell_type": "markdown",
"id": "f3fd5d7d-bd4c-4557-bf4f-09710844c5d1",
"metadata": {},
"source": [
"--------------------\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "7137801f-5d5d-44f5-9418-0996bbd7410e",
"metadata": {},
"source": [
"## 6. Off-Policy Monte Carlo Learning and Importance Sampling\n",
"\n",
"So far, all the algorithms we have studied were **on-policy**: the policy used to generate data (the *behavior policy*) was the same as the policy being evaluated or improved.\n",
"\n",
"In many real-world situations, however, this assumption is unrealistic. We often want to:\n",
"- evaluate or improve a **target policy**,\n",
"- using data generated by a **different behavior policy**.\n",
"\n",
"This setting is known as **off-policy learning**.\n",
"\n",
"**Key idea: separating behavior and target policies**\n",
"\n",
"- The **behavior policy** determines how data (episodes) are collected.\n",
"- The **target policy** is the policy we want to evaluate or optimize.\n",
"- In off-policy learning, these two policies are not the same.\n",
"\n",
"Because the data are generated under a different distribution than the one induced by the target policy, we must correct this mismatch.\n",
"\n",
"**Importance sampling**\n",
"\n",
"To address this issue, we use **importance sampling**, a classical statistical technique that allows expectations under one distribution to be estimated using samples from another distribution.\n",
"\n",
"In this section, you will study:\n",
"- **Ordinary importance sampling**, which uses unbiased but potentially high-variance estimates;\n",
"- **Weighted importance sampling**, which reduces variance at the cost of a small bias.\n",
"\n",
"Both methods play a central role in off-policy Monte Carlo learning.\n",
"\n",
"**What you will learn**\n",
"\n",
"In the exercises that follow, you will:\n",
"- distinguish clearly between **on-policy** and **off-policy** learning,\n",
"- evaluate a target policy using data generated by a different behavior policy,\n",
"- implement **ordinary** and **weighted importance sampling**,\n",
"- observe the biasvariance trade-off inherent in off-policy Monte Carlo methods.\n",
"\n",
"> Off-policy learning requires careful reasoning about probability distributions and can be subtle at first. \n",
"> Take your time, follow the derivations step by step, and focus on understanding *why* importance sampling is needed before worrying about implementation details.\n"
]
},
{
"cell_type": "markdown",
"id": "26b1ad91-d9dd-494b-be85-865752a9649f",
"metadata": {},
"source": [
"**Exercise 6.** (Target Policy and Behavior Policy)\n",
"\n",
"In off-policy learning, we distinguish between two different policies:\n",
"\n",
"- the **target policy** $\\pi$, which is the policy we want to evaluate or improve;\n",
"- the **behavior policy** $b$, which is the policy used to generate data (episodes).\n",
"\n",
"In this exercise, we explicitly define both policies and study their roles.\n",
"\n",
"**Target policy**: The target policy is **deterministic** and represents the strategy we want to evaluate. It is defined as:\n",
"\n",
"- **STICK** if the players sum $p \\ge 20$;\n",
"- **HIT** otherwise.\n",
"\n",
"**Behavior policy**: The behavior policy is $\\varepsilon$-soft and is used to generate episodes. It is defined as follows:\n",
"\n",
"- with probability $1\\varepsilon$, follow the target policy;\n",
"- with probability $\\varepsilon$, choose a random action (HIT or STICK).\n"
]
},
{
"cell_type": "code",
"execution_count": 240,
"id": "a006d3be-3ed9-4ef2-955d-065e36fb15a3",
"metadata": {},
"outputs": [],
"source": [
"def behavior_policy(\n",
" s: tuple,\n",
" epsilon: float = 0.2,\n",
" rng: np.random.Generator = rng,\n",
") -> int:\n",
" \"\"\"ε-soft BEHAVIOR policy b.\n",
"\n",
" Rule:\n",
" - With probability (1 - epsilon), follow the target policy π\n",
" - With probability epsilon, choose a random action (HIT or STICK)\n",
"\n",
" This ensures that all actions have nonzero probability, which is\n",
" required for off-policy learning.\n",
"\n",
" Args:\n",
" s: state tuple (p, d, u)\n",
" epsilon: exploration probability\n",
" rng: random number generator\n",
"\n",
" Returns:\n",
" action: HIT or STICK\n",
"\n",
" \"\"\"\n",
" r = rng.random()\n",
"\n",
" if r < 1 - epsilon:\n",
" return target_policy(s)\n",
" return int(rng.integers(0, 2))\n"
]
},
{
"cell_type": "markdown",
"id": "8bb5136e-01f2-427d-97b9-9d3516b33c84",
"metadata": {},
"source": [
"Now we study **off-policy Monte Carlo evaluation**, where we want to estimate the value of a **target policy** using data generated by a **different behavior policy**.\n",
"\n",
"**Reminder: principle of importance sampling (IS)**\n",
"\n",
"Let:\n",
"- $\\pi$ be the **target policy** (the policy we want to evaluate),\n",
"- $b$ be the **behavior policy** (the policy used to generate data),\n",
"- $G$ be the return of an episode.\n",
"\n",
"Because episodes are generated under $b$ instead of $\\pi$, we must correct for the distribution mismatch using **importance sampling**.\n",
"\n",
"For a trajectory $(S_0,A_0,\\dots,S_{T-1},A_{T-1})$, the **importance sampling ratio** is defined as:\n",
"$$\n",
"W=\\prod_{t=0}^{T-1}\\frac{\\pi(A_t \\mid S_t)}{b(A_t \\mid S_t)}.\n",
"$$\n",
"\n",
"This ratio measures how much more likely the trajectory is under the target policy compared to the behavior policy.\n",
"\n",
"**Ordinary vs. weighted importance sampling**\n",
"\n",
"Using the importance weights $W$, we can build two estimators of\n",
"$$\n",
"V^\\pi(s_0) = \\mathbb{E}_\\pi[G \\mid S_0 = s_0].\n",
"$$\n",
"\n",
"- **Ordinary importance sampling**\n",
"$$\n",
"\\hat V_{\\text{ordinary}}=\\frac{1}{N}\\sum_{i=1}^N W^{(i)} G^{(i)}.\n",
"$$\n",
"\n",
"- **Weighted importance sampling**\n",
"$$\n",
"\\hat V_{\\text{weighted}}=\\frac{\\sum_{i=1}^N W^{(i)} G^{(i)}}{\\sum_{i=1}^N W^{(i)}}.\n",
"$$\n",
"\n",
"The ordinary estimator is unbiased but may have high variance, while the weighted estimator reduces variance at the cost of a small bias.\n",
"\n",
"**Setting of the following exercise**\n",
"\n",
"- The **target policy** $\\pi$ is deterministic.\n",
"- The **behavior policy** $b$ is $\\varepsilon$-soft:\n",
" - with probability $1-\\varepsilon$, it follows the target policy,\n",
" - with probability $\\varepsilon$, it chooses a random action.\n",
"- See **Exercise 6.**\n"
]
},
{
"cell_type": "markdown",
"id": "5c2186ad-9432-44fa-9a11-147b54d7ffb4",
"metadata": {},
"source": [
"**Exercise 7.** (Off-Policy Monte Carlo Evaluation with Importance Sampling)\n",
"\n",
"**Part 1.** Recall the importance sampling (IS) weight for a trajectory is\n",
"$$\n",
"W = \\prod_{t=0}^{T-1}\\frac{\\pi(A_t\\mid S_t)}{b(A_t\\mid S_t)}.\n",
"$$\n",
"\n",
"Because the target policy $\\pi$ is deterministic:\n",
"$$\n",
"\\pi(a\\mid s)=\n",
"\\begin{cases}\n",
"1 & \\text{if } a=\\pi(s),\\\\\n",
"0 & \\text{otherwise.}\n",
"\\end{cases}\n",
"$$\n",
"\n",
"**Question.**\n",
"1. Explain why $W=0$ as soon as the behavior policy takes an action different from the target action.\n",
"2. Show that under our behavior policy in **Exercise 6.**,\n",
"$$\n",
"b(\\pi(s)\\mid s)=(1-\\varepsilon)+\\varepsilon/2.\n",
"$$\n",
"\n",
"\n",
"**Part 2.** Off-Policy Monte Carlo Evaluation (Ordinary vs Weighted IS)\n",
"\n",
"We want to evaluate the target policy $\\pi$ using episodes generated by $b$.\n",
"\n",
"- **Ordinary IS estimator**:\n",
"$$\n",
"\\hat V_{\\text{ordinary}}=\\frac{1}{N}\\sum_{i=1}^N W^{(i)}G^{(i)}.\n",
"$$\n",
"- **Weighted IS estimator**:\n",
"$$\n",
"\\hat V_{\\text{weighted}}=\\frac{\\sum_{i=1}^N W^{(i)}G^{(i)}}{\\sum_{i=1}^N W^{(i)}}.\n",
"$$\n",
"\n",
"Here, each episode starts from the environments initial deal defined by `initial_hands` (see `generate_episode` function).\n"
]
},
{
"cell_type": "markdown",
"id": "81f64de2-3bc7-4491-b88c-056638912b17",
"metadata": {},
"source": [
"Define a function `off_policy_mc_eval_is` that estimate $V^\\pi(s_0) = \\mathbb{E}_\\pi[G \\mid S_0 = s_0]$, where\n",
"\n",
"- $s_0$ = `start state` that is defined in `generate_episode` function with `initial_hands`\n",
"- $\\pi$ = target_policy (the policy you want to evaluate)\n",
"- But we will not generate episodes using $\\pi$. We generate episodes using $b$ = behavior_policy, and we correct the mismatch with importance sampling.\n",
"\n",
"The Args of `off_policy_mc_eval_is` are \n",
"- num_episodes: number of episodes to simulate.\n",
"- epsilon: exploration parameter of the behavior policy $b$.\n",
"- rng: random generator.\n",
"\n",
"The outputs of `off_policy_mc_eval_is` are\n",
"\n",
"- `V_ordinary` : the Ordinary Importance Sampling estimate of $V^\\pi(s_0)$.\n",
"- `V_weighted` : the Weighted Importance Sampling estimate of $V^\\pi(s_0)$.\n"
]
},
{
"cell_type": "code",
"execution_count": 241,
"id": "6dadbb5c-1f52-4dc3-95f0-332d66876284",
"metadata": {},
"outputs": [],
"source": [
"def off_policy_mc_eval_is(\n",
" num_episodes: int = 50_000,\n",
" epsilon: float = 0.2,\n",
" rng: np.random.Generator = rng,\n",
") -> tuple[float, float]:\n",
" \"\"\"Off-policy Monte Carlo evaluation of a TARGET policy π using episodes generated by a different BEHAVIOR policy b.\n",
"\n",
" We estimate two quantities:\n",
" - V_ordinary : (1/N) * Σ_i W_i * G_i\n",
" - V_weighted : (Σ_i W_i * G_i) / (Σ_i W_i)\n",
"\n",
" where:\n",
" - G_i is the return of episode i\n",
" - W_i is the importance sampling weight:\n",
" W_i = Π_t π(A_t|S_t) / b(A_t|S_t)\n",
"\n",
" In this lab:\n",
" - π is deterministic (target_policy)\n",
" - b is epsilon-soft around π (behavior_policy)\n",
"\n",
" Args:\n",
" num_episodes (int): Number of simulated episodes.\n",
" epsilon (float): Exploration parameter for the behavior policy.\n",
" rng (np.random.Generator): Random number generator.\n",
"\n",
" Returns:\n",
" V_ordinary, V_weighted : floats\n",
"\n",
" \"\"\"\n",
" ordinary_sum = 0.0\n",
" ordinary_n = 0\n",
"\n",
" weighted_num = 0.0\n",
" weighted_den = 0.0\n",
"\n",
" for _ in range(num_episodes):\n",
" episode, G = generate_episode(\n",
" lambda s: behavior_policy(s, epsilon=epsilon, rng=rng),\n",
" rng,\n",
" )\n",
" W = 1.0\n",
"\n",
" for s, a in episode:\n",
" p, _d, _u = s\n",
"\n",
" if p < 12: # noqa: PLR2004\n",
" continue\n",
"\n",
" a_star = target_policy(s)\n",
"\n",
" if a != a_star:\n",
" W = 0.0\n",
" break\n",
"\n",
" b_prob = (1 - epsilon) + (epsilon / 2)\n",
"\n",
" W *= 1 / b_prob\n",
"\n",
" ordinary_sum += W * G\n",
" ordinary_n += 1\n",
"\n",
" weighted_num += W * G\n",
" weighted_den += W\n",
"\n",
" V_ordinary = ordinary_sum / ordinary_n\n",
" V_weighted = weighted_num / weighted_den if weighted_den > 0 else 0.0\n",
"\n",
" return V_ordinary, V_weighted\n"
]
},
{
"cell_type": "markdown",
"id": "f5ed08d2-1bce-4f26-bcc1-4cfb1c5e401d",
"metadata": {},
"source": [
"**Exercise 8.** (Variance Experiment (Ordinary vs Weighted IS))\n",
"\n",
"In this exercise, we compare the variance empirically of Ordinary vs Weighted IS.\n",
"\n",
"**Tasks.**\n",
"1. Fix $\\varepsilon=0.2$ and `num_episodes=10_000`.\n",
"2. Repeat the evaluation $K=20$ times with different random seeds.\n",
"3. Record the $K$ values of $V_{\\text{ordinary}}$ and $V_{\\text{weighted}}$.\n",
"4. Compute the empirical mean and variance of each estimator.\n",
"5. Which estimator has lower variance? \n"
]
},
{
"cell_type": "code",
"execution_count": 242,
"id": "4d3ff71b-ed58-44a5-aae5-fc5fa38a10c1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ordinary IS: mean = -0.30517438127233065 var = 0.0002877876907253712\n",
"Weighted IS: mean = -0.3037184822501776 var = 0.0002511580932966114\n"
]
}
],
"source": [
"def variance_experiment(\n",
" K: int = 20,\n",
" num_episodes: int = 10_000,\n",
" epsilon: float = 0.2,\n",
") -> tuple[np.ndarray, np.ndarray]:\n",
" \"\"\"Run K independent off-policy MC evaluations and report variance of estimates.\n",
"\n",
" Args:\n",
" K: number of independent runs\n",
" num_episodes: number of episodes per run\n",
" epsilon: exploration parameter for behavior policy\n",
"\n",
" Returns:\n",
" ord_vals: array of ordinary IS estimates\n",
" wgt_vals: array of weighted IS estimates\n",
"\n",
" \"\"\"\n",
" ord_vals = []\n",
" wgt_vals = []\n",
"\n",
" for k in range(K):\n",
" rng_k = np.random.default_rng(k)\n",
" V_ord, V_wgt = off_policy_mc_eval_is(\n",
" num_episodes=num_episodes,\n",
" epsilon=epsilon,\n",
" rng=rng_k,\n",
" )\n",
" ord_vals.append(V_ord)\n",
" wgt_vals.append(V_wgt)\n",
"\n",
" ord_vals = np.array(ord_vals)\n",
" wgt_vals = np.array(wgt_vals)\n",
"\n",
" print(\"Ordinary IS: mean =\", np.mean(ord_vals), \" var =\", np.var(ord_vals))\n",
" print(\"Weighted IS: mean =\", np.mean(wgt_vals), \" var =\", np.var(wgt_vals))\n",
"\n",
" return ord_vals, wgt_vals\n",
"\n",
"\n",
"ord_vals, wgt_vals = variance_experiment(K=10, num_episodes=5_000, epsilon=0.2)\n"
]
},
{
"cell_type": "markdown",
"id": "0c605809-4d2b-454e-9e4f-4e548571e4dd",
"metadata": {},
"source": [
"Now we want to learn an (approximately) optimal policy for Blackjack:\n",
"\n",
"- Learn the optimal action-value function $q_*(s,a)$ (approximated by `Q`)\n",
"\n",
"- Then take the greedy policy w.r.t. `Q` (stored in `policy`)\n",
"\n",
"But episodes are generated by a different policy:\n",
"\n",
"- **Behavior policy $b$**: $\\varepsilon$-soft (explores)\n",
"\n",
"- **Target policy $\\pi$**: greedy w.r.t. $Q$ (deterministic)\n",
"\n",
"Because data come from $b$ but we want to learn values under $\\pi$, we use importance sampling."
]
},
{
"cell_type": "markdown",
"id": "bac5048a-581c-40fc-8f99-0410bfab5d4a",
"metadata": {},
"source": [
"-----------------------\n",
"\n",
"\n",
"**Exercise 9.($\\star\\star\\star$)** (Off-Policy Monte Carlo Control (Weighted Importance Sampling))\n",
"\n",
"The goal of this exercise is to learn an approximately optimal **deterministic policy** using **off-policy Monte Carlo control** with **weighted importance sampling**.\n",
"\n",
"Let:\n",
"- **π** be the **target policy**, deterministic and greedy with respect to the action-value function `Q`,\n",
"- **b** be an **$\\varepsilon$-soft behavior policy**,\n",
"- $G \\in \\{-1, 0, 1\\}$ be the terminal return of an episode.\n",
"\n",
"Episodes are generated using the behavior policy **b**, but learning targets the policy **π**.\n",
"\n",
"**Part A — Recall Importance Sampling Weights**\n",
"\n",
"Consider one episode:\n",
"$\n",
"(s_0,a_0), (s_1,a_1), \\dots, (s_{T-1},a_{T-1})\n",
"$\n",
"with terminal return $G$.\n",
"\n",
"1. Recall the definition of the importance sampling ratio:\n",
"$$\n",
"W_t = \\prod_{k=t}^{T-1} \\frac{\\pi(a_k \\mid s_k)}{b(a_k \\mid s_k)}.\n",
"$$\n",
"\n",
"Explain why, when the target policy $\\pi$ is **deterministic**, each factor \n",
"$\\pi(a_k \\mid s_k)$ is either **1 or 0**, and consequently the full weight \n",
"$W_t$ is either **strictly positive or equal to zero**.\n",
"\n",
"**Part B — Backward Processing of Episodes**\n",
"\n",
"In the algorithm, the episode is processed **backward**, from time $T-1$ to 0.\n",
"\n",
"2. Explain why backward processing allows an **early stopping rule** when\n",
"$$\n",
"a_t \\neq \\pi(s_t).\n",
"$$\n",
"\n",
"What happens to the importance sampling weight $W_t$ after this point?\n",
"\n",
"**Part C — Weighted Incremental Update**\n",
"\n",
"For each stateaction pair $(s,a)$, the algorithm maintains:\n",
"- a cumulative weight $C(s,a)$,\n",
"- an action-value estimate $Q(s,a)$.\n",
"\n",
"The update rules are:\n",
"$$\n",
"\\begin{aligned}\n",
"C(s,a) &\\leftarrow C(s,a) + W, \\\\\n",
"Q(s,a) &\\leftarrow Q(s,a) + \\frac{W}{C(s,a)}\\bigl(G - Q(s,a)\\bigr).\n",
"\\end{aligned}\n",
"$$\n",
"\n",
"3. Show that this update is an **incremental form** of the weighted average:\n",
"$$\n",
"Q(s,a) = \\frac{\\sum_i W_i G_i}{\\sum_i W_i}.\n",
"$$\n",
"\n",
"**Part D — Policy Improvement**\n",
"\n",
"After updating $Q(s,a)$, the target policy is updated as:\n",
"$$\n",
"\\pi(s) = \\arg\\max_a Q(s,a).\n",
"$$\n",
"\n",
"**Part E — Behavior Policy Probability**\n",
"\n",
"Assume that:\n",
"- the greedy action is $a^* = \\pi(s)$,\n",
"- the behavior policy selects:\n",
" - $a^* $ with probability $ 1 - \\varepsilon $,\n",
" - the other action with probability $\\varepsilon$.\n",
"\n",
"5. Compute explicitly:\n",
"$$\n",
"\\frac{\\pi(a^* \\mid s)}{b(a^* \\mid s)}\n",
"\\quad \\text{and} \\quad\n",
"\\frac{\\pi(a \\mid s)}{b(a \\mid s)} \\quad \\text{for } a \\neq a^*.\n",
"$$\n",
"\n",
"Explain why, whenever $a = a^* $, the importance weight update becomes:\n",
"$$\n",
"W \\leftarrow \\frac{W}{1 - \\varepsilon}.\n",
"$$\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 243,
"id": "f4850733-218a-41c7-b21b-ca49234d8b4c",
"metadata": {},
"outputs": [],
"source": [
"def mc_control_off_policy(\n",
" num_episodes: int = 500_000,\n",
" epsilon: float = 0.2,\n",
" rng: np.random.Generator = rng,\n",
") -> tuple[dict, dict]:\n",
" \"\"\"Off-policy Monte Carlo control using WEIGHTED importance sampling.\n",
"\n",
" Behavior policy b:\n",
" epsilon-soft around the current target policy\n",
" Target policy pi:\n",
" deterministic, greedy w.r.t. Q\n",
"\n",
" Returns:\n",
" Q : dict mapping (state, action) -> action-value estimate\n",
" policy : dict mapping state -> greedy actio\n",
"\n",
" \"\"\"\n",
" Q = defaultdict(float)\n",
" C = defaultdict(float)\n",
" policy = {}\n",
"\n",
" def greedy_action(s: tuple) -> int:\n",
" \"\"\"Pick the greedy action based on current Q.\"\"\"\n",
" return HIT if Q[(s, HIT)] >= Q[(s, STICK)] else STICK\n",
"\n",
" def behavior_policy(s: tuple) -> int:\n",
" \"\"\"Epsilon-soft behavior policy around current target policy.\"\"\"\n",
" p, _d, _u = s\n",
" if p < 12: # noqa: PLR2004\n",
" return HIT\n",
"\n",
" a_star = policy.get(s, greedy_action(s))\n",
"\n",
" if rng.random() < 1 - epsilon:\n",
" return a_star\n",
" return STICK if a_star == HIT else HIT\n",
"\n",
" for _ in range(num_episodes):\n",
" episode, G = generate_episode(behavior_policy, rng)\n",
" W = 1.0\n",
"\n",
" for s, a in reversed(episode):\n",
" if s[0] < 12: # noqa: PLR2004\n",
" continue\n",
"\n",
" C[(s, a)] += W\n",
" Q[(s, a)] += (W / C[(s, a)]) * (G - Q[(s, a)])\n",
"\n",
" a_star = greedy_action(s)\n",
" policy[s] = a_star\n",
"\n",
" if a != a_star:\n",
" break\n",
"\n",
" W *= 1 / (1 - epsilon)\n",
"\n",
" return Q, policy\n"
]
},
{
"cell_type": "markdown",
"id": "cc2a308d-2197-46ab-a902-b1f65fbc9deb",
"metadata": {},
"source": [
"----------------------\n",
"\n",
"**Exercise 10.** (Compare off-policy MC control vs on-policy $\\varepsilon$-soft MC control)\n",
"\n",
"What will be compared?\n",
"\n",
"1. Learned policies: Do they agree on most states?\n",
"\n",
"2. Action-value functions: Are the greedy actions the same?\n",
"\n",
"3. Stability: Which method is noisier for a fixed number of episodes?\n",
"\n",
"**Step 1 — Train both algorithms**"
]
},
{
"cell_type": "code",
"execution_count": 244,
"id": "095a2741-adfc-4599-bd2a-8edc7a5174fe",
"metadata": {},
"outputs": [],
"source": [
"# Train on-policy epsilon-soft MC control\n",
"Q_on, pi_on = mc_control_epsilon_soft(\n",
" num_episodes=200_000,\n",
" N0=100,\n",
" rng=np.random.default_rng(0),\n",
")\n",
"\n",
"# Train off-policy MC control\n",
"Q_off, pi_off = mc_control_off_policy(\n",
" num_episodes=200_000,\n",
" epsilon=0.2,\n",
" rng=np.random.default_rng(0),\n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "468de25e-1928-4fdf-9032-2b979c2ddefa",
"metadata": {},
"source": [
"**Step 2 — Compare greedy actions on a grid of states**"
]
},
{
"cell_type": "code",
"execution_count": 245,
"id": "e5d4f000-66a7-450b-a2aa-762375e9cbd5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Policy agreement rate: 0.955\n"
]
}
],
"source": [
"def policy_agreement(pi1: dict, pi2: dict) -> float:\n",
" \"\"\"Compute the agreement rate between two policies.\n",
"\n",
" Args:\n",
" pi1: first policy dict mapping state -> action\n",
" pi2: second policy dict mapping state -> action\n",
"\n",
" Returns:\n",
" agreement_rate: float in [0, 1] representing the fraction of states\n",
" where both policies agree on the action.\n",
"\n",
" \"\"\"\n",
" total = 0\n",
" agree = 0\n",
"\n",
" for p in range(12, 22):\n",
" for d in range(1, 11):\n",
" for u in [0, 1]:\n",
" s = (p, d, u)\n",
" if s in pi1 and s in pi2:\n",
" total += 1\n",
" if pi1[s] == pi2[s]:\n",
" agree += 1\n",
"\n",
" return agree / total if total > 0 else 0.0\n",
"\n",
"\n",
"agreement = policy_agreement(pi_on, pi_off)\n",
"print(\"Policy agreement rate:\", agreement)\n"
]
},
{
"cell_type": "markdown",
"id": "e49d6c3f-3226-4ce4-baa8-6e7ffdc970ab",
"metadata": {},
"source": [
"**Step 3 — Inspect specific states**"
]
},
{
"cell_type": "code",
"execution_count": 246,
"id": "8632c062-a36d-4b39-aa79-f525b2d03648",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(20, 10, 0) on-policy: STICK | off-policy: STICK\n",
"(13, 2, 0) on-policy: HIT | off-policy: HIT\n",
"(18, 6, 1) on-policy: STICK | off-policy: HIT\n"
]
}
],
"source": [
"test_states = [(20, 10, 0), (13, 2, 0), (18, 6, 1)]\n",
"\n",
"for s in test_states:\n",
" print(\n",
" s,\n",
" \"on-policy:\",\n",
" \"HIT\" if pi_on.get(s) == HIT else \"STICK\",\n",
" \"| off-policy:\",\n",
" \"HIT\" if pi_off.get(s) == HIT else \"STICK\",\n",
" )\n"
]
},
{
"cell_type": "markdown",
"id": "375f68d9-f0fc-4615-b66f-3a32721c1d84",
"metadata": {},
"source": [
"**Step 4 (very slow) - Compare the stability of off-policy MC control vs on-policy $\\varepsilon$-soft MC control**"
]
},
{
"cell_type": "code",
"execution_count": 247,
"id": "b5efb0ca-8de5-432a-8aea-645820ec4961",
"metadata": {},
"outputs": [],
"source": [
"def evaluate_policy(\n",
" pi: dict,\n",
" num_eval_episodes: int = 20_000,\n",
" rng: np.random.Generator = rng,\n",
") -> float:\n",
" \"\"\"Estimate expected return of a policy by simulation. We generate episodes following pi and average terminal rewards.\"\"\"\n",
" if rng is None:\n",
" rng = np.random.default_rng(999)\n",
"\n",
" total = 0.0\n",
" for _ in range(num_eval_episodes):\n",
" _episode, G = generate_episode(policy_fn=lambda s: pi.get(s, HIT), rng=rng)\n",
" total += G\n",
" return total / num_eval_episodes\n"
]
},
{
"cell_type": "code",
"execution_count": 248,
"id": "13b72b63-375d-4150-9b83-687dc2bcb6f6",
"metadata": {},
"outputs": [],
"source": [
"def noise_experiment_performance(\n",
" K: int = 20,\n",
" train_episodes: int = 50_000,\n",
" eval_episodes: int = 20_000,\n",
" epsilon: float = 0.2,\n",
" N0: int = 100,\n",
") -> tuple[np.ndarray, np.ndarray]:\n",
" \"\"\"Compare on-policy vs off-policy MC control performance over K runs.\n",
"\n",
" Args:\n",
" K: number of independent runs\n",
" train_episodes: number of training episodes per run\n",
" eval_episodes: number of evaluation episodes per run\n",
" epsilon: exploration parameter for off-policy MC control\n",
" N0: parameter controlling exploration for on-policy MC control\n",
"\n",
" Returns:\n",
" on_perf: array of on-policy performance estimates\n",
" off_perf: array of off-policy performance estimates\n",
"\n",
" \"\"\"\n",
" on_perf = []\n",
" off_perf = []\n",
"\n",
" for k in range(K):\n",
" rng_train = np.random.default_rng(k)\n",
" rng_eval = np.random.default_rng(10_000 + k)\n",
"\n",
" _Q_on, pi_on = mc_control_epsilon_soft(\n",
" num_episodes=train_episodes,\n",
" N0=N0,\n",
" rng=rng_train,\n",
" )\n",
" on_perf.append(\n",
" evaluate_policy(pi_on, num_eval_episodes=eval_episodes, rng=rng_eval),\n",
" )\n",
"\n",
" _Q_off, pi_off = mc_control_off_policy(\n",
" num_episodes=train_episodes,\n",
" epsilon=epsilon,\n",
" rng=rng_train,\n",
" )\n",
" off_perf.append(\n",
" evaluate_policy(pi_off, num_eval_episodes=eval_episodes, rng=rng_eval),\n",
" )\n",
"\n",
" on_perf = np.array(on_perf)\n",
" off_perf = np.array(off_perf)\n",
"\n",
" print(\n",
" \"=== Estimated return: Higher mean, better policy w.r.t. the return, Low variance, better algorithm stability ===\",\n",
" )\n",
" print(f\"On-policy : mean={on_perf.mean():.4f}, var={on_perf.var():.6f}\")\n",
" print(f\"Off-policy : mean={off_perf.mean():.4f}, var={off_perf.var():.6f}\")\n",
"\n",
" return on_perf, off_perf\n"
]
},
{
"cell_type": "code",
"execution_count": 249,
"id": "8c3bbc38-926c-4af6-8f31-d7465093a84d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=== Estimated return: Higher mean, better policy w.r.t. the return, Low variance, better algorithm stability ===\n",
"On-policy : mean=-0.0534, var=0.000115\n",
"Off-policy : mean=-0.0360, var=0.000137\n"
]
}
],
"source": [
"on_perf, off_perf = noise_experiment_performance(\n",
" K=10,\n",
" train_episodes=50_000,\n",
" eval_episodes=10_000,\n",
" epsilon=0.2,\n",
" N0=100,\n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "4b31bf12-b93b-4e93-8b5b-86a5d94d798f",
"metadata": {},
"source": [
"In fact On-policy $\\varepsilon$-soft MC control\n",
"\n",
"- Learns while exploring\n",
"- Simpler\n",
"- Usually more stable\n",
"- But policy never becomes fully greedy\n",
"\n",
"Off-policy MC control (IS)\n",
"\n",
"- Learns a deterministic greedy policy\n",
"- More flexible\n",
"\n",
"- But:\n",
" - importance weights can explode\n",
" - many episodes contribute nothing\n",
" - higher variance"
]
},
{
"cell_type": "markdown",
"id": "cb1cd119-ffca-4edb-9b46-6c6cf130a970",
"metadata": {},
"source": [
"-------------------------------\n",
"\n",
"\n",
"## 7. Plot the optimal policies and value functions "
]
},
{
"cell_type": "markdown",
"id": "9904a6af-7c92-4721-978b-114b14305957",
"metadata": {},
"source": [
"\n",
"In the Blackjack environment, states are represented by \n",
"$$\n",
"s = (\\text{player sum},\\ \\text{dealer showing},\\ \\text{usable ace}),\n",
"$$\n",
"where we consider:\n",
"- the player sum ranges from **12 to 21**,\n",
"- the dealers visible card ranges from **1 to 10**,\n",
"- the usable ace indicator is either **0 (no usable ace)** or **1 (usable ace)**.\n",
"\n",
"For visualization purposes, we convert value functions and policies—stored as Python dictionaries indexed by states—into **2-dimensional grids**:\n",
"\n",
"- **Rows** correspond to the player sum $$p = 12,\\dots,21$$\n",
"- **Columns** correspond to the dealer showing card $$d = 1,\\dots,10$$\n",
"\n",
"This allows us to:\n",
"- visualize **state-value functions** $$V(s)$$ as heatmaps,\n",
"- visualize **policies** as action maps (HIT or STICK),\n",
"- compare the policies and value functions learned by different Monte Carlo control methods.\n",
"\n",
"The following helper functions perform:\n",
"1. Conversion from a dictionary-based representation to a grid representation.\n",
"2. Visualization of value functions and policies using heatmaps.\n"
]
},
{
"cell_type": "code",
"execution_count": 250,
"id": "6bb5b2b6-8b7d-409b-9af1-6c6270eeaeb9",
"metadata": {},
"outputs": [],
"source": [
"def to_grid_from_V(V: dict, usable_ace_flag: int) -> np.ndarray:\n",
" \"\"\"Convert V dict into a (10 x 10) grid for plotting.\n",
"\n",
" Rows: player sum 12..21 (10 values)\n",
" Cols: dealer showing 1..10 (10 values)\n",
"\n",
" Args:\n",
" V: dict mapping state -> value\n",
" usable_ace_flag: 0 or 1 indicating whether to consider states with usable aces or not\n",
"\n",
" Returns:\n",
" grid: (10 x 10) numpy array of values\n",
"\n",
" \"\"\"\n",
" grid = np.zeros((10, 10))\n",
" for i, p in enumerate(range(12, 22)):\n",
" for j, d in enumerate(range(1, 11)):\n",
" s = (p, d, usable_ace_flag)\n",
" grid[i, j] = V.get(s, 0.0)\n",
" return grid\n",
"\n",
"\n",
"def to_grid_from_policy(policy: dict, usable_ace_flag: int) -> np.ndarray:\n",
" \"\"\"Convert policy dict into a (10 x 10) grid of actions with HIT=1, STICK=0.\n",
"\n",
" Rows: player sum 12..21 (10 values)\n",
" Cols: dealer showing 1..10 (10 values)\n",
"\n",
" Args:\n",
" policy: dict mapping state -> action\n",
" usable_ace_flag: 0 or 1 indicating whether to consider states with usable aces or not\n",
"\n",
" Returns:\n",
" grid: (10 x 10) numpy array of actions (HIT=1, STICK=0)\n",
"\n",
"\n",
" \"\"\"\n",
" grid = np.zeros((10, 10))\n",
" for i, p in enumerate(range(12, 22)):\n",
" for j, d in enumerate(range(1, 11)):\n",
" s = (p, d, usable_ace_flag)\n",
" grid[i, j] = policy.get(s, HIT)\n",
" return grid\n",
"\n",
"\n",
"def plot_value_heatmap(grid: np.ndarray, title: str = \"Value\") -> None:\n",
" \"\"\"Plot heatmap.\n",
"\n",
" Axes:\n",
" y-axis: player sum (12..21)\n",
" x-axis: dealer showing (1..10)\n",
"\n",
" Args:\n",
" grid: (10 x 10) numpy array of values\n",
" title: title of the plot\n",
"\n",
" \"\"\"\n",
" plt.figure(figsize=(7, 5))\n",
" plt.imshow(grid, origin=\"lower\", aspect=\"auto\")\n",
" plt.colorbar()\n",
" plt.xticks(ticks=np.arange(10), labels=[str(i) for i in range(1, 11)])\n",
" plt.yticks(ticks=np.arange(10), labels=[str(i) for i in range(12, 22)])\n",
" plt.xlabel(\"Dealer showing (d)\")\n",
" plt.ylabel(\"Player sum (p)\")\n",
" plt.title(title)\n",
" plt.show()\n",
"\n",
"\n",
"def plot_policy_map(grid: np.ndarray, title: str = \"Policy (1=HIT, 0=STICK)\") -> None:\n",
" \"\"\"Plot policy as a binary heatmap.\n",
"\n",
" Axes:\n",
" y-axis: player sum (12..21)\n",
" x-axis: dealer showing (1..10)\n",
"\n",
" Args:\n",
" grid: (10 x 10) numpy array of actions (HIT=1, STICK=0)\n",
" title: title of the plot\n",
"\n",
" \"\"\"\n",
" plt.figure(figsize=(7, 5))\n",
" plt.imshow(grid, origin=\"lower\", aspect=\"auto\")\n",
" plt.colorbar()\n",
" plt.xticks(ticks=np.arange(10), labels=[str(i) for i in range(1, 11)])\n",
" plt.yticks(ticks=np.arange(10), labels=[str(i) for i in range(12, 22)])\n",
" plt.xlabel(\"Dealer showing (d)\")\n",
" plt.ylabel(\"Player sum (p)\")\n",
" plt.title(title)\n",
" plt.show()\n"
]
},
{
"cell_type": "markdown",
"id": "2f0dad49-8af6-47ff-b3fb-1f0ab58f838d",
"metadata": {},
"source": [
"We now train **three different Monte Carlo control algorithms** on the Blackjack environment. \n",
"Each algorithm aims to learn an **(approximately) optimal policy** by estimating an action-value\n",
"function $Q(s,a)$ and acting greedily with respect to it.\n",
"\n",
"To ensure a fair comparison:\n",
"- All methods are trained using the **same random number generator** (fixed seed),\n",
"- The same number of episodes is used for each method,\n",
"- Each method ultimately produces a **greedy policy** derived from its learned $Q$-function.\n",
"\n",
"The three control strategies considered are:\n",
"\n",
"- **Monte Carlo Control with Exploring Starts** \n",
" Guarantees sufficient exploration by starting episodes from randomly chosen stateaction pairs.\n",
"\n",
"- **On-policy Monte Carlo Control (ε-soft)** \n",
" Learns while following a stochastic ε-soft policy that balances exploration and exploitation.\n",
"\n",
"- **Off-policy Monte Carlo Control (Importance Sampling)** \n",
" Learns an optimal greedy policy while following a different exploratory behavior policy.\n",
"\n",
"The outputs of each method are:\n",
"- an estimated action-value function $ Q(s,a) $,\n",
"- a corresponding greedy policy $ \\hat{\\pi}(s) = \\arg\\max_a Q(s,a) $.\n"
]
},
{
"cell_type": "code",
"execution_count": 251,
"id": "695ee8b0-7d84-42f8-81ff-c1d5f968f381",
"metadata": {},
"outputs": [],
"source": [
"rng_train = np.random.default_rng(0)\n",
"\n",
"Q_es, pi_es = mc_control_exploring_starts(num_episodes=500_000, rng=rng_train)\n",
"Q_on, pi_on = mc_control_epsilon_soft(num_episodes=500_000, N0=100, rng=rng_train)\n",
"Q_off, pi_off = mc_control_off_policy(num_episodes=500_000, epsilon=0.2, rng=rng_train)\n",
"\n",
"policies = {\n",
" \"Exploring Starts\": pi_es,\n",
" \"On-policy ε-soft\": pi_on,\n",
" \"Off-policy IS\": pi_off,\n",
"}\n"
]
},
{
"cell_type": "markdown",
"id": "7eacd54a-73e8-480f-8f7f-a36f69e694b2",
"metadata": {},
"source": [
"**Visual Comparison of Learned Policies**\n",
"\n",
"We now visualize the **greedy policies learned by each Monte Carlo control method**.\n",
"\n",
"For each method, we plot a **policy map** showing the action selected in every state:\n",
"- **HIT (1)** or **STICK (0)**,\n",
"- for all player sums $p = 12,\\dots,21$,\n",
"- and all dealer showing cards $d = 1,\\dots,10$.\n",
"\n",
"Since the presence of a **usable ace** significantly affects the optimal decision in Blackjack,\n",
"we generate separate policy maps for:\n",
"- states **without a usable ace** (`usable_ace` = 0),\n",
"- states **with a usable ace** (`usable_ace` = 1).\n"
]
},
{
"cell_type": "code",
"execution_count": 252,
"id": "ba33e086-60b1-485f-b410-faa805457ac8",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 700x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 700x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 700x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 700x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 700x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 700x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for usable in [0, 1]:\n",
" for name, pi in policies.items():\n",
" grid_pi = to_grid_from_policy(pi, usable)\n",
" plot_policy_map(grid_pi, title=f\"{name} — policy map (usable_ace={usable})\")\n"
]
},
{
"cell_type": "markdown",
"id": "cd81c661-dfe9-4a53-95fe-189846907656",
"metadata": {},
"source": [
"We now perform a **quantitative comparison** of the greedy policies\n",
"learned by the different Monte Carlo control methods.\n",
"\n",
"For two policies $\\pi_1$ and $\\pi_2$, we define the **policy agreement rate** as the\n",
"fraction of states in which both policies choose the **same action**:\n",
"$$\n",
"\\text{Agreement}(\\pi_1,\\pi_2)\n",
"= \\frac{1}{|\\mathcal S|}\n",
"\\sum_{s \\in \\mathcal S}\n",
"\\mathbf{1}\\{\\pi_1(s) = \\pi_2(s)\\}.\n",
"$$\n",
"\n",
"The state space considered here includes all Blackjack states with:\n",
"- player sums $p = 12,\\dots,21$,\n",
"- dealer showing cards $d = 1,\\dots,10$,\n",
"- both cases of usable ace ($u \\in \\{0,1\\}$).\n",
"\n",
"An agreement rate close to **1** indicates that two methods have learned **very similar policies**,\n",
"while lower values reveal systematic differences in decision-making.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 253,
"id": "e6d2c364-d042-42f4-a0ed-ec48b20f9c24",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Agreement Exploring Starts vs On-policy ε-soft: 0.970\n",
"Agreement Exploring Starts vs Off-policy IS: 0.970\n",
"Agreement On-policy ε-soft vs Off-policy IS: 0.980\n"
]
}
],
"source": [
"def policy_agreement(pi_a: dict, pi_b: dict) -> float:\n",
" \"\"\"Compute agreement rate between two policies over all states.\n",
"\n",
" Args:\n",
" pi_a: first policy dict mapping state -> action\n",
" pi_b: second policy dict mapping state -> action\n",
"\n",
" Returns:\n",
" agreement_rate: float in [0, 1] representing the fraction of states\n",
" where both policies agree on the action.\n",
"\n",
" \"\"\"\n",
" states = [(p, d, u) for u in [0, 1] for p in range(12, 22) for d in range(1, 11)]\n",
" agree = 0\n",
" for s in states:\n",
" a = pi_a.get(s, HIT)\n",
" b = pi_b.get(s, HIT)\n",
" agree += int(a == b)\n",
" return agree / len(states)\n",
"\n",
"\n",
"names = list(policies.keys())\n",
"for i in range(len(names)):\n",
" for j in range(i + 1, len(names)):\n",
" a, b = names[i], names[j]\n",
" print(f\"Agreement {a} vs {b}: {policy_agreement(policies[a], policies[b]):.3f}\")\n"
]
},
{
"cell_type": "markdown",
"id": "f4f6ea4b-b3bf-4560-af0c-abd9d6f866c6",
"metadata": {},
"source": [
"**Evaluating the Value Function of the Learned Policies**\n",
"\n",
"The Monte Carlo control algorithms studied so far directly estimate **action-value functions**\n",
"$Q(s,a)$ and output a **greedy policy** $\\hat{\\pi}$.\n",
"However, to compare the *quality* of these learned policies, we must evaluate **how good each policy actually is**.\n",
"\n",
"To do this, we perform **Monte Carlo policy evaluation** for each learned policy $\\hat{\\pi}$.\n",
"More precisely, for a fixed policy $\\pi$, we estimate its **state-value function**:\n",
"$$\n",
"V^{\\pi}(s) = \\mathbb{E}_{\\pi}[G \\mid S_0 = s],\n",
"$$\n",
"where \\( G \\) denotes the terminal return of a Blackjack episode.\n",
"\n",
"We use **first-visit Monte Carlo evaluation** (cf. Exercise 2. but here with another input which is the policy).\n"
]
},
{
"cell_type": "code",
"execution_count": 254,
"id": "ba15bc9c-a5ba-4241-a206-8cbe2bf63811",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 700x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 700x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 700x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 700x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 700x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 700x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def mc_policy_evaluation_V(\n",
" policy: dict, num_episodes: int = 200_000, rng: np.random.Generator = rng,\n",
") -> dict:\n",
" \"\"\"First-visit Monte Carlo policy evaluation that estimates V^pi(s) for all visited states s.\n",
"\n",
" Args:\n",
" policy: dict mapping state -> action\n",
" num_episodes: number of episodes to sample\n",
" rng: random generator used to draw cards\n",
"\n",
" Returns:\n",
" V: dict mapping state -> estimated value\n",
"\n",
" \"\"\"\n",
" if rng is None:\n",
" rng = np.random.default_rng(123)\n",
"\n",
" returns_sum = defaultdict(float)\n",
" returns_count = defaultdict(int)\n",
"\n",
" def pi(s: tuple[int, int, int]) -> int:\n",
" \"\"\"Policy function using the provided policy dict.\"\"\"\n",
" return policy.get(s, HIT)\n",
"\n",
" for _ in range(num_episodes):\n",
" episode, G = generate_episode(pi, rng)\n",
" visited = set()\n",
"\n",
" for s, _a in episode:\n",
" p, _d, _u = s\n",
" if p < 12: # noqa: PLR2004\n",
" continue\n",
"\n",
" if s not in visited:\n",
" visited.add(s)\n",
" returns_sum[s] += G\n",
" returns_count[s] += 1\n",
"\n",
" return {s: returns_sum[s] / returns_count[s] for s in returns_sum}\n",
"\n",
"\n",
"rng_eval = np.random.default_rng(999)\n",
"\n",
"V_of_policy = {}\n",
"for name, pi in policies.items():\n",
" V_of_policy[name] = mc_policy_evaluation_V(pi, num_episodes=300_000, rng=rng_eval)\n",
"\n",
"for usable in [0, 1]:\n",
" for name, V in V_of_policy.items():\n",
" grid_V = to_grid_from_V(V, usable) # noqa: N816\n",
" plot_value_heatmap(grid_V, title=f\"{name} — V^π (usable_ace={usable})\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "469c4300-5fa1-4b2d-bb7d-3da032a1e147",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c0d7e8a-64a0-4fa3-adbb-65e723900d20",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "studies",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}