mirror of
https://github.com/ArthurDanjou/ArtStudies.git
synced 2026-01-14 15:54:13 +01:00
892 lines
210 KiB
Plaintext
892 lines
210 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "44b75d44",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Lab 2 - Second maze\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"id": "100d1e0d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"np.set_printoptions(\n",
|
||
" precision=3,\n",
|
||
" suppress=True,\n",
|
||
") # (not mandatory) This line is for limiting floats to 3 decimal places, avoiding scientific notation (like 1.23e-04) for small numbers.\n",
|
||
"\n",
|
||
"# For reproducibility\n",
|
||
"rng = np.random.default_rng(seed=42) # This line creates a random number generator."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "1018deab",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2. Maze definition and MDP formulation\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ca4fa301-c14f-44ec-b04f-b01ca42d979a",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 2.1 Define the maze "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"id": "f91cda05",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"maze_str = [\n",
|
||
" \"############\",\n",
|
||
" \"#S#.......X#\",\n",
|
||
" \"#.#.###.#.##\",\n",
|
||
" \"#.....#X#..#\",\n",
|
||
" \"#.###.####.#\",\n",
|
||
" \"#...#X#X...#\",\n",
|
||
" \"###.######X#\",\n",
|
||
" \"#.....X...##\",\n",
|
||
" \"#.###.#.#..#\",\n",
|
||
" \"#...#...X#.#\",\n",
|
||
" \"#X#.X#X##G.#\",\n",
|
||
" \"############\",\n",
|
||
"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"id": "564cb757-eefe-4be6-9b6f-bb77ace42a97",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"n_rows = len(maze_str)\n",
|
||
"n_cols = len(maze_str[0])\n",
|
||
"\n",
|
||
"figsize = (n_cols / 2 if n_cols > n_rows else 8, n_rows / 2 if n_rows > n_cols else 8)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "adc49d58-2730-41d8-96fb-ca7c9cb4fcdf",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 2.2 Map each walkable cell (not a wall '#') to a state index\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"id": "7116044b-c134-43de-9f30-01ab62325300",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"FREE = {\n",
|
||
" \".\",\n",
|
||
" \"S\",\n",
|
||
" \"G\",\n",
|
||
" \"X\",\n",
|
||
"} # The vector Free represents cells that the agent is allowed to move into."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "1c9ad05e-9c6c-4e00-918c-44b858f45298",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Dictionaries to convert between grid and state index**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"id": "a1258de4",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Number of states (non-wall cells): 62\n",
|
||
"Start state: 0 at (1, 1)\n",
|
||
"Goal states: [60] at (10, 9)\n",
|
||
"Trap states: [8, 18, 27, 28, 33, 39, 54, 56, 58, 59] at (1, 10)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"state_to_pos = {} # s -> (i,j)\n",
|
||
"pos_to_state = {} # (i,j) -> s\n",
|
||
"\n",
|
||
"start_state = None # will store the state index of start state\n",
|
||
"goal_states = [] # will store the state index of goal state # We use a list in case there are multiple goals\n",
|
||
"trap_states = [] # will store the state index of trap state # We use a list in case there are multiple traps\n",
|
||
"\n",
|
||
"s = 0\n",
|
||
"for i in range(n_rows): # i = row index\n",
|
||
" for j in range(n_cols): # j = column index\n",
|
||
" cell = maze_str[i][j] # cell = the character at that position (S, ., #, etc.)\n",
|
||
"\n",
|
||
" if (\n",
|
||
" cell in FREE\n",
|
||
" ): # FREE contains: free cells \".\", start cell \"S\", goal cell \"G\" and trap cell \"X\"\n",
|
||
" # Walls # are ignored, they are not MDP states.\n",
|
||
" state_to_pos[s] = (i, j)\n",
|
||
" pos_to_state[(i, j)] = s\n",
|
||
"\n",
|
||
" if cell == \"S\":\n",
|
||
" start_state = s\n",
|
||
" elif cell == \"G\":\n",
|
||
" goal_states.append(s)\n",
|
||
" elif cell == \"X\":\n",
|
||
" trap_states.append(s)\n",
|
||
"\n",
|
||
" s += 1\n",
|
||
"\n",
|
||
"n_states = s\n",
|
||
"\n",
|
||
"print(\"Number of states (non-wall cells):\", n_states)\n",
|
||
"print(\"Start state:\", start_state, \"at\", state_to_pos[start_state])\n",
|
||
"print(\"Goal states:\", goal_states, \"at\", state_to_pos[goal_states[0]])\n",
|
||
"print(\"Trap states:\", trap_states, \"at\", state_to_pos[trap_states[0]])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 29,
|
||
"id": "fc61ceef-217c-47f4-8eba-0353369210db",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 800x800 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"def plot_maze_with_states() -> None:\n",
|
||
" \"\"\"Plot the maze with state indices.\"\"\"\n",
|
||
" grid = np.ones(\n",
|
||
" (n_rows, n_cols),\n",
|
||
" ) # Start with a matrix of ones. Here 1 means “free cell”\n",
|
||
" for i in range(n_rows):\n",
|
||
" for j in range(n_cols):\n",
|
||
" if maze_str[i][j] == \"#\":\n",
|
||
" grid[i, j] = 0 # We replace walls (#) with 0\n",
|
||
"\n",
|
||
" _fig, ax = plt.subplots(figsize=figsize)\n",
|
||
" ax.imshow(grid, cmap=\"gray\", alpha=0.7)\n",
|
||
"\n",
|
||
" # Plot state indices\n",
|
||
" for (\n",
|
||
" s,\n",
|
||
" (i, j),\n",
|
||
" ) in state_to_pos.items(): # Calling .items() returns a list-like sequence of (key, value) pairs in the dictionary.\n",
|
||
" cell = maze_str[i][j]\n",
|
||
"\n",
|
||
" if cell == \"S\":\n",
|
||
" label = f\"S\\n{s}\"\n",
|
||
" color = \"green\"\n",
|
||
" elif cell == \"G\":\n",
|
||
" label = f\"G\\n{s}\"\n",
|
||
" color = \"blue\"\n",
|
||
" elif cell == \"X\":\n",
|
||
" label = f\"X\\n{s}\"\n",
|
||
" color = \"red\"\n",
|
||
" else:\n",
|
||
" label = str(s)\n",
|
||
" color = \"black\"\n",
|
||
"\n",
|
||
" ax.text(\n",
|
||
" j,\n",
|
||
" i,\n",
|
||
" label, # Attention : matplotlib, text(x, y, ...) expects (column, row)\n",
|
||
" ha=\"center\",\n",
|
||
" va=\"center\",\n",
|
||
" fontsize=10,\n",
|
||
" fontweight=\"bold\",\n",
|
||
" color=color,\n",
|
||
" )\n",
|
||
"\n",
|
||
" ax.set_xticks([]) # remove numeric axes, we don't need.\n",
|
||
" ax.set_yticks([])\n",
|
||
" ax.set_title(\"Maze with state indices\")\n",
|
||
"\n",
|
||
" plt.show()\n",
|
||
"\n",
|
||
"\n",
|
||
"plot_maze_with_states()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "db078d86",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 2.4 Actions and deterministic movement"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"id": "f7f0b8e4-1f48-4d03-9e5f-a47e59c3e827",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"A_UP, A_RIGHT, A_DOWN, A_LEFT = 0, 1, 2, 3\n",
|
||
"ACTIONS = [A_UP, A_RIGHT, A_DOWN, A_LEFT]\n",
|
||
"action_names = {A_UP: \"\\u2191\", A_RIGHT: \"\\u2192\", A_DOWN: \"\\u2193\", A_LEFT: \"\\u2190\"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"id": "4b06da5e-bc63-48e5-a336-37bce952443d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def move_deterministic(i: int, j: int, a: int) -> tuple[int, int]:\n",
|
||
" \"\"\"Deterministic movement on the grid. If the movement hits a wall or boundary, the agent stays in place.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" i (int): current row index\n",
|
||
" j (int): current column index\n",
|
||
" a (int): action to take (A_UP, A_DOWN, A_LEFT, A_RIGHT)\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" (tuple[int, int]): new (row, column) position after taking action a\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" candidate_i, candidate_j = (\n",
|
||
" i,\n",
|
||
" j,\n",
|
||
" ) # It means “Unless the action succeeds, the robot stays in place.”\n",
|
||
"\n",
|
||
" # Now each action changes the coordinates of the robot:\n",
|
||
" if a == A_UP:\n",
|
||
" candidate_i, candidate_j = (\n",
|
||
" i - 1,\n",
|
||
" j,\n",
|
||
" ) # if the action is UP, then row becomes row -1\n",
|
||
" elif a == A_DOWN:\n",
|
||
" candidate_i, candidate_j = (\n",
|
||
" i + 1,\n",
|
||
" j,\n",
|
||
" ) # if the action is DOWN, then row becomes row +1\n",
|
||
" elif a == A_LEFT:\n",
|
||
" candidate_i, candidate_j = (\n",
|
||
" i,\n",
|
||
" j - 1,\n",
|
||
" ) # if the action is LEFT, then column becomes column -1\n",
|
||
" elif a == A_RIGHT:\n",
|
||
" candidate_i, candidate_j = (\n",
|
||
" i,\n",
|
||
" j + 1,\n",
|
||
" ) # if the action is RIGHT, then column becomes column +1\n",
|
||
"\n",
|
||
" # Check boundaries\n",
|
||
" if not (0 <= candidate_i < n_rows and 0 <= candidate_j < n_cols):\n",
|
||
" # If the robot tries to move outside the maze\n",
|
||
" # It will not move and it stays at (i, j).\n",
|
||
" return i, j\n",
|
||
"\n",
|
||
" # Check wall\n",
|
||
" if maze_str[candidate_i][candidate_j] == \"#\":\n",
|
||
" # If the next cell is a wall, the robot stays in place.\n",
|
||
" return i, j\n",
|
||
"\n",
|
||
" return candidate_i, candidate_j # Otherwise, return the new position"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c9e620e6",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 2.5 Transition probabilities and reward function"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 32,
|
||
"id": "610253e7-f3f7-4a30-be3e-2ec5a1e2ed04",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"gamma = 0.95\n",
|
||
"p_error = 0.1 # probability of the error to a random other direction"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 33,
|
||
"id": "7a51f242-fe4e-4e74-8a1f-a8df32b194b8",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Initialize transition matrices and reward vector\n",
|
||
"P = np.zeros((len(ACTIONS), n_states, n_states))\n",
|
||
"R = np.zeros(n_states)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 34,
|
||
"id": "49d54d1f-dc29-45b6-ad31-ad0e848f920d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Set rewards for each state\n",
|
||
"step_penalty = -0.01\n",
|
||
"goal_reward = 1.0\n",
|
||
"trap_reward = -1.0"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 35,
|
||
"id": "b9b7495a-c233-425c-99c0-5bddaf6c3225",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for s in range(n_states):\n",
|
||
" if s in goal_states:\n",
|
||
" R[s] = goal_reward\n",
|
||
" elif s in trap_states:\n",
|
||
" R[s] = trap_reward\n",
|
||
" else:\n",
|
||
" R[s] = step_penalty"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 36,
|
||
"id": "eca4c571-39c7-468b-af86-0bab9489415e",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"terminal_states = set(goal_states + trap_states)\n",
|
||
"\n",
|
||
"\n",
|
||
"def is_terminal(s: int) -> bool:\n",
|
||
" \"\"\"Check if a state is terminal (goal or trap).\"\"\"\n",
|
||
" return s in terminal_states"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 37,
|
||
"id": "2d03276b-e206-4d1f-9024-f6948ca61523",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for s in range(n_states): # We loop over all states s.\n",
|
||
" i, j = state_to_pos[\n",
|
||
" s\n",
|
||
" ] # We recover the states to their coordinates (i, j) in the maze.\n",
|
||
"\n",
|
||
" # First, in a goal or trap state,\n",
|
||
" # No matter which action you “choose”, you stay in the same state with probability 1.\n",
|
||
" # This makes the terminal states as the absorbing states.\n",
|
||
" if is_terminal(s):\n",
|
||
" # Terminal states: stay forever\n",
|
||
" for a in ACTIONS:\n",
|
||
" P[a, s, s] = goal_reward\n",
|
||
" continue\n",
|
||
"\n",
|
||
" # If the state is non-terminal, we define the stochastic movement.\n",
|
||
" # For a given state s and intended action a,\n",
|
||
" # With probability 1 - p_error, the robot will move in direction a;\n",
|
||
" # With probability p_error, the robot will move in one of the other 3 directions, each with probability p_error / 3.\n",
|
||
" for a in ACTIONS:\n",
|
||
" # main action (intended action)\n",
|
||
" main_i, main_j = move_deterministic(i, j, a)\n",
|
||
" s_main = pos_to_state[\n",
|
||
" (main_i, main_j)\n",
|
||
" ] # s_main is the state index of that next cell.\n",
|
||
" P[a, s, s_main] += (\n",
|
||
" 1 - p_error\n",
|
||
" ) # We add probability 1 - p_error to P[a, s, s_main].\n",
|
||
"\n",
|
||
" # error actions\n",
|
||
" other_actions = [\n",
|
||
" a2 for a2 in ACTIONS if a2 != a\n",
|
||
" ] # other_actions = the 3 actions different from a.\n",
|
||
" for a2 in other_actions: # for each of the error action,\n",
|
||
" error_i, error_j = move_deterministic(i, j, a2)\n",
|
||
" s_error = pos_to_state[(error_i, error_j)] # get its state index s_error\n",
|
||
" P[a, s, s_error] += p_error / len(\n",
|
||
" other_actions,\n",
|
||
" ) # add p_error / 3 to P[a, s, s_error]\n",
|
||
"# So for each (s,a), probabilities over all s_next sum to 1."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 38,
|
||
"id": "341fe630-8f87-4773-84ad-92d3516e53e2",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Action ↑: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||
"Action →: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||
"Action ↓: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||
"Action ←: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"for a in ACTIONS:\n",
|
||
" # For each action a:\n",
|
||
" # P[a] is a matrix of shape (n_states, n_states).\n",
|
||
" # P[a].sum(axis=1) sums over next states s_next, giving for each state s:\n",
|
||
" # We print these row sums.\n",
|
||
" # If everything is correct, they should be very close to 1.\n",
|
||
"\n",
|
||
" probs = P[a].sum(axis=1)\n",
|
||
" print(f\"Action {action_names[a]}:\", probs)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "46d23991",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 3. Policy evaluation\n",
|
||
"\n",
|
||
"### 3.1 Bellman expectation equation"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 39,
|
||
"id": "2fffe0b7",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def policy_evaluation( # noqa: PLR0913\n",
|
||
" policy: np.ndarray,\n",
|
||
" P: np.ndarray,\n",
|
||
" R: np.ndarray,\n",
|
||
" gamma: float,\n",
|
||
" theta: float = 1e-6,\n",
|
||
" max_iter: int = 10_000,\n",
|
||
") -> np.ndarray:\n",
|
||
" \"\"\"Evaluate a deterministic policy for the given MDP.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" policy: array of shape (n_states,), with values in {0,1,2,3}\n",
|
||
" P: array of shape (n_actions, n_states, n_states)\n",
|
||
" R: array of shape (n_states,)\n",
|
||
" gamma: discount factor\n",
|
||
" theta: convergence threshold\n",
|
||
" max_iter: maximum number of iterations\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" n_states = len(R) # get the number of states\n",
|
||
" V = np.zeros(n_states) # initialize the value function\n",
|
||
"\n",
|
||
" for _it in range(max_iter): # Main iterative loop\n",
|
||
" V_new = np.zeros_like(\n",
|
||
" V,\n",
|
||
" ) # Create a new value vector and we will compute an updated value for each state.\n",
|
||
"\n",
|
||
" # Now we update each state using the Bellman expectation equation\n",
|
||
" for s in range(n_states):\n",
|
||
" a = policy[s] # Extract the action chosen by the policy in state\n",
|
||
" V_new[s] = R[s] + gamma * np.sum(P[a, s, :] * V)\n",
|
||
"\n",
|
||
" delta = np.max(\n",
|
||
" np.abs(V_new - V),\n",
|
||
" ) # This measures how much the value function changed in this iteration:\n",
|
||
" # If delta is small, the values start to converge; otherwise, we need to keep iterating.\n",
|
||
" V = V_new # Update V, i.e. Set the new values for the next iteration.\n",
|
||
"\n",
|
||
" if delta < theta: # Check convergence: When changes are tiny, we stop.\n",
|
||
" break\n",
|
||
"\n",
|
||
" return V # Return the final value function, this is our estimate for V^{pi}(s), s in the state set."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 40,
|
||
"id": "4c428327",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def plot_values(V: np.ndarray, title=\"Value function\") -> None:\n",
|
||
" \"\"\"Plot the value function V on the maze as a heatmap.\"\"\"\n",
|
||
" grid_values = np.full(\n",
|
||
" (n_rows, n_cols),\n",
|
||
" np.nan,\n",
|
||
" ) # Initializes a grid the same size as the maze. Every cell starts as NaN.\n",
|
||
" for (\n",
|
||
" s,\n",
|
||
" (i, j),\n",
|
||
" ) in (\n",
|
||
" state_to_pos.items()\n",
|
||
" ): # recall that state_to_pos maps each state index to its maze coordinates (i,j).\n",
|
||
" grid_values[i, j] = V[\n",
|
||
" s\n",
|
||
" ] # For each reachable cell, we write the value V[s] in the grid.\n",
|
||
" # Walls # never get values, and they stay as NaN.\n",
|
||
"\n",
|
||
" _fig, ax = plt.subplots(figsize=figsize)\n",
|
||
" im = ax.imshow(grid_values, cmap=\"magma\")\n",
|
||
" plt.colorbar(im, ax=ax)\n",
|
||
"\n",
|
||
" # For each state:\n",
|
||
" # Place the text label at (column j, row i).\n",
|
||
" # Display value to two decimals.\n",
|
||
" # Use white text so it’s visible on the heatmap.\n",
|
||
" # Center the text inside each cell.\n",
|
||
"\n",
|
||
" for s, (i, j) in state_to_pos.items():\n",
|
||
" ax.text(\n",
|
||
" j,\n",
|
||
" i,\n",
|
||
" f\"{V[s]:.2f}\",\n",
|
||
" ha=\"center\",\n",
|
||
" va=\"center\",\n",
|
||
" color=\"white\",\n",
|
||
" fontsize=9,\n",
|
||
" )\n",
|
||
"\n",
|
||
" # Remove axis ticks and set title\n",
|
||
" ax.set_xticks([])\n",
|
||
" ax.set_yticks([])\n",
|
||
" ax.set_title(title)\n",
|
||
" plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "c1ab67f0-bd5e-4ffe-b655-aec030401b78",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def plot_policy(policy: np.ndarray, title=\"Policy\") -> None:\n",
|
||
" \"\"\"Plot the given policy on the maze.\"\"\"\n",
|
||
" _fig, ax = plt.subplots(figsize=figsize)\n",
|
||
" # draw walls as dark cells\n",
|
||
" wall_grid = np.zeros((n_rows, n_cols))\n",
|
||
" for i in range(n_rows):\n",
|
||
" for j in range(n_cols):\n",
|
||
" if maze_str[i][j] == \"#\":\n",
|
||
" wall_grid[i, j] = 1\n",
|
||
" ax.imshow(wall_grid, cmap=\"Greys\", alpha=0.5)\n",
|
||
"\n",
|
||
" for s, (i, j) in state_to_pos.items():\n",
|
||
" cell = maze_str[i][j]\n",
|
||
" if cell == \"#\":\n",
|
||
" continue\n",
|
||
"\n",
|
||
" if s in goal_states:\n",
|
||
" ax.text(\n",
|
||
" j,\n",
|
||
" i,\n",
|
||
" \"G\",\n",
|
||
" ha=\"center\",\n",
|
||
" va=\"center\",\n",
|
||
" fontsize=14,\n",
|
||
" fontweight=\"bold\",\n",
|
||
" color=\"blue\",\n",
|
||
" )\n",
|
||
" elif s in trap_states:\n",
|
||
" ax.text(\n",
|
||
" j,\n",
|
||
" i,\n",
|
||
" \"X\",\n",
|
||
" ha=\"center\",\n",
|
||
" va=\"center\",\n",
|
||
" fontsize=14,\n",
|
||
" fontweight=\"bold\",\n",
|
||
" color=\"red\",\n",
|
||
" )\n",
|
||
" elif s == start_state:\n",
|
||
" ax.text(\n",
|
||
" j,\n",
|
||
" i,\n",
|
||
" \"S\",\n",
|
||
" ha=\"center\",\n",
|
||
" va=\"center\",\n",
|
||
" fontsize=14,\n",
|
||
" fontweight=\"bold\",\n",
|
||
" color=\"green\",\n",
|
||
" )\n",
|
||
" else:\n",
|
||
" a = policy[s]\n",
|
||
" ax.text(\n",
|
||
" j,\n",
|
||
" i,\n",
|
||
" action_names[a],\n",
|
||
" ha=\"center\",\n",
|
||
" va=\"center\",\n",
|
||
" fontsize=14,\n",
|
||
" color=\"black\",\n",
|
||
" )\n",
|
||
"\n",
|
||
" ax.set_xticks(np.arange(-0.5, n_cols, 1))\n",
|
||
" ax.set_yticks(np.arange(-0.5, n_rows, 1))\n",
|
||
" ax.set_xticklabels([])\n",
|
||
" ax.set_yticklabels([])\n",
|
||
" ax.grid(visible=True)\n",
|
||
" ax.set_title(title)\n",
|
||
" plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "813121a9",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 3.3 Evaluating a policy"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 42,
|
||
"id": "ceb5dfe2",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[0 3 2 1 1 3 0 2 0 0 2 3 2 3 2 3 2 0 3 1 2 1 0 3 3 2 1 3 2 1 1 0 0 2 3 0 3\n",
|
||
" 3 1 2 0 3 2 1 0 3 1 3 2 3 3 0 1 1 1 0 2 0 2 2 3 2]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Random policy: for each state, pick a random action\n",
|
||
"random_policy = rng.integers(low=0, high=len(ACTIONS), size=n_states)\n",
|
||
"\n",
|
||
"print(random_policy)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 43,
|
||
"id": "8f3e2ac2",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Value function under random policy:\n",
|
||
"[ -0.2 -0.295 -0.534 -1.301 -1.394 -1.467 -0.827 -1.176 -20.\n",
|
||
" -0.2 -0.207 -6.086 -0.548 -0.2 -0.201 -0.204 -0.28 -0.481\n",
|
||
" -20. -0.546 -0.566 -0.2 -1.126 -0.588 -0.201 -0.203 -0.209\n",
|
||
" -20. -20. -1.769 -1.186 -1.222 -0.229 -20. -1.944 -0.862\n",
|
||
" -0.824 -1.381 -18.279 -20. -7.557 -6.924 -0.44 -5.78 -17.248\n",
|
||
" -7.364 -0.214 -0.207 -18.427 -17.386 -16.41 -16.347 -17.519 -18.483\n",
|
||
" -20. -0.013 -20. -15.666 -20. -20. 20. 5.496]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"V_random = policy_evaluation(policy=random_policy, P=P, R=R, gamma=gamma)\n",
|
||
"print(\"Value function under random policy:\")\n",
|
||
"print(V_random)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 44,
|
||
"id": "cf45291e",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 800x800 with 2 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 800x800 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"plot_values(V_random, title=\"Value function: random policy\")\n",
|
||
"plot_policy(policy=random_policy, title=\"Random Policy\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 45,
|
||
"id": "5a82a3b7",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 800x800 with 2 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 800x800 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"my_policy = np.array([\n",
|
||
" A_DOWN,\n",
|
||
" A_DOWN,\n",
|
||
" A_LEFT,\n",
|
||
" A_LEFT,\n",
|
||
" A_LEFT,\n",
|
||
" A_LEFT,\n",
|
||
" A_LEFT,\n",
|
||
" A_LEFT,\n",
|
||
" A_LEFT,\n",
|
||
" A_DOWN,\n",
|
||
" A_DOWN,\n",
|
||
" A_UP,\n",
|
||
" A_UP,\n",
|
||
" A_DOWN,\n",
|
||
" A_LEFT,\n",
|
||
" A_LEFT,\n",
|
||
" A_LEFT,\n",
|
||
" A_LEFT,\n",
|
||
" A_UP,\n",
|
||
" A_UP,\n",
|
||
" A_LEFT,\n",
|
||
" A_DOWN,\n",
|
||
" A_UP,\n",
|
||
" A_UP,\n",
|
||
" A_RIGHT,\n",
|
||
" A_RIGHT,\n",
|
||
" A_DOWN,\n",
|
||
" A_UP,\n",
|
||
" A_RIGHT,\n",
|
||
" A_RIGHT,\n",
|
||
" A_RIGHT,\n",
|
||
" A_UP,\n",
|
||
" A_DOWN,\n",
|
||
" A_UP,\n",
|
||
" A_RIGHT,\n",
|
||
" A_RIGHT,\n",
|
||
" A_RIGHT,\n",
|
||
" A_RIGHT,\n",
|
||
" A_DOWN,\n",
|
||
" A_RIGHT,\n",
|
||
" A_RIGHT,\n",
|
||
" A_RIGHT,\n",
|
||
" A_DOWN,\n",
|
||
" A_UP,\n",
|
||
" A_DOWN,\n",
|
||
" A_UP,\n",
|
||
" A_RIGHT,\n",
|
||
" A_DOWN,\n",
|
||
" A_UP,\n",
|
||
" A_LEFT,\n",
|
||
" A_LEFT,\n",
|
||
" A_RIGHT,\n",
|
||
" A_RIGHT,\n",
|
||
" A_UP,\n",
|
||
" A_LEFT,\n",
|
||
" A_DOWN,\n",
|
||
" A_UP,\n",
|
||
" A_UP,\n",
|
||
" A_LEFT,\n",
|
||
" A_UP,\n",
|
||
" A_DOWN,\n",
|
||
" A_LEFT,\n",
|
||
"])\n",
|
||
"\n",
|
||
"V_my_policy = policy_evaluation(policy=my_policy, P=P, R=R, gamma=gamma)\n",
|
||
"\n",
|
||
"plot_values(V=V_my_policy, title=\"Value function: my policy\")\n",
|
||
"plot_policy(policy=my_policy, title=\"My policy\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "446c93e4",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "studies",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.13.9"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|