Files
ArtStudies/M2/Reinforcement Learning/project/Project_RL_DANJOU_VON-SIEMENS.ipynb
2026-03-06 15:28:31 +01:00

1893 lines
109 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "fef45687",
"metadata": {},
"source": [
"# RL Project: Atari Tennis Tournament\n",
"\n",
"This notebook implements four Reinforcement Learning algorithms to play Atari Tennis (`ALE/Tennis-v5` via Gymnasium):\n",
"\n",
"1. **SARSA** — Semi-gradient SARSA with linear approximation (inspired by Lab 7, on-policy update from Lab 5B)\n",
"2. **Q-Learning** — Off-policy linear approximation (inspired by Lab 5B)\n",
"\n",
"Each agent is **pre-trained independently** against the built-in Atari AI opponent, then evaluated in a comparative tournament."
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "b50d7174",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: mps\n"
]
}
],
"source": [
"import pickle\n",
"from collections import deque\n",
"from pathlib import Path\n",
"\n",
"import ale_py # noqa: F401 — registers ALE environments\n",
"import gymnasium as gym\n",
"import supersuit as ss\n",
"from gymnasium.wrappers import FrameStackObservation, ResizeObservation\n",
"from pettingzoo.atari import tennis_v3\n",
"from tqdm.auto import tqdm\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"import torch\n",
"from torch import nn, optim\n",
"\n",
"DEVICE = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {DEVICE}\")\n"
]
},
{
"cell_type": "markdown",
"id": "86047166",
"metadata": {},
"source": [
"# Configuration & Checkpoints\n",
"\n",
"We use a **checkpoint** system (`pickle` serialization) to save and restore trained agent weights. This enables an incremental workflow:\n",
"- Train one agent at a time and save its weights\n",
"- Resume later without retraining previous agents\n",
"- Load all checkpoints for the final evaluation"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "ff3486a4",
"metadata": {},
"outputs": [],
"source": [
"CHECKPOINT_DIR = Path(\"checkpoints\")\n",
"CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n",
"\n",
"\n",
"def get_path(name: str) -> Path:\n",
" \"\"\"Return the checkpoint path for an agent (.pkl).\"\"\"\n",
" base = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\")\n",
" return CHECKPOINT_DIR / (base + \".pkl\")\n"
]
},
{
"cell_type": "markdown",
"id": "ec691487",
"metadata": {},
"source": [
"# Utility Functions\n",
"\n",
"## Observation Normalization\n",
"\n",
"The Tennis environment produces image observations of shape `(4, 84, 84)` after preprocessing (grayscale + resize + frame stack).\n",
"We normalize them into 1D `float64` vectors divided by 255, as in Lab 7 (continuous feature normalization).\n",
"\n",
"## ε-greedy Policy\n",
"\n",
"Follows the pattern from Lab 5B (`epsilon_greedy`) and Lab 7 (`epsilon_greedy_action`):\n",
"- With probability ε: random action (exploration)\n",
"- With probability 1ε: action maximizing $\\hat{q}(s, a)$ with uniform tie-breaking (`np.flatnonzero`)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "be85c130",
"metadata": {},
"outputs": [],
"source": [
"def normalize_obs(observation: np.ndarray) -> np.ndarray:\n",
" \"\"\"Flatten and normalize an observation to a 1D float64 vector.\n",
"\n",
" Replicates the /255.0 normalization used in all agents from the original project.\n",
" For image observations of shape (4, 84, 84), this produces a vector of length 28_224.\n",
"\n",
" Args:\n",
" observation: Raw observation array from the environment.\n",
"\n",
" Returns:\n",
" 1D numpy array of dtype float64, values in [0, 1].\n",
"\n",
" \"\"\"\n",
" return observation.flatten().astype(np.float64) / 255.0\n",
"\n",
"\n",
"def epsilon_greedy(\n",
" q_values: np.ndarray,\n",
" epsilon: float,\n",
" rng: np.random.Generator,\n",
") -> int:\n",
" \"\"\"Select an action using an ε-greedy policy with fair tie-breaking.\n",
"\n",
" Follows the same logic as Lab 5B epsilon_greedy and Lab 7 epsilon_greedy_action:\n",
" - With probability epsilon: choose a random action (exploration).\n",
" - With probability 1-epsilon: choose the action with highest Q-value (exploitation).\n",
" - If multiple actions share the maximum Q-value, break ties uniformly at random.\n",
"\n",
" Handles edge cases: empty q_values, NaN/Inf values.\n",
"\n",
" Args:\n",
" q_values: Array of Q-values for each action, shape (n_actions,).\n",
" epsilon: Exploration probability in [0, 1].\n",
" rng: NumPy random number generator.\n",
"\n",
" Returns:\n",
" Selected action index.\n",
"\n",
" \"\"\"\n",
" q_values = np.asarray(q_values, dtype=np.float64).reshape(-1)\n",
"\n",
" if q_values.size == 0:\n",
" msg = \"q_values is empty.\"\n",
" raise ValueError(msg)\n",
"\n",
" if rng.random() < epsilon:\n",
" return int(rng.integers(0, q_values.size))\n",
"\n",
" # Handle NaN/Inf values safely\n",
" finite_mask = np.isfinite(q_values)\n",
" if not np.any(finite_mask):\n",
" return int(rng.integers(0, q_values.size))\n",
"\n",
" safe_q = q_values.copy()\n",
" safe_q[~finite_mask] = -np.inf\n",
" max_val = np.max(safe_q)\n",
" best = np.flatnonzero(safe_q == max_val)\n",
"\n",
" if best.size == 0:\n",
" return int(rng.integers(0, q_values.size))\n",
"\n",
" return int(rng.choice(best))\n"
]
},
{
"cell_type": "markdown",
"id": "bb53da28",
"metadata": {},
"source": [
"# Agent Definitions\n",
"\n",
"## Base Class `Agent`\n",
"\n",
"Common interface for all agents, same signatures: `get_action`, `update`, `save`, `load`.\n",
"Serialization uses `pickle` (compatible with numpy arrays)."
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "ded9b1fb",
"metadata": {},
"outputs": [],
"source": [
"class Agent:\n",
" \"\"\"Base class for reinforcement learning agents.\n",
"\n",
" All agents share this interface so they are compatible with the tournament system.\n",
" \"\"\"\n",
"\n",
" def __init__(self, seed: int, action_space: int) -> None:\n",
" \"\"\"Initialize the agent with its action space and a reproducible RNG.\"\"\"\n",
" self.action_space = action_space\n",
" self.rng = np.random.default_rng(seed=seed)\n",
"\n",
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
" \"\"\"Select an action from the current observation.\"\"\"\n",
" raise NotImplementedError\n",
"\n",
" def update(\n",
" self,\n",
" state: np.ndarray,\n",
" action: int,\n",
" reward: float,\n",
" next_state: np.ndarray,\n",
" done: bool,\n",
" next_action: int | None = None,\n",
" ) -> None:\n",
" \"\"\"Update agent parameters from one transition.\"\"\"\n",
"\n",
" def save(self, filename: str) -> None:\n",
" \"\"\"Save the agent state to disk using pickle.\"\"\"\n",
" with Path(filename).open(\"wb\") as f:\n",
" pickle.dump(self.__dict__, f)\n",
"\n",
" def load(self, filename: str) -> None:\n",
" \"\"\"Load the agent state from disk.\"\"\"\n",
" with Path(filename).open(\"rb\") as f:\n",
" self.__dict__.update(pickle.load(f)) # noqa: S301\n"
]
},
{
"cell_type": "markdown",
"id": "8a4eae79",
"metadata": {},
"source": [
"## Random Agent (baseline)\n",
"\n",
"Serves as a reference to evaluate the performance of learning agents."
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "78bdc9d2",
"metadata": {},
"outputs": [],
"source": [
"class RandomAgent(Agent):\n",
" \"\"\"A simple agent that selects actions uniformly at random (baseline).\"\"\"\n",
"\n",
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
" \"\"\"Select a random action, ignoring the observation and epsilon.\"\"\"\n",
" _ = observation, epsilon\n",
" return int(self.rng.integers(0, self.action_space))\n"
]
},
{
"cell_type": "markdown",
"id": "5f679032",
"metadata": {},
"source": [
"## SARSA Agent — Linear Approximation (Semi-gradient)\n",
"\n",
"This agent combines:\n",
"- **Linear approximation** from Lab 7 (`SarsaAgent`): $\\hat{q}(s, a; \\mathbf{W}) = \\mathbf{W}_a^\\top \\phi(s)$\n",
"- **On-policy SARSA update** from Lab 5B (`train_sarsa`): $\\delta = r + \\gamma \\hat{q}(s', a') - \\hat{q}(s, a)$\n",
"\n",
"The semi-gradient update rule is:\n",
"$$W_a \\leftarrow W_a + \\alpha \\cdot \\delta \\cdot \\phi(s)$$\n",
"\n",
"where $\\phi(s)$ is the normalized observation vector (analogous to tile coding features in Lab 7, but in dense form)."
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "c124ed9a",
"metadata": {},
"outputs": [],
"source": [
"class SarsaAgent(Agent):\n",
" \"\"\"Semi-gradient SARSA agent with linear function approximation.\n",
"\n",
" Inspired by:\n",
" - Lab 7 SarsaAgent: linear q(s,a) = W_a . phi(s), semi-gradient update\n",
" - Lab 5B train_sarsa: on-policy TD target using Q(s', a')\n",
"\n",
" The weight matrix W has shape (n_actions, n_features).\n",
" For a given state s, q(s, a) = W[a] @ phi(s) is the dot product\n",
" of the action's weight row with the normalized observation.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" n_features: int,\n",
" n_actions: int,\n",
" alpha: float = 0.001,\n",
" gamma: float = 0.99,\n",
" seed: int = 42,\n",
" ) -> None:\n",
" \"\"\"Initialize SARSA agent with linear weights.\n",
"\n",
" Args:\n",
" n_features: Dimension of the feature vector phi(s).\n",
" n_actions: Number of discrete actions.\n",
" alpha: Learning rate (kept small for high-dim features).\n",
" gamma: Discount factor.\n",
" seed: RNG seed for reproducibility.\n",
"\n",
" \"\"\"\n",
" super().__init__(seed, n_actions)\n",
" self.n_features = n_features\n",
" self.alpha = alpha\n",
" self.gamma = gamma\n",
" # Weight matrix: one row per action, analogous to Lab 7's self.w\n",
" # but organized as (n_actions, n_features) for dense features.\n",
" self.W = np.zeros((n_actions, n_features), dtype=np.float64)\n",
"\n",
" def _q_values(self, phi: np.ndarray) -> np.ndarray:\n",
" \"\"\"Compute Q-values for all actions given feature vector phi(s).\n",
"\n",
" Equivalent to Lab 7's self.q(s, a) = self.w[idx].sum()\n",
" but using dense linear approximation: q(s, a) = W[a] @ phi.\n",
"\n",
" Args:\n",
" phi: Normalized feature vector, shape (n_features,).\n",
"\n",
" Returns:\n",
" Array of Q-values, shape (n_actions,).\n",
"\n",
" \"\"\"\n",
" return self.W @ phi # shape (n_actions,)\n",
"\n",
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
" \"\"\"Select action using ε-greedy policy over linear Q-values.\n",
"\n",
" Same pattern as Lab 7 SarsaAgent.eps_greedy:\n",
" compute q-values for all actions, then apply epsilon_greedy.\n",
" \"\"\"\n",
" phi = normalize_obs(observation)\n",
" q_vals = self._q_values(phi)\n",
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
"\n",
" def update(\n",
" self,\n",
" state: np.ndarray,\n",
" action: int,\n",
" reward: float,\n",
" next_state: np.ndarray,\n",
" done: bool,\n",
" next_action: int | None = None,\n",
" ) -> None:\n",
" \"\"\"Perform one semi-gradient SARSA update.\n",
"\n",
" Follows the SARSA update from Lab 5B train_sarsa:\n",
" td_target = r + gamma * Q(s', a') * (0 if done else 1)\n",
" Q(s, a) += alpha * (td_target - Q(s, a))\n",
"\n",
" In continuous form with linear approximation (Lab 7 SarsaAgent.update):\n",
" delta = target - q(s, a)\n",
" W[a] += alpha * delta * phi(s)\n",
"\n",
" Args:\n",
" state: Current observation.\n",
" action: Action taken.\n",
" reward: Reward received.\n",
" next_state: Next observation.\n",
" done: Whether the episode ended.\n",
" next_action: Action chosen in next state (required for SARSA).\n",
"\n",
" \"\"\"\n",
" phi = np.nan_to_num(normalize_obs(state), nan=0.0, posinf=0.0, neginf=0.0)\n",
" q_sa = float(self.W[action] @ phi) # current estimate q(s, a)\n",
" if not np.isfinite(q_sa):\n",
" q_sa = 0.0\n",
"\n",
" if done:\n",
" # Terminal: no future value (Lab 5B: gamma * Q[s2, a2] * 0)\n",
" target = reward\n",
" else:\n",
" # On-policy: use q(s', a') where a' is the actual next action\n",
" # This is the key SARSA property (Lab 5B)\n",
" phi_next = np.nan_to_num(normalize_obs(next_state), nan=0.0, posinf=0.0, neginf=0.0)\n",
" if next_action is None:\n",
" next_action = 0 # fallback, should not happen in practice\n",
" q_sp_ap = float(self.W[next_action] @ phi_next)\n",
" if not np.isfinite(q_sp_ap):\n",
" q_sp_ap = 0.0\n",
" target = float(reward) + self.gamma * q_sp_ap\n",
"\n",
" # Semi-gradient update: W[a] += alpha * delta * phi(s)\n",
" # Analogous to Lab 7: self.w[idx] += self.alpha * delta\n",
" if not np.isfinite(target):\n",
" return\n",
"\n",
" delta = float(target - q_sa)\n",
" if not np.isfinite(delta):\n",
" return\n",
"\n",
" td_step = float(np.clip(delta, -1_000.0, 1_000.0))\n",
" self.W[action] += self.alpha * td_step * phi\n",
" self.W[action] = np.nan_to_num(self.W[action], nan=0.0, posinf=1e6, neginf=-1e6)\n"
]
},
{
"cell_type": "markdown",
"id": "d4e18536",
"metadata": {},
"source": [
"## Q-Learning Agent — Linear Approximation (Off-policy)\n",
"\n",
"Same architecture as SARSA but with the **off-policy update** from Lab 5B (`train_q_learning`):\n",
"\n",
"$$\\delta = r + \\gamma \\max_{a'} \\hat{q}(s', a') - \\hat{q}(s, a)$$\n",
"\n",
"The key difference from SARSA: we use $\\max_{a'} Q(s', a')$ instead of $Q(s', a')$ where $a'$ is the action actually chosen. This allows learning the optimal policy independently of the exploration policy."
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "f5b5b9ea",
"metadata": {},
"outputs": [],
"source": [
"class QLearningAgent(Agent):\n",
" \"\"\"Q-Learning agent with linear function approximation (off-policy).\n",
"\n",
" Inspired by:\n",
" - Lab 5B train_q_learning: off-policy TD target using max_a' Q(s', a')\n",
" - Lab 7 SarsaAgent: linear approximation q(s,a) = W[a] @ phi(s)\n",
"\n",
" The only difference from SarsaAgent is the TD target:\n",
" SARSA uses Q(s', a') (on-policy), Q-Learning uses max_a' Q(s', a') (off-policy).\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" n_features: int,\n",
" n_actions: int,\n",
" alpha: float = 0.001,\n",
" gamma: float = 0.99,\n",
" seed: int = 42,\n",
" ) -> None:\n",
" \"\"\"Initialize Q-Learning agent with linear weights.\n",
"\n",
" Args:\n",
" n_features: Dimension of the feature vector phi(s).\n",
" n_actions: Number of discrete actions.\n",
" alpha: Learning rate.\n",
" gamma: Discount factor.\n",
" seed: RNG seed.\n",
"\n",
" \"\"\"\n",
" super().__init__(seed, n_actions)\n",
" self.n_features = n_features\n",
" self.alpha = alpha\n",
" self.gamma = gamma\n",
" self.W = np.zeros((n_actions, n_features), dtype=np.float64)\n",
"\n",
" def _q_values(self, phi: np.ndarray) -> np.ndarray:\n",
" \"\"\"Compute Q-values for all actions: q(s, a) = W[a] @ phi for each a.\"\"\"\n",
" return self.W @ phi\n",
"\n",
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
" \"\"\"Select action using ε-greedy policy over linear Q-values.\"\"\"\n",
" phi = normalize_obs(observation)\n",
" q_vals = self._q_values(phi)\n",
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
"\n",
" def update(\n",
" self,\n",
" state: np.ndarray,\n",
" action: int,\n",
" reward: float,\n",
" next_state: np.ndarray,\n",
" done: bool,\n",
" next_action: int | None = None,\n",
" ) -> None:\n",
" \"\"\"Perform one Q-learning update.\n",
"\n",
" Follows Lab 5B train_q_learning:\n",
" td_target = r + gamma * max(Q[s2]) * (0 if terminated else 1)\n",
" Q[s, a] += alpha * (td_target - Q[s, a])\n",
"\n",
" In continuous form with linear approximation:\n",
" delta = target - q(s, a)\n",
" W[a] += alpha * delta * phi(s)\n",
" \"\"\"\n",
" _ = next_action # Q-learning is off-policy: next_action is not used\n",
" phi = np.nan_to_num(normalize_obs(state), nan=0.0, posinf=0.0, neginf=0.0)\n",
" q_sa = float(self.W[action] @ phi)\n",
" if not np.isfinite(q_sa):\n",
" q_sa = 0.0\n",
"\n",
" if done:\n",
" # Terminal state: no future value\n",
" # Lab 5B: gamma * np.max(Q[s2]) * (0 if terminated else 1)\n",
" target = reward\n",
" else:\n",
" # Off-policy: use max over all actions in next state\n",
" # This is the key Q-learning property (Lab 5B)\n",
" phi_next = np.nan_to_num(normalize_obs(next_state), nan=0.0, posinf=0.0, neginf=0.0)\n",
" q_next_all = self._q_values(phi_next) # q(s', a') for all a'\n",
" q_next_max = float(np.max(q_next_all))\n",
" if not np.isfinite(q_next_max):\n",
" q_next_max = 0.0\n",
" target = float(reward) + self.gamma * q_next_max\n",
"\n",
" if not np.isfinite(target):\n",
" return\n",
"\n",
" delta = float(target - q_sa)\n",
" if not np.isfinite(delta):\n",
" return\n",
"\n",
" td_step = float(np.clip(delta, -1_000.0, 1_000.0))\n",
" self.W[action] += self.alpha * td_step * phi\n",
" self.W[action] = np.nan_to_num(self.W[action], nan=0.0, posinf=1e6, neginf=-1e6)\n"
]
},
{
"cell_type": "markdown",
"id": "79e6b39f",
"metadata": {},
"source": [
"## Monte Carlo Agent — Linear Approximation (First-visit)\n",
"\n",
"This agent is inspired by Lab 4 (`mc_control_epsilon_soft`):\n",
"- Accumulates transitions in an episode buffer `(state, action, reward)`\n",
"- At the end of the episode (`done=True`), computes **cumulative returns** by traversing the buffer backward:\n",
" $$G \\leftarrow \\gamma \\cdot G + r$$\n",
"- Updates weights with the semi-gradient rule:\n",
" $$W_a \\leftarrow W_a + \\alpha \\cdot (G - \\hat{q}(s, a)) \\cdot \\phi(s)$$\n",
"\n",
"Unlike TD methods (SARSA, Q-Learning), Monte Carlo waits for the complete episode to finish before updating.\n",
"\n",
"> **Note**: This agent currently has **checkpoint loading issues** — the saved weights fail to restore properly, causing the agent to behave as if untrained during evaluation. The training code itself works correctly."
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "7a3aa454",
"metadata": {},
"outputs": [],
"source": [
"class MonteCarloAgent(Agent):\n",
" \"\"\"Monte Carlo control agent with linear function approximation.\n",
"\n",
" Inspired by Lab 4 mc_control_epsilon_soft:\n",
" - Accumulates transitions in an episode buffer\n",
" - At episode end (done=True), computes discounted returns backward:\n",
" G = gamma * G + r (same as Lab 4's reversed loop)\n",
" - Updates weights with semi-gradient: W[a] += alpha * (G - q(s,a)) * phi(s)\n",
"\n",
" Unlike TD methods (SARSA, Q-Learning), no update occurs until the episode ends.\n",
"\n",
" Performance optimizations over naive per-step implementation:\n",
" - float32 weights & features (halves memory bandwidth, faster SIMD)\n",
" - Raw observations stored compactly as uint8, batch-normalized at episode end\n",
" - Vectorized return computation & chunk-based weight updates via einsum\n",
" - Single weight sanitization per episode instead of per-step\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" n_features: int,\n",
" n_actions: int,\n",
" alpha: float = 0.001,\n",
" gamma: float = 0.99,\n",
" seed: int = 42,\n",
" ) -> None:\n",
" \"\"\"Initialize Monte Carlo agent with linear weights and episode buffers.\"\"\"\n",
" super().__init__(seed, n_actions)\n",
" self.n_features = n_features\n",
" self.alpha = alpha\n",
" self.gamma = gamma\n",
" self.W = np.zeros((n_actions, n_features), dtype=np.float32)\n",
" self._obs_buf: list[np.ndarray] = []\n",
" self._act_buf: list[int] = []\n",
" self._rew_buf: list[float] = []\n",
"\n",
" def _q_values(self, phi: np.ndarray) -> np.ndarray:\n",
" return self.W @ phi\n",
"\n",
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
" \"\"\"Select action using ε-greedy policy over linear Q-values.\"\"\"\n",
" phi = observation.flatten().astype(np.float32) / np.float32(255.0)\n",
" q_vals = self._q_values(phi)\n",
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
"\n",
" def update(\n",
" self,\n",
" state: np.ndarray,\n",
" action: int,\n",
" reward: float,\n",
" next_state: np.ndarray,\n",
" done: bool,\n",
" next_action: int | None = None,\n",
" ) -> None:\n",
" \"\"\"Accumulate episode transitions and update weights at episode end.\"\"\"\n",
" _ = next_state, next_action\n",
"\n",
" self._obs_buf.append(state)\n",
" self._act_buf.append(action)\n",
" self._rew_buf.append(reward)\n",
"\n",
" if not done:\n",
" return\n",
"\n",
" n = len(self._rew_buf)\n",
" actions = np.array(self._act_buf, dtype=np.intp)\n",
"\n",
" returns = np.empty(n, dtype=np.float32)\n",
" G = np.float32(0.0)\n",
" gamma32 = np.float32(self.gamma)\n",
" for i in range(n - 1, -1, -1):\n",
" G = gamma32 * G + np.float32(self._rew_buf[i])\n",
" returns[i] = G\n",
"\n",
" alpha32 = np.float32(self.alpha)\n",
" chunk_size = 500\n",
" for start in range(0, n, chunk_size):\n",
" end = min(start + chunk_size, n)\n",
" cs = end - start\n",
"\n",
" raw = np.array(self._obs_buf[start:end])\n",
" phi = raw.reshape(cs, -1).astype(np.float32)\n",
" phi /= np.float32(255.0)\n",
"\n",
" ca = actions[start:end]\n",
" q_sa = np.einsum(\"ij,ij->i\", self.W[ca], phi)\n",
"\n",
" deltas = np.clip(returns[start:end] - q_sa, -1000.0, 1000.0)\n",
"\n",
" for a in range(self.action_space):\n",
" mask = ca == a\n",
" if not np.any(mask):\n",
" continue\n",
" self.W[a] += alpha32 * (deltas[mask] @ phi[mask])\n",
"\n",
" self.W = np.nan_to_num(self.W, nan=0.0, posinf=1e6, neginf=-1e6)\n",
"\n",
" self._obs_buf.clear()\n",
" self._act_buf.clear()\n",
" self._rew_buf.clear()\n"
]
},
{
"cell_type": "markdown",
"id": "d5766fe9",
"metadata": {},
"source": [
"## DQN Agent — PyTorch MLP with Experience Replay and Target Network\n",
"\n",
"This agent implements the Deep Q-Network (DQN) using **PyTorch** for GPU-accelerated training (MPS on Apple Silicon).\n",
"\n",
"**Network architecture** (same structure as before, now as `torch.nn.Module`):\n",
"$$\\text{Input}(n\\_features) \\to \\text{Linear}(256) \\to \\text{ReLU} \\to \\text{Linear}(256) \\to \\text{ReLU} \\to \\text{Linear}(n\\_actions)$$\n",
"\n",
"**Key techniques** (inspired by Lab 6A Dyna-Q + classic DQN):\n",
"- **Experience Replay**: circular buffer of transitions, sampled as minibatches for off-policy updates\n",
"- **Target Network**: periodically synchronized copy of the Q-network, stabilizes learning\n",
"- **Gradient clipping**: prevents exploding gradients in deep networks\n",
"- **GPU acceleration**: tensors on MPS/CUDA device for fast forward/backward passes\n",
"\n",
"> **Note**: This agent currently has **checkpoint loading issues** — the saved `.pt` checkpoint fails to restore properly (device mismatch / state dict incompatibility), causing errors during evaluation. The training code itself works correctly."
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "9c777493",
"metadata": {},
"outputs": [],
"source": [
"class QNetwork(nn.Module):\n",
" \"\"\"MLP Q-network: Input -> 256 -> ReLU -> 256 -> ReLU -> n_actions.\"\"\"\n",
"\n",
" def __init__(self, n_features: int, n_actions: int) -> None:\n",
" \"\"\"Initialize the Q-network architecture.\"\"\"\n",
" super().__init__()\n",
" self.net = nn.Sequential(\n",
" nn.Linear(n_features, 256),\n",
" nn.ReLU(),\n",
" nn.Linear(256, 256),\n",
" nn.ReLU(),\n",
" nn.Linear(256, n_actions),\n",
" )\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" \"\"\"Compute Q-values for input state tensor x.\"\"\"\n",
" return self.net(x)\n",
"\n",
"\n",
"class ReplayBuffer:\n",
" \"\"\"Fixed-size circular replay buffer storing (s, a, r, s', done) transitions.\"\"\"\n",
"\n",
" def __init__(self, capacity: int) -> None:\n",
" \"\"\"Initialize the replay buffer with a maximum capacity.\"\"\"\n",
" self.buffer: deque[tuple[np.ndarray, int, float, np.ndarray, bool]] = deque(maxlen=capacity)\n",
"\n",
" def push(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool) -> None:\n",
" \"\"\"Add a transition to the replay buffer.\"\"\"\n",
" self.buffer.append((state, action, reward, next_state, done))\n",
"\n",
" def sample(self, batch_size: int, rng: np.random.Generator) -> tuple[np.ndarray, ...]:\n",
" \"\"\"Sample a random minibatch of transitions from the buffer.\"\"\"\n",
" indices = rng.choice(len(self.buffer), size=batch_size, replace=False)\n",
" batch = [self.buffer[i] for i in indices]\n",
" states = np.array([t[0] for t in batch])\n",
" actions = np.array([t[1] for t in batch])\n",
" rewards = np.array([t[2] for t in batch])\n",
" next_states = np.array([t[3] for t in batch])\n",
" dones = np.array([t[4] for t in batch], dtype=np.float32)\n",
" return states, actions, rewards, next_states, dones\n",
"\n",
" def __len__(self) -> int:\n",
" \"\"\"Return the current number of transitions stored in the buffer.\"\"\"\n",
" return len(self.buffer)\n",
"\n",
"\n",
"class DQNAgent(Agent):\n",
" \"\"\"Deep Q-Network agent using PyTorch with GPU acceleration (MPS/CUDA).\n",
"\n",
" Inspired by:\n",
" - Lab 6A Dyna-Q: experience replay (store transitions, sample for updates)\n",
" - Classic DQN (Mnih et al., 2015): target network, minibatch SGD\n",
"\n",
" Uses Adam optimizer and Huber loss (smooth L1) for stable training.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" n_features: int,\n",
" n_actions: int,\n",
" lr: float = 1e-4,\n",
" gamma: float = 0.99,\n",
" buffer_size: int = 50_000,\n",
" batch_size: int = 128,\n",
" target_update_freq: int = 1000,\n",
" seed: int = 42,\n",
" ) -> None:\n",
" \"\"\"Initialize DQN agent.\n",
"\n",
" Args:\n",
" n_features: Input feature dimension.\n",
" n_actions: Number of discrete actions.\n",
" lr: Learning rate for Adam optimizer.\n",
" gamma: Discount factor.\n",
" buffer_size: Maximum replay buffer capacity.\n",
" batch_size: Minibatch size for updates.\n",
" target_update_freq: Steps between target network syncs.\n",
" seed: RNG seed.\n",
"\n",
" \"\"\"\n",
" super().__init__(seed, n_actions)\n",
" self.n_features = n_features\n",
" self.lr = lr\n",
" self.gamma = gamma\n",
" self.batch_size = batch_size\n",
" self.target_update_freq = target_update_freq\n",
" self.update_step = 0\n",
"\n",
" # Q-network and target network on GPU\n",
" torch.manual_seed(seed)\n",
" self.q_net = QNetwork(n_features, n_actions).to(DEVICE)\n",
" self.target_net = QNetwork(n_features, n_actions).to(DEVICE)\n",
" self.target_net.load_state_dict(self.q_net.state_dict())\n",
" self.target_net.eval()\n",
"\n",
" self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)\n",
" self.loss_fn = nn.SmoothL1Loss() # Huber loss — more robust than MSE\n",
"\n",
" # Experience replay buffer\n",
" self.replay_buffer = ReplayBuffer(buffer_size)\n",
"\n",
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
" \"\"\"Select action using epsilon-greedy policy over Q-network outputs.\"\"\"\n",
" if self.rng.random() < epsilon:\n",
" return int(self.rng.integers(0, self.action_space))\n",
"\n",
" phi = normalize_obs(observation)\n",
" with torch.no_grad():\n",
" state_t = torch.from_numpy(phi).float().unsqueeze(0).to(DEVICE)\n",
" q_vals = self.q_net(state_t).cpu().numpy().squeeze(0)\n",
" return epsilon_greedy(q_vals, 0.0, self.rng)\n",
"\n",
" def update(\n",
" self,\n",
" state: np.ndarray,\n",
" action: int,\n",
" reward: float,\n",
" next_state: np.ndarray,\n",
" done: bool,\n",
" next_action: int | None = None,\n",
" ) -> None:\n",
" \"\"\"Store transition and perform a minibatch DQN update.\"\"\"\n",
" _ = next_action # DQN is off-policy\n",
"\n",
" # Store transition\n",
" phi_s = normalize_obs(state)\n",
" phi_sp = normalize_obs(next_state)\n",
" self.replay_buffer.push(phi_s, action, reward, phi_sp, done)\n",
"\n",
" if len(self.replay_buffer) < self.batch_size:\n",
" return\n",
"\n",
" # Sample minibatch\n",
" states_b, actions_b, rewards_b, next_states_b, dones_b = self.replay_buffer.sample(\n",
" self.batch_size, self.rng,\n",
" )\n",
"\n",
" # Convert to tensors on device\n",
" states_t = torch.from_numpy(states_b).float().to(DEVICE)\n",
" actions_t = torch.from_numpy(actions_b).long().to(DEVICE)\n",
" rewards_t = torch.from_numpy(rewards_b).float().to(DEVICE)\n",
" next_states_t = torch.from_numpy(next_states_b).float().to(DEVICE)\n",
" dones_t = torch.from_numpy(dones_b).float().to(DEVICE)\n",
"\n",
" # Current Q-values for taken actions\n",
" q_values = self.q_net(states_t)\n",
" q_curr = q_values.gather(1, actions_t.unsqueeze(1)).squeeze(1)\n",
"\n",
" # Target Q-values (off-policy: max over actions in next state)\n",
" with torch.no_grad():\n",
" q_next = self.target_net(next_states_t).max(dim=1).values\n",
" targets = rewards_t + (1.0 - dones_t) * self.gamma * q_next\n",
"\n",
" # Compute loss and update\n",
" loss = self.loss_fn(q_curr, targets)\n",
" self.optimizer.zero_grad()\n",
" loss.backward()\n",
" nn.utils.clip_grad_norm_(self.q_net.parameters(), max_norm=10.0)\n",
" self.optimizer.step()\n",
"\n",
" # Sync target network periodically\n",
" self.update_step += 1\n",
" if self.update_step % self.target_update_freq == 0:\n",
" self.target_net.load_state_dict(self.q_net.state_dict())\n",
"\n",
" def save(self, filename: str) -> None:\n",
" \"\"\"Save agent state using torch.save (networks + optimizer + metadata).\"\"\"\n",
" torch.save(\n",
" {\n",
" \"q_net\": self.q_net.state_dict(),\n",
" \"target_net\": self.target_net.state_dict(),\n",
" \"optimizer\": self.optimizer.state_dict(),\n",
" \"update_step\": self.update_step,\n",
" \"n_features\": self.n_features,\n",
" \"action_space\": self.action_space,\n",
" },\n",
" filename,\n",
" )\n",
"\n",
" def load(self, filename: str) -> None:\n",
" \"\"\"Load agent state from a torch checkpoint.\"\"\"\n",
" checkpoint = torch.load(filename, map_location=DEVICE, weights_only=False)\n",
" self.q_net.load_state_dict(checkpoint[\"q_net\"])\n",
" self.target_net.load_state_dict(checkpoint[\"target_net\"])\n",
" self.optimizer.load_state_dict(checkpoint[\"optimizer\"])\n",
" self.update_step = checkpoint[\"update_step\"]\n",
" self.q_net.to(DEVICE)\n",
" self.target_net.to(DEVICE)\n",
" self.target_net.eval()\n"
]
},
{
"cell_type": "markdown",
"id": "91e51dc8",
"metadata": {},
"source": [
"## Tennis Environment\n",
"\n",
"Creation of the Atari Tennis environment via Gymnasium (`ALE/Tennis-v5`) with standard wrappers:\n",
"- **Grayscale**: `obs_type=\"grayscale\"` — single-channel observations\n",
"- **Resize**: `ResizeObservation(84, 84)` — downscale to 84×84\n",
"- **Frame stack**: `FrameStackObservation(4)` — stack 4 consecutive frames\n",
"\n",
"The final observation is an array of shape `(4, 84, 84)`, which flattens to 28,224 features.\n",
"\n",
"The agent plays against the **built-in Atari AI opponent**."
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "f9a973dd",
"metadata": {},
"outputs": [],
"source": [
"def create_env() -> gym.Env:\n",
" \"\"\"Create the ALE/Tennis-v5 environment with preprocessing wrappers.\n",
"\n",
" Applies:\n",
" - obs_type=\"grayscale\": grayscale observation (210, 160)\n",
" - ResizeObservation(84, 84): downscale to 84x84\n",
" - FrameStackObservation(4): stack 4 consecutive frames -> (4, 84, 84)\n",
"\n",
" Returns:\n",
" Gymnasium environment ready for training.\n",
"\n",
" \"\"\"\n",
" env = gym.make(\"ALE/Tennis-v5\", obs_type=\"grayscale\")\n",
" env = ResizeObservation(env, shape=(84, 84))\n",
" return FrameStackObservation(env, stack_size=4)\n"
]
},
{
"cell_type": "markdown",
"id": "18cb28d8",
"metadata": {},
"source": [
"## Training & Evaluation Infrastructure\n",
"\n",
"Functions for training and evaluating agents in the single-agent Gymnasium environment:\n",
"\n",
"1. **`train_agent`** — Pre-trains an agent against the built-in AI for a given number of episodes with ε-greedy exploration\n",
"2. **`plot_training_curves`** — Plots the training reward history (moving average) for all agents"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06b91580",
"metadata": {},
"outputs": [],
"source": [
"AGENT_COLORS = {\n",
" \"Random\": \"#95a5a6\",\n",
" \"SARSA\": \"#3498db\",\n",
" \"Q-Learning\": \"#e74c3c\",\n",
" \"MonteCarlo\": \"#2ecc71\",\n",
" \"DQN\": \"#9b59b6\",\n",
"}\n",
"\n",
"\n",
"def train_agent(\n",
" env: gym.Env,\n",
" agent: Agent,\n",
" name: str,\n",
" *,\n",
" episodes: int = 5000,\n",
" epsilon_start: float = 1.0,\n",
" epsilon_end: float = 0.05,\n",
" epsilon_decay: float = 0.999,\n",
" max_steps: int = 5000,\n",
") -> list[float]:\n",
" \"\"\"Pre-train an agent against the built-in Atari AI opponent.\n",
"\n",
" Each agent learns independently by playing full episodes. This is the\n",
" self-play pre-training phase: the agent interacts with the environment's\n",
" built-in opponent and updates its parameters after each transition.\n",
"\n",
" Args:\n",
" env: Gymnasium ALE/Tennis-v5 environment.\n",
" agent: Agent instance to train.\n",
" name: Display name for the progress bar.\n",
" episodes: Number of training episodes.\n",
" epsilon_start: Initial exploration rate.\n",
" epsilon_end: Minimum exploration rate.\n",
" epsilon_decay: Multiplicative decay per episode.\n",
" max_steps: Maximum steps per episode.\n",
"\n",
" Returns:\n",
" List of total rewards per episode.\n",
"\n",
" \"\"\"\n",
" rewards_history: list[float] = []\n",
" epsilon = epsilon_start\n",
"\n",
" pbar = tqdm(range(episodes), desc=f\"Training {name}\", leave=True)\n",
"\n",
" for _ep in pbar:\n",
" obs, _info = env.reset()\n",
" obs = np.asarray(obs)\n",
" total_reward = 0.0\n",
"\n",
" action = agent.get_action(obs, epsilon=epsilon)\n",
"\n",
" for _step in range(max_steps):\n",
" next_obs, reward, terminated, truncated, _info = env.step(action)\n",
" next_obs = np.asarray(next_obs)\n",
" done = terminated or truncated\n",
" reward = float(reward)\n",
" total_reward += reward\n",
"\n",
" next_action = agent.get_action(next_obs, epsilon=epsilon) if not done else None\n",
"\n",
" agent.update(\n",
" state=obs,\n",
" action=action,\n",
" reward=reward,\n",
" next_state=next_obs,\n",
" done=done,\n",
" next_action=next_action,\n",
" )\n",
"\n",
" if done:\n",
" break\n",
"\n",
" obs = next_obs\n",
" action = next_action\n",
"\n",
" rewards_history.append(total_reward)\n",
" epsilon = max(epsilon_end, epsilon * epsilon_decay)\n",
"\n",
" recent_window = 50\n",
" if len(rewards_history) >= recent_window:\n",
" recent_avg = np.mean(rewards_history[-recent_window:])\n",
" pbar.set_postfix(\n",
" avg50=f\"{recent_avg:.1f}\",\n",
" eps=f\"{epsilon:.3f}\",\n",
" rew=f\"{total_reward:.0f}\",\n",
" )\n",
"\n",
" return rewards_history\n",
"\n",
"def plot_training_curves(\n",
" training_histories: dict[str, list[float]],\n",
" path: str,\n",
" window: int = 100,\n",
") -> None:\n",
" \"\"\"Plot training reward curves for all agents on a single figure.\n",
"\n",
" Uses a moving average to smooth the curves.\n",
"\n",
" Args:\n",
" training_histories: Dict mapping agent names to reward lists.\n",
" path: File path to save the plot image.\n",
" window: Moving average window size.\n",
"\n",
" \"\"\"\n",
" fig, ax = plt.subplots(figsize=(13, 6))\n",
"\n",
" for name, rewards in training_histories.items():\n",
" color = AGENT_COLORS.get(name, \"#7f8c8d\")\n",
" if len(rewards) >= window:\n",
" ma = np.convolve(rewards, np.ones(window) / window, mode=\"valid\")\n",
" x = np.arange(window - 1, len(rewards))\n",
" ax.plot(x, ma, label=name, color=color, linewidth=2)\n",
" ax.fill_between(x, ma - np.std(rewards) * 0.3, ma + np.std(rewards) * 0.3,\n",
" color=color, alpha=0.1)\n",
" else:\n",
" ax.plot(rewards, label=f\"{name} (raw)\", color=color, linewidth=1.5, alpha=0.7)\n",
"\n",
" ax.set_xlabel(\"Episodes\", fontsize=12)\n",
" ax.set_ylabel(f\"Average Reward (window = {window})\", fontsize=12)\n",
" ax.set_title(\"Training Curves (vs Built-in AI)\", fontsize=14, fontweight=\"bold\")\n",
" ax.legend(fontsize=10, framealpha=0.9)\n",
" ax.grid(visible=True, alpha=0.3, linestyle=\"--\")\n",
" ax.axhline(y=0, color=\"gray\", linewidth=0.8, linestyle=\":\", alpha=0.5)\n",
" ax.spines[\"top\"].set_visible(False)\n",
" ax.spines[\"right\"].set_visible(False)\n",
" fig.tight_layout()\n",
" plt.savefig(path, dpi=150, bbox_inches=\"tight\")\n",
" plt.show()\n"
]
},
{
"cell_type": "markdown",
"id": "9605e9c4",
"metadata": {},
"source": [
"## Agent Instantiation & Incremental Training (One Agent at a Time)\n",
"\n",
"**Environment**: `ALE/Tennis-v5` (grayscale, 84×84×4 frames → 28,224 features, 18 actions).\n",
"\n",
"**Agents**:\n",
"- **Random** — random baseline (no training needed)\n",
"- **SARSA** — linear approximation, semi-gradient TD(0)\n",
"- **Q-Learning** — linear approximation, off-policy\n",
"- **Monte Carlo** — first-visit MC with linear weights (⚠️ checkpoint loading issues)\n",
"- **DQN** — deep Q-network with experience replay and target network (⚠️ `.pt` checkpoint loading issues)\n",
"\n",
"**Workflow**:\n",
"1. Train **one** selected agent (`AGENT_TO_TRAIN`)\n",
"2. Save its weights to `checkpoints/` (`.pkl` for linear agents, `.pt` for DQN)\n",
"3. Repeat later for another agent without retraining previous ones\n",
"4. Load all saved checkpoints before the final evaluation"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "6f6ba8df",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Observation shape : (4, 84, 84)\n",
"Feature vector dim: 28224\n",
"Number of actions : 18\n"
]
}
],
"source": [
"# Create environment\n",
"env = create_env()\n",
"obs, _info = env.reset()\n",
"\n",
"n_actions = int(env.action_space.n)\n",
"n_features = int(np.prod(obs.shape))\n",
"\n",
"print(f\"Observation shape : {obs.shape}\")\n",
"print(f\"Feature vector dim: {n_features}\")\n",
"print(f\"Number of actions : {n_actions}\")\n",
"\n",
"# Instantiate agents\n",
"agent_random = RandomAgent(seed=42, action_space=int(n_actions))\n",
"agent_sarsa = SarsaAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
"agent_q = QLearningAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
"agent_mc = MonteCarloAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
"agent_dqn = DQNAgent(n_features=n_features, n_actions=n_actions)\n",
"\n",
"agents = {\n",
" \"Random\": agent_random,\n",
" \"SARSA\": agent_sarsa,\n",
" \"Q-Learning\": agent_q,\n",
" \"MonteCarlo\": agent_mc,\n",
" \"DQN\": agent_dqn,\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "4d449701",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Selected agent: Q-Learning\n",
"Checkpoint path: checkpoints/q_learning.pkl\n",
"\n",
"============================================================\n",
"Training: Q-Learning (3000 episodes)\n",
"============================================================\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Selected agent: Q-Learning\n",
"Checkpoint path: checkpoints/q_learning.pkl\n",
"\n",
"============================================================\n",
"Training: Q-Learning (3000 episodes)\n",
"============================================================\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0c3a441131c648749a7d938732265143",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training Q-Learning: 0%| | 0/3000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Selected agent: Q-Learning\n",
"Checkpoint path: checkpoints/q_learning.pkl\n",
"\n",
"============================================================\n",
"Training: Q-Learning (3000 episodes)\n",
"============================================================\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0c3a441131c648749a7d938732265143",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training Q-Learning: 0%| | 0/3000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[55]\u001b[39m\u001b[32m, line 28\u001b[39m\n\u001b[32m 25\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mTraining: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mAGENT_TO_TRAIN\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mTRAINING_EPISODES\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m episodes)\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 26\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m=\u001b[39m\u001b[33m'\u001b[39m*\u001b[32m60\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m28\u001b[39m training_histories[AGENT_TO_TRAIN] = \u001b[43mtrain_agent\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 29\u001b[39m \u001b[43m \u001b[49m\u001b[43menv\u001b[49m\u001b[43m=\u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 30\u001b[39m \u001b[43m \u001b[49m\u001b[43magent\u001b[49m\u001b[43m=\u001b[49m\u001b[43magent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 31\u001b[39m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m=\u001b[49m\u001b[43mAGENT_TO_TRAIN\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 32\u001b[39m \u001b[43m \u001b[49m\u001b[43mepisodes\u001b[49m\u001b[43m=\u001b[49m\u001b[43mTRAINING_EPISODES\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 33\u001b[39m \u001b[43m \u001b[49m\u001b[43mepsilon_start\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1.0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 34\u001b[39m \u001b[43m \u001b[49m\u001b[43mepsilon_end\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0.05\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 35\u001b[39m \u001b[43m \u001b[49m\u001b[43mepsilon_decay\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0.999\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 36\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 38\u001b[39m avg_last_100 = np.mean(training_histories[AGENT_TO_TRAIN][-\u001b[32m100\u001b[39m:])\n\u001b[32m 39\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m-> \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mAGENT_TO_TRAIN\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m avg reward (last 100 eps): \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mavg_last_100\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.2f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[53]\u001b[39m\u001b[32m, line 45\u001b[39m, in \u001b[36mtrain_agent\u001b[39m\u001b[34m(env, agent, name, episodes, epsilon_start, epsilon_end, epsilon_decay, max_steps)\u001b[39m\n\u001b[32m 42\u001b[39m action = agent.get_action(obs, epsilon=epsilon)\n\u001b[32m 44\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_steps):\n\u001b[32m---> \u001b[39m\u001b[32m45\u001b[39m next_obs, reward, terminated, truncated, _info = \u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 46\u001b[39m next_obs = np.asarray(next_obs)\n\u001b[32m 47\u001b[39m done = terminated \u001b[38;5;129;01mor\u001b[39;00m truncated\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/stateful_observation.py:425\u001b[39m, in \u001b[36mFrameStackObservation.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 414\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 415\u001b[39m \u001b[38;5;28mself\u001b[39m, action: WrapperActType\n\u001b[32m 416\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 417\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Steps through the environment, appending the observation to the frame buffer.\u001b[39;00m\n\u001b[32m 418\u001b[39m \n\u001b[32m 419\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 423\u001b[39m \u001b[33;03m Stacked observations, reward, terminated, truncated, and info from the environment\u001b[39;00m\n\u001b[32m 424\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m425\u001b[39m obs, reward, terminated, truncated, info = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 426\u001b[39m \u001b[38;5;28mself\u001b[39m.obs_queue.append(obs)\n\u001b[32m 428\u001b[39m updated_obs = deepcopy(\n\u001b[32m 429\u001b[39m concatenate(\u001b[38;5;28mself\u001b[39m.env.observation_space, \u001b[38;5;28mself\u001b[39m.obs_queue, \u001b[38;5;28mself\u001b[39m.stacked_obs)\n\u001b[32m 430\u001b[39m )\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/core.py:560\u001b[39m, in \u001b[36mObservationWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 556\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 557\u001b[39m \u001b[38;5;28mself\u001b[39m, action: ActType\n\u001b[32m 558\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 559\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Modifies the :attr:`env` after calling :meth:`step` using :meth:`self.observation` on the returned observations.\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m560\u001b[39m observation, reward, terminated, truncated, info = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 561\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.observation(observation), reward, terminated, truncated, info\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/common.py:393\u001b[39m, in \u001b[36mOrderEnforcing.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 391\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._has_reset:\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m ResetNeeded(\u001b[33m\"\u001b[39m\u001b[33mCannot call env.step() before calling env.reset()\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m393\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/core.py:327\u001b[39m, in \u001b[36mWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 323\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 324\u001b[39m \u001b[38;5;28mself\u001b[39m, action: WrapperActType\n\u001b[32m 325\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 326\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data.\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m327\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/common.py:285\u001b[39m, in \u001b[36mPassiveEnvChecker.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 283\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m env_step_passive_checker(\u001b[38;5;28mself\u001b[39m.env, action)\n\u001b[32m 284\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m285\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/ale_py/env.py:305\u001b[39m, in \u001b[36mAtariEnv.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 303\u001b[39m reward = \u001b[32m0.0\u001b[39m\n\u001b[32m 304\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(frameskip):\n\u001b[32m--> \u001b[39m\u001b[32m305\u001b[39m reward += \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43male\u001b[49m\u001b[43m.\u001b[49m\u001b[43mact\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrength\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 307\u001b[39m is_terminal = \u001b[38;5;28mself\u001b[39m.ale.game_over(with_truncation=\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m 308\u001b[39m is_truncated = \u001b[38;5;28mself\u001b[39m.ale.game_truncated()\n",
"\u001b[31mKeyboardInterrupt\u001b[39m: "
]
}
],
"source": [
"AGENT_TO_TRAIN = \"Q-Learning\"\n",
"TRAINING_EPISODES = 3000\n",
"FORCE_RETRAIN = True\n",
"\n",
"if AGENT_TO_TRAIN not in agents:\n",
" msg = f\"Unknown agent '{AGENT_TO_TRAIN}'. Available: {list(agents)}\"\n",
" raise ValueError(msg)\n",
"\n",
"training_histories: dict[str, list[float]] = {}\n",
"agent = agents[AGENT_TO_TRAIN]\n",
"checkpoint_path = get_path(AGENT_TO_TRAIN)\n",
"\n",
"print(f\"Selected agent: {AGENT_TO_TRAIN}\")\n",
"print(f\"Checkpoint path: {checkpoint_path}\")\n",
"\n",
"if AGENT_TO_TRAIN == \"Random\":\n",
" print(\"Random is a baseline and is not trained.\")\n",
" training_histories[AGENT_TO_TRAIN] = []\n",
"elif checkpoint_path.exists() and not FORCE_RETRAIN:\n",
" agent.load(str(checkpoint_path))\n",
" print(\"Checkpoint found -> weights loaded, training skipped.\")\n",
" training_histories[AGENT_TO_TRAIN] = []\n",
"else:\n",
" print(f\"\\n{'='*60}\")\n",
" print(f\"Training: {AGENT_TO_TRAIN} ({TRAINING_EPISODES} episodes)\")\n",
" print(f\"{'='*60}\")\n",
"\n",
" training_histories[AGENT_TO_TRAIN] = train_agent(\n",
" env=env,\n",
" agent=agent,\n",
" name=AGENT_TO_TRAIN,\n",
" episodes=TRAINING_EPISODES,\n",
" epsilon_start=1.0,\n",
" epsilon_end=0.05,\n",
" epsilon_decay=0.999,\n",
" )\n",
"\n",
" avg_last_100 = np.mean(training_histories[AGENT_TO_TRAIN][-100:])\n",
" print(f\"-> {AGENT_TO_TRAIN} avg reward (last 100 eps): {avg_last_100:.2f}\")\n",
"\n",
" agent.save(str(checkpoint_path))\n",
" print(\"Checkpoint saved.\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a13a65df",
"metadata": {},
"outputs": [],
"source": [
"plot_training_curves(\n",
" training_histories, f\"plots/{AGENT_TO_TRAIN}_training_curves.png\", window=100,\n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "56959ccd",
"metadata": {},
"source": [
"# Multi-Agent Evaluation & Tournament"
]
},
{
"cell_type": "markdown",
"id": "0698f673",
"metadata": {},
"source": [
"### Loading Checkpoints & Running Evaluation\n",
"\n",
"Before evaluation, we load the saved weights for each trained agent from the `checkpoints/` directory. If a checkpoint is missing, the agent will play with its initial weights (zeros), which is effectively random behavior."
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "5315eb02",
"metadata": {},
"outputs": [],
"source": [
"for name in [\"SARSA\", \"Q-Learning\"]:\n",
" checkpoint_path = get_path(name)\n",
" if checkpoint_path.exists():\n",
" agents[name].load(str(checkpoint_path))\n",
" else:\n",
" print(f\"Warning: Missing checkpoint for {name}\")\n"
]
},
{
"cell_type": "markdown",
"id": "3047c4b0",
"metadata": {},
"source": [
"After individual pre-training against Atari's built-in AI, we move on to **head-to-head evaluation** between our agents.\n",
"\n",
"**PettingZoo Environment**: unlike Gymnasium (single player vs built-in AI), PettingZoo (`tennis_v3`) allows **two Python agents** to play against each other. The same preprocessing wrappers are applied (grayscale, resize to 84×84, 4-frame stacking) via **SuperSuit**.\n",
"\n",
"**Match Protocol**: each matchup is played over **two legs** with swapped positions (`first_0` / `second_0`) to eliminate any side-of-court advantage. Results are tallied as wins, losses, and draws."
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "d04d37e0",
"metadata": {},
"outputs": [],
"source": [
"def create_tournament_env() -> gym.Env:\n",
" \"\"\"Create PettingZoo Tennis env with preprocessing compatible with our agents.\"\"\"\n",
" env = tennis_v3.env(obs_type=\"grayscale_image\")\n",
" env = ss.resize_v1(env, x_size=84, y_size=84)\n",
" env = ss.frame_stack_v1(env, 4)\n",
" return ss.observation_lambda_v0(env, lambda obs, _: np.transpose(obs, (2, 0, 1)))\n",
"\n",
"\n",
"def run_match(\n",
" env: gym.Env,\n",
" agent_first: Agent,\n",
" agent_second: Agent,\n",
" episodes: int = 10,\n",
" max_steps: int = 4000,\n",
") -> dict[str, int]:\n",
" \"\"\"Run multiple PettingZoo episodes between two agents.\n",
"\n",
" Returns wins for global labels {'first': ..., 'second': ..., 'draw': ...}.\n",
" \"\"\"\n",
" wins = {\"first\": 0, \"second\": 0, \"draw\": 0}\n",
"\n",
" for _ep in range(episodes):\n",
" env.reset()\n",
" rewards = {\"first_0\": 0.0, \"second_0\": 0.0}\n",
"\n",
" for step_idx, agent_id in enumerate(env.agent_iter()):\n",
" obs, reward, termination, truncation, _info = env.last()\n",
" done = termination or truncation\n",
" rewards[agent_id] += float(reward)\n",
"\n",
" if done or step_idx >= max_steps:\n",
" action = None\n",
" else:\n",
" current_agent = agent_first if agent_id == \"first_0\" else agent_second\n",
" action = current_agent.get_action(np.asarray(obs), epsilon=0.0)\n",
"\n",
" env.step(action)\n",
"\n",
" if step_idx + 1 >= max_steps:\n",
" break\n",
"\n",
" # Determine winner for the current episode\n",
" if rewards[\"first_0\"] > rewards[\"second_0\"]:\n",
" wins[\"first\"] += 1\n",
" elif rewards[\"second_0\"] > rewards[\"first_0\"]:\n",
" wins[\"second\"] += 1\n",
" else:\n",
" wins[\"draw\"] += 1\n",
"\n",
" return wins\n"
]
},
{
"cell_type": "markdown",
"id": "8fd8215a",
"metadata": {},
"source": [
"### Observation Space Alignment Check: Gymnasium vs PettingZoo"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "fd235ac7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Gymnasium: (4, 84, 84) uint8 0 214\n",
"PettingZoo: (4, 84, 84) uint8 0 214\n"
]
}
],
"source": [
"env_gym = create_env()\n",
"obs_gym, _ = env_gym.reset()\n",
"print(\n",
" \"Gymnasium:\",\n",
" np.asarray(obs_gym).shape,\n",
" np.asarray(obs_gym).dtype,\n",
" np.asarray(obs_gym).min(),\n",
" np.asarray(obs_gym).max(),\n",
")\n",
"\n",
"env_pz = create_tournament_env()\n",
"env_pz.reset()\n",
"for agent_id in env_pz.agent_iter():\n",
" obs_pz, *_ = env_pz.last()\n",
" print(\n",
" \"PettingZoo:\",\n",
" np.asarray(obs_pz).shape,\n",
" np.asarray(obs_pz).dtype,\n",
" np.asarray(obs_pz).min(),\n",
" np.asarray(obs_pz).max(),\n",
" )\n",
" break\n"
]
},
{
"cell_type": "markdown",
"id": "5cfcfdb0",
"metadata": {},
"source": [
"## Evaluate against Atari pre-trained agents"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bee640bb",
"metadata": {},
"outputs": [],
"source": [
"def evaluate_vs_atari(\n",
" env: gym.Env,\n",
" agents: dict[str, Agent],\n",
" episodes: int = 20,\n",
" max_steps: int = 5000,\n",
") -> dict[str, float]:\n",
" \"\"\"Evaluate each agent against the built-in Atari AI opponent over multiple episodes.\"\"\"\n",
" results = {}\n",
" for name, agent in agents.items():\n",
" total_rewards = []\n",
" for _ in range(episodes):\n",
" obs, _info = env.reset()\n",
" obs = np.asarray(obs)\n",
" ep_reward = 0.0\n",
"\n",
" for _ in range(max_steps):\n",
" action = agent.get_action(obs, epsilon=0.0)\n",
" next_obs, reward, terminated, truncated, _info = env.step(action)\n",
" ep_reward += float(reward)\n",
"\n",
" if terminated or truncated:\n",
" break\n",
"\n",
" obs = np.asarray(next_obs)\n",
"\n",
" total_rewards.append(ep_reward)\n",
"\n",
" avg_reward = float(np.mean(total_rewards))\n",
" results[name] = avg_reward\n",
" print(\n",
" f\"{name} vs Atari AI: Score moyen = {avg_reward:.2f} sur {episodes} épisodes\",\n",
" )\n",
"\n",
" return results\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94a804cc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluation against built-in Atari AI opponent:\n",
"Random vs Atari AI: Score moyen = -23.90 sur 10 épisodes\n",
"SARSA vs Atari AI: Score moyen = 0.00 sur 10 épisodes\n",
"Q-Learning vs Atari AI: Score moyen = 0.00 sur 10 épisodes\n",
"MonteCarlo vs Atari AI: Score moyen = -23.80 sur 10 épisodes\n",
"DQN vs Atari AI: Score moyen = -24.00 sur 10 épisodes\n"
]
}
],
"source": [
"print(\"Evaluation against built-in Atari AI opponent:\")\n",
"atari_results = evaluate_vs_atari(env, agents, episodes=10)\n"
]
},
{
"cell_type": "markdown",
"id": "150e6764",
"metadata": {},
"source": [
"## Evaluation against the Random Agent (Baseline)\n",
"\n",
"To quantify whether our agents have actually learned, we first evaluate them against the **Random agent** baseline. A properly trained agent should achieve a **win rate significantly above 50%** against a random opponent.\n",
"\n",
"Each agent plays **two legs** (one in each position) for a total of 20 episodes. Only decisive matches (excluding draws) are counted in the win rate calculation."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b85a88f",
"metadata": {},
"outputs": [],
"source": [
"def evaluate_vs_random(\n",
" agents: dict[str, Agent],\n",
" random_agent_name: str = \"Random\",\n",
" episodes_per_leg: int = 10,\n",
") -> dict[str, float]:\n",
" \"\"\"Evaluate each agent against the specified random agent and compute win rates.\"\"\"\n",
" win_rates = {}\n",
" env = create_tournament_env()\n",
" agent_random = agents[random_agent_name]\n",
"\n",
" for name, agent in agents.items():\n",
" if name == random_agent_name:\n",
" continue\n",
"\n",
" leg1 = run_match(env, agent, agent_random, episodes=episodes_per_leg)\n",
" leg2 = run_match(env, agent_random, agent, episodes=episodes_per_leg)\n",
"\n",
" total_wins = leg1[\"first\"] + leg2[\"second\"]\n",
" total_matches = (episodes_per_leg * 2) - (leg1[\"draw\"] + leg2[\"draw\"])\n",
"\n",
" if total_matches == 0:\n",
" win_rates[name] = 0.5\n",
" else:\n",
" win_rates[name] = total_wins / total_matches\n",
"\n",
" print(\n",
" f\"{name} vs {random_agent_name}: {total_wins} wins out of {total_matches} decisive matches (Win rate: {win_rates[name]:.1%})\",\n",
" )\n",
"\n",
" env.close()\n",
" return win_rates\n",
"\n",
"def plot_evaluation_results(\n",
" atari_results: dict[str, float],\n",
" random_win_rates: dict[str, float],\n",
" save_path: str = \"plots/evaluation_results.png\",\n",
") -> None:\n",
" \"\"\"Plot evaluation results as two side-by-side subplots.\"\"\"\n",
" agent_names = list(atari_results.keys())\n",
" colors = [AGENT_COLORS.get(n, \"#7f8c8d\") for n in agent_names]\n",
"\n",
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
"\n",
" # --- Left panel: Average reward vs Atari AI ---\n",
" atari_scores = [atari_results[n] for n in agent_names]\n",
" x = np.arange(len(agent_names))\n",
" bars1 = ax1.bar(x, atari_scores, color=colors, edgecolor=\"white\", width=0.6)\n",
" ax1.set_xticks(x)\n",
" ax1.set_xticklabels(agent_names, rotation=30, ha=\"right\", fontsize=10)\n",
" ax1.set_ylabel(\"Average Reward\", fontsize=11)\n",
" ax1.set_title(\"Performance vs Atari Built-in AI\", fontsize=13, fontweight=\"bold\")\n",
" ax1.axhline(y=0, color=\"gray\", linestyle=\"--\", linewidth=0.8, alpha=0.5)\n",
" ax1.grid(axis=\"y\", alpha=0.3, linestyle=\"--\")\n",
" ax1.spines[\"top\"].set_visible(False)\n",
" ax1.spines[\"right\"].set_visible(False)\n",
" for bar, score in zip(bars1, atari_scores):\n",
" va = \"bottom\" if score >= 0 else \"top\"\n",
" ax1.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),\n",
" f\"{score:+.1f}\", ha=\"center\", va=va, fontsize=9, fontweight=\"bold\")\n",
"\n",
" # --- Right panel: Win rate vs Random ---\n",
" trained = [n for n in agent_names if n in random_win_rates]\n",
" trained_colors = [AGENT_COLORS.get(n, \"#7f8c8d\") for n in trained]\n",
" rates_pct = [random_win_rates[n] * 100 for n in trained]\n",
" x2 = np.arange(len(trained))\n",
" bars2 = ax2.bar(x2, rates_pct, color=trained_colors, edgecolor=\"white\", width=0.6)\n",
" ax2.set_xticks(x2)\n",
" ax2.set_xticklabels(trained, rotation=30, ha=\"right\", fontsize=10)\n",
" ax2.set_ylabel(\"Win Rate (%)\", fontsize=11)\n",
" ax2.set_title(\"Win Rate vs Random Baseline\", fontsize=13, fontweight=\"bold\")\n",
" ax2.set_ylim(0, 110)\n",
" ax2.axhline(y=50, color=\"gray\", linestyle=\"--\", linewidth=0.8, alpha=0.5, label=\"50 % baseline\")\n",
" ax2.legend(fontsize=8, loc=\"upper right\")\n",
" ax2.grid(axis=\"y\", alpha=0.3, linestyle=\"--\")\n",
" ax2.spines[\"top\"].set_visible(False)\n",
" ax2.spines[\"right\"].set_visible(False)\n",
" for bar, wr in zip(bars2, rates_pct):\n",
" ax2.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 1.5,\n",
" f\"{wr:.0f} %\", ha=\"center\", va=\"bottom\", fontsize=9, fontweight=\"bold\")\n",
"\n",
" fig.suptitle(\"Agent Evaluation Summary\", fontsize=15, fontweight=\"bold\", y=1.02)\n",
" fig.tight_layout()\n",
" plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a053644e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluation against the Random agent\n",
"SARSA vs Random: 0 wins out of 20 decisive matches (Win rate: 0.0%)\n",
"Q-Learning vs Random: 0 wins out of 20 decisive matches (Win rate: 0.0%)\n",
"MonteCarlo vs Random: 8 wins out of 18 decisive matches (Win rate: 44.4%)\n",
"DQN vs Random: 6 wins out of 12 decisive matches (Win rate: 50.0%)\n"
]
}
],
"source": [
"print(\"Evaluation against the Random agent\")\n",
"win_rates_vs_random = evaluate_vs_random(\n",
" agents, random_agent_name=\"Random\", episodes_per_leg=10,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5fe97c3d",
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "'float' object is not subscriptable",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mTypeError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[62]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mplot_evaluation_results\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2\u001b[39m \u001b[43m \u001b[49m\u001b[43matari_results\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwin_rates_vs_random\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msave_path\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mplots/evaluation_results.png\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 3\u001b[39m \u001b[43m)\u001b[49m\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[61]\u001b[39m\u001b[32m, line 66\u001b[39m, in \u001b[36mplot_evaluation_results\u001b[39m\u001b[34m(atari_results, random_results, save_path)\u001b[39m\n\u001b[32m 63\u001b[39m fig, (ax1, ax2) = plt.subplots(\u001b[32m1\u001b[39m, \u001b[32m2\u001b[39m, figsize=(\u001b[32m14\u001b[39m, \u001b[32m5\u001b[39m))\n\u001b[32m 65\u001b[39m \u001b[38;5;66;03m# --- Left: Average reward vs Atari AI with error bars ---\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m66\u001b[39m means = [\u001b[43matari_results\u001b[49m\u001b[43m[\u001b[49m\u001b[43mn\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mmean\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m n \u001b[38;5;129;01min\u001b[39;00m agent_names]\n\u001b[32m 67\u001b[39m stds = [atari_results[n][\u001b[33m\"\u001b[39m\u001b[33mstd\u001b[39m\u001b[33m\"\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m n \u001b[38;5;129;01min\u001b[39;00m agent_names]\n\u001b[32m 68\u001b[39m x = np.arange(\u001b[38;5;28mlen\u001b[39m(agent_names))\n",
"\u001b[31mTypeError\u001b[39m: 'float' object is not subscriptable"
]
},
{
"ename": "TypeError",
"evalue": "'float' object is not subscriptable",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mTypeError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[62]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mplot_evaluation_results\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2\u001b[39m \u001b[43m \u001b[49m\u001b[43matari_results\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwin_rates_vs_random\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msave_path\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mplots/evaluation_results.png\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 3\u001b[39m \u001b[43m)\u001b[49m\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[61]\u001b[39m\u001b[32m, line 66\u001b[39m, in \u001b[36mplot_evaluation_results\u001b[39m\u001b[34m(atari_results, random_results, save_path)\u001b[39m\n\u001b[32m 63\u001b[39m fig, (ax1, ax2) = plt.subplots(\u001b[32m1\u001b[39m, \u001b[32m2\u001b[39m, figsize=(\u001b[32m14\u001b[39m, \u001b[32m5\u001b[39m))\n\u001b[32m 65\u001b[39m \u001b[38;5;66;03m# --- Left: Average reward vs Atari AI with error bars ---\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m66\u001b[39m means = [\u001b[43matari_results\u001b[49m\u001b[43m[\u001b[49m\u001b[43mn\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mmean\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m n \u001b[38;5;129;01min\u001b[39;00m agent_names]\n\u001b[32m 67\u001b[39m stds = [atari_results[n][\u001b[33m\"\u001b[39m\u001b[33mstd\u001b[39m\u001b[33m\"\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m n \u001b[38;5;129;01min\u001b[39;00m agent_names]\n\u001b[32m 68\u001b[39m x = np.arange(\u001b[38;5;28mlen\u001b[39m(agent_names))\n",
"\u001b[31mTypeError\u001b[39m: 'float' object is not subscriptable"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABHsAAAGyCAYAAAB0jsg1AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAIQZJREFUeJzt3X2MFtXZB+CzgICmgloKCEWpWkWLgoJsQY2xoZJosPzRlKoBSvyo1RoLaQVEwW+srxqSukpErf5RC2rEGCFYpRJjpSGCJNoKRlGhRhao5aOooDBvZprdsrhYF3efnb33upIRZnZmn7PPkZ17f3vmnKosy7IEAAAAQAgdWrsBAAAAADQfYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAew57XnrppTR69OjUp0+fVFVVlZ5++un/ec3SpUvTaaedlrp06ZKOO+649MgjjxxoewEA2hS1EwBQ+rBnx44dadCgQammpuYrnf/uu++m888/P51zzjlp1apV6Ve/+lW69NJL03PPPXcg7QUAaFPUTgBApVVlWZYd8MVVVWnBggVpzJgx+z1nypQpaeHChemNN96oP/bTn/40bdmyJS1evPhAXxoAoM1ROwEAldCppV9g2bJlaeTIkQ2OjRo1qhjhsz87d+4stjp79uxJH330UfrmN79ZFEkAQDnlv0Pavn178bh3hw6mBqxU7ZRTPwFA25S1QP3U4mHPhg0bUq9evRocy/e3bduWPvnkk3TwwQd/4ZpZs2alm266qaWbBgC0kPXr16dvf/vb3t8K1U459RMAtG3rm7F+avGw50BMmzYtTZ48uX5/69at6aijjiq+8G7durVq2wCA/csDiX79+qVDDz3U21Rh6icAaJu2tUD91OJhT+/evVNtbW2DY/l+Htrs7zdT+apd+bav/BphDwCUn8euK1s75dRPANC2VTXjtDUt/jD98OHD05IlSxoce/7554vjAAConQCA5tXksOff//53sYR6vtUtrZ7/fd26dfVDiMePH19//hVXXJHWrl2brr322rR69ep03333pccffzxNmjSpOb8OAIBSUjsBAKUPe1599dV06qmnFlsun1sn//uMGTOK/Q8//LA++Ml95zvfKZZez0fzDBo0KN19993pwQcfLFaVAACITu0EAFRaVZav8dUGJivq3r17MVGzOXsAoLzcs8tDXwBA+71nt/icPQAAAABUjrAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAoL2HPTU1Nal///6pa9euqbq6Oi1fvvxLz589e3Y64YQT0sEHH5z69euXJk2alD799NMDbTMAQJujfgIAShv2zJ8/P02ePDnNnDkzrVy5Mg0aNCiNGjUqbdy4sdHzH3vssTR16tTi/DfffDM99NBDxee47rrrmqP9AAClp34CAEod9txzzz3psssuSxMnTkwnnXRSmjNnTjrkkEPSww8/3Oj5r7zySjrjjDPSRRddVIwGOvfcc9OFF174P0cDAQBEoX4CAEob9uzatSutWLEijRw58r+foEOHYn/ZsmWNXjNixIjimrpwZ+3atWnRokXpvPPO2+/r7Ny5M23btq3BBgDQFqmfAIBK69SUkzdv3px2796devXq1eB4vr969epGr8lH9OTXnXnmmSnLsvT555+nK6644ksf45o1a1a66aabmtI0AIBSUj8BAOFW41q6dGm6/fbb03333VfM8fPUU0+lhQsXpltuuWW/10ybNi1t3bq1flu/fn1LNxMAoDTUTwBAxUb29OjRI3Xs2DHV1tY2OJ7v9+7du9FrbrjhhjRu3Lh06aWXFvsnn3xy2rFjR7r88svT9OnTi8fA9tWlS5diAwBo69RPAECpR/Z07tw5DRkyJC1ZsqT+2J49e4r94cOHN3rNxx9//IVAJw+McvljXQAAkamfAIBSj+zJ5cuuT5gwIQ0dOjQNGzYszZ49uxipk6/OlRs/fnzq27dvMe9ObvTo0cUKFKeeemqqrq5Ob7/9djHaJz9eF/oAAESmfgIASh32jB07Nm3atCnNmDEjbdiwIQ0ePDgtXry4ftLmdevWNRjJc/3116eqqqrizw8++CB961vfKoKe2267rXm/EgCAklI/AQCVVJW1gWep8qXXu3fvXkzW3K1bt9ZuDgCwH+7Z5aEvAKD93rNbfDUuAAAAACpH2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAADQ3sOempqa1L9//9S1a9dUXV2dli9f/qXnb9myJV111VXpyCOPTF26dEnHH398WrRo0YG2GQCgzVE/AQCV0qmpF8yfPz9Nnjw5zZkzpwh6Zs+enUaNGpXWrFmTevbs+YXzd+3alX74wx8WH3vyySdT37590/vvv58OO+yw5voaAABKTf0EAFRSVZZlWVMuyAOe008/Pd17773F/p49e1K/fv3S1VdfnaZOnfqF8/NQ6P/+7//S6tWr00EHHXRAjdy2bVvq3r172rp1a+rWrdsBfQ4AoOW5ZzdO/QQAVLJ+atJjXPkonRUrVqSRI0f+9xN06FDsL1u2rNFrnnnmmTR8+PDiMa5evXqlgQMHpttvvz3t3r17v6+zc+fO4ovdewMAaIvUTwBApTUp7Nm8eXMR0uShzd7y/Q0bNjR6zdq1a4vHt/Lr8nl6brjhhnT33XenW2+9db+vM2vWrCLVqtvykUMAAG2R+gkACLcaV/6YVz5fzwMPPJCGDBmSxo4dm6ZPn1483rU/06ZNK4Yv1W3r169v6WYCAJSG+gkAqNgEzT169EgdO3ZMtbW1DY7n+7179270mnwFrnyunvy6OieeeGIxEigf1ty5c+cvXJOv2JVvAABtnfoJACj1yJ48mMlH5yxZsqTBb57y/XxensacccYZ6e233y7Oq/PWW28VIVBjQQ8AQCTqJwCg9I9x5cuuz507Nz366KPpzTffTL/4xS/Sjh070sSJE4uPjx8/vngMq07+8Y8++ihdc801RcizcOHCYoLmfMJmAID2QP0EAJT2Ma5cPufOpk2b0owZM4pHsQYPHpwWL15cP2nzunXrihW66uSTKz/33HNp0qRJ6ZRTTkl9+/Ytgp8pU6Y071cCAFBS6icAoJKqsizLUjtccx4AaH7u2eWhLwCg/d6zW3w1LgAAAAAqR9gDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAA0N7DnpqamtS/f//UtWvXVF1dnZYvX/6Vrps3b16qqqpKY8aMOZCXBQBos9RPAEBpw5758+enyZMnp5kzZ6aVK1emQYMGpVGjRqWNGzd+6XXvvfde+vWvf53OOuusr9NeAIA2R/0EAJQ67LnnnnvSZZddliZOnJhOOumkNGfOnHTIIYekhx9+eL/X7N69O1188cXppptuSsccc8zXbTMAQJuifgIAShv27Nq1K61YsSKNHDnyv5+gQ4dif9myZfu97uabb049e/ZMl1xyyVd6nZ07d6Zt27Y12AAA2iL1EwBQ6rBn8+bNxSidXr16NTie72/YsKHRa15++eX00EMPpblz537l15k1a1bq3r17/davX7+mNBMAoDTUTwBAqNW4tm/fnsaNG1cEPT169PjK102bNi1t3bq1flu/fn1LNhMAoDTUTwDA19WpKSfngU3Hjh1TbW1tg+P5fu/evb9w/jvvvFNMzDx69Oj6Y3v27PnPC3fqlNasWZOOPfbYL1zXpUuXYgMAaOvUTwBAqUf2dO7cOQ0ZMiQtWbKkQXiT7w8fPvwL5w8YMCC9/vrradWqVfXbBRdckM4555zi7x7PAgCiUz8BAKUe2ZPLl12fMGFCGjp0aBo2bFiaPXt22rFjR7E6V278+PGpb9++xbw7Xbt2TQMHDmxw/WGHHVb8ue9xAICo1E8AQKnDnrFjx6ZNmzalGTNmFJMyDx48OC1evLh+0uZ169YVK3QBAKB+AgAqryrLsiyVXL70er4qVz5Zc7du3Vq7OQDAfrhnl4e+AID2e882BAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQHsPe2pqalL//v1T165dU3V1dVq+fPl+z507d24666yz0uGHH15sI0eO/NLzAQAiUj8BAKUNe+bPn58mT56cZs6cmVauXJkGDRqURo0alTZu3Njo+UuXLk0XXnhhevHFF9OyZctSv3790rnnnps++OCD5mg/AEDpqZ8AgEqqyrIsa8oF+Uie008/Pd17773F/p49e4oA5+qrr05Tp079n9fv3r27GOGTXz9+/Piv9Jrbtm1L3bt3T1u3bk3dunVrSnMBgApyz26c+gkAqGT91KSRPbt27UorVqwoHsWq/wQdOhT7+aidr+Ljjz9On332WTriiCP2e87OnTuLL3bvDQCgLVI/AQCV1qSwZ/PmzcXInF69ejU4nu9v2LDhK32OKVOmpD59+jQIjPY1a9asItWq2/KRQwAAbZH6CQAIvRrXHXfckebNm5cWLFhQTO68P9OmTSuGL9Vt69evr2QzAQBKQ/0EADRVp6ac3KNHj9SxY8dUW1vb4Hi+37t37y+99q677iqKlRdeeCGdcsopX3puly5dig0AoK1TPwEApR7Z07lz5zRkyJC0ZMmS+mP5BM35/vDhw/d73Z133pluueWWtHjx4jR06NCv12IAgDZE/QQAlHpkTy5fdn3ChAlFaDNs2LA0e/bstGPHjjRx4sTi4/kKW3379i3m3cn99re/TTNmzEiPPfZY6t+/f/3cPt/4xjeKDQAgOvUTAFDqsGfs2LFp06ZNRYCTBzeDBw8uRuzUTdq8bt26YoWuOvfff3+xCsWPf/zjBp9n5syZ6cYbb2yOrwEAoNTUTwBAJVVlWZaldrjmPADQ/Nyzy0NfAED7vWdXdDUuAAAAAFqWsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAACgvYc9NTU1qX///qlr166puro6LV++/EvPf+KJJ9KAAQOK808++eS0aNGiA20vAECbpH4CAEob9syfPz9Nnjw5zZw5M61cuTINGjQojRo1Km3cuLHR81955ZV04YUXpksuuSS99tpracyYMcX2xhtvNEf7AQBKT/0EAFRSVZZlWVMuyEfynH766enee+8t9vfs2ZP69euXrr766jR16tQvnD927Ni0Y8eO9Oyzz9Yf+/73v58GDx6c5syZ85Vec9u2bal79+5p69atqVu3bk1pLgBQQe7ZjVM/AQCVrJ86NeXkXbt2pRUrVqRp06bVH+vQoUMaOXJkWrZsWaPX5MfzkUB7y0cCPf300/t9nZ07dxZbnfwLrnsDAIDyqrtXN/F3SaGpnwCAStdPTQp7Nm/enHbv3p169erV4Hi+v3r16kav2bBhQ6Pn58f3Z9asWemmm276wvF8BBEAUH7//Oc/i99QoX4CACpfPzUp7KmUfOTQ3qOBtmzZko4++ui0bt06hWMrp4154LZ+/XqP07UyfVEe+qIc9EN55KNxjzrqqHTEEUe0dlPaHfVTOfn+VB76ohz0Q3noi9j1U5PCnh49eqSOHTum2traBsfz/d69ezd6TX68KefnunTpUmz7yhMuc/a0vrwP9EM56Ivy0BfloB/KI3/Mm/9QP5Hz/ak89EU56Ify0Bcx66cmfabOnTunIUOGpCVLltQfyydozveHDx/e6DX58b3Pzz3//PP7PR8AIBL1EwBQaU1+jCt/vGrChAlp6NChadiwYWn27NnFalsTJ04sPj5+/PjUt2/fYt6d3DXXXJPOPvvsdPfdd6fzzz8/zZs3L7366qvpgQceaP6vBgCghNRPAECpw558KfVNmzalGTNmFJMs50uoL168uH4S5nxenb2HHo0YMSI99thj6frrr0/XXXdd+u53v1usxDVw4MCv/Jr5I10zZ85s9NEuKkc/lIe+KA99UQ76oTz0RePUT+2XfxPloS/KQT+Uh76I3RdVmbVRAQAAAMIweyIAAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAilN2FNTU5P69++funbtmqqrq9Py5cu/9PwnnngiDRgwoDj/5JNPTosWLapYWyNrSj/MnTs3nXXWWenwww8vtpEjR/7PfqNl+mJv8+bNS1VVVWnMmDHe7lbqiy1btqSrrroqHXnkkcWM+scff7zvUa3QD7Nnz04nnHBCOvjgg1O/fv3SpEmT0qefftocTWm3XnrppTR69OjUp0+f4vtMvrrm/7J06dJ02mmnFf8WjjvuuPTII49UpK3tgdqpPNRP5aF+Kge1U3mon9px/ZSVwLx587LOnTtnDz/8cPa3v/0tu+yyy7LDDjssq62tbfT8v/zlL1nHjh2zO++8M/v73/+eXX/99dlBBx2Uvf766xVveyRN7YeLLrooq6mpyV577bXszTffzH72s59l3bt3z/7xj39UvO3tvS/qvPvuu1nfvn2zs846K/vRj35UsfZG1tS+2LlzZzZ06NDsvPPOy15++eWiT5YuXZqtWrWq4m1vz/3whz/8IevSpUvxZ94Hzz33XHbkkUdmkyZNqnjbI1m0aFE2ffr07KmnnsryEmLBggVfev7atWuzQw45JJs8eXJxv/7d735X3L8XL15csTZHpXYqD/VTeaifykHtVB7qp/ZdP5Ui7Bk2bFh21VVX1e/v3r0769OnTzZr1qxGz//JT36SnX/++Q2OVVdXZz//+c9bvK2RNbUf9vX5559nhx56aPboo4+2YCvbhwPpi/z9HzFiRPbggw9mEyZMEPa0Ul/cf//92THHHJPt2rWruZrAAfRDfu4PfvCDBsfyG+YZZ5zh/WwmX6VYufbaa7Pvfe97DY6NHTs2GzVqlH74mtRO5aF+Kg/1UzmoncpD/dS+66dWf4xr165dacWKFcUjQHU6dOhQ7C9btqzRa/Lje5+fGzVq1H7Pp2X6YV8ff/xx+uyzz9IRRxzhLW+Fvrj55ptTz5490yWXXOL9b8W+eOaZZ9Lw4cOLx7h69eqVBg4cmG6//fa0e/du/VLBfhgxYkRxTd2jXmvXri0epTvvvPP0QwW5X7cMtVN5qJ/KQ/1UDmqn8lA/tV3NVT91Sq1s8+bNxQ9B+Q9Fe8v3V69e3eg1GzZsaPT8/DiV64d9TZkypXgOcd//MWn5vnj55ZfTQw89lFatWuXtbuW+yEOFP//5z+niiy8uwoW33347XXnllUUQOnPmTP1ToX646KKLiuvOPPPMfARr+vzzz9MVV1yRrrvuOn1QQfu7X2/bti198sknxXxKNJ3aqTzUT+WhfioHtVN5qJ/aruaqn1p9ZA8x3HHHHcXEwAsWLCgmT6Vytm/fnsaNG1dMmN2jRw9vfSvbs2dPMcLqgQceSEOGDEljx45N06dPT3PmzGntprUr+aR2+Yiq++67L61cuTI99dRTaeHChemWW25p7aYB1FM/tR71U3moncpD/RRLq4/syX847dixY6qtrW1wPN/v3bt3o9fkx5tyPi3TD3Xuuuuuolh54YUX0imnnOLtrnBfvPPOO+m9994rZnjf+6aZ69SpU1qzZk069thj9UsF+iKXr8B10EEHFdfVOfHEE4uEPh9O27lzZ31RgX644YYbihD00ksvLfbzVRt37NiRLr/88iJ8yx8Do+Xt737drVs3o3q+BrVTeaifykP9VA5qp/JQP7VdzVU/tXq1m//gk//2e8mSJQ1+UM3383kvGpMf3/v83PPPP7/f82mZfsjdeeedxW/KFy9enIYOHeqtboW+GDBgQHr99deLR7jqtgsuuCCdc845xd/zJaepTF/kzjjjjOLRrbrALffWW28VIZCgp3L9kM8htm+gUxfA/WduPCrB/bplqJ3KQ/1UHuqnclA7lYf6qe1qtvopK8mScPkSuY888kixtNjll19eLKm7YcOG4uPjxo3Lpk6d2mDp9U6dOmV33XVXseT3zJkzLb3eCv1wxx13FEshP/nkk9mHH35Yv23fvr05mtOuNbUv9mU1rtbri3Xr1hWr0v3yl7/M1qxZkz377LNZz549s1tvvbUZW9X+NLUf8vtC3g9//OMfi+Ur//SnP2XHHntssZojBy7//v7aa68VW15C3HPPPcXf33///eLjeR/kfbHv0qG/+c1vivt1TU2NpdebidqpPNRP5aF+Kge1U3mon9p3/VSKsCeXrx1/1FFHFeFBvkTcX//61/qPnX322cUPr3t7/PHHs+OPP744P1+WbOHCha3Q6nia0g9HH3108T/rvlv+QxaV7Yt9CXtaty9eeeWVrLq6uggn8mXYb7vttuzzzz9v5la1P03ph88++yy78cYbi4Cna9euWb9+/bIrr7wy+9e//tVKrY/hxRdfbPT7ft17n/+Z98W+1wwePLjot/zfw+9///tWan08aqfyUD+Vh/qpHNRO5aF+ar/1U1X+n+YddAQAAABAa2n1OXsAAAAAaD7CHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAFIc/w/JLyUiKOJTlAAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 1400x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_evaluation_results(\n",
" atari_results, win_rates_vs_random, save_path=\"plots/evaluation_results.png\",\n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "b9cb30f5",
"metadata": {},
"source": [
"## Championship Match: SARSA vs Q-Learning\n",
"\n",
"The final showdown pits the two trained agents against each other: **SARSA** (on-policy) versus **Q-Learning** (off-policy).\n",
"\n",
"This match directly compares the two TD learning strategies:\n",
"- **SARSA** updates its weights following the policy it actually executes (on-policy): $\\delta = r + \\gamma \\hat{q}(s', a') - \\hat{q}(s, a)$\n",
"- **Q-Learning** learns the optimal policy independently of exploration (off-policy): $\\delta = r + \\gamma \\max_{a'} \\hat{q}(s', a') - \\hat{q}(s, a)$\n",
"\n",
"The championship is played over **2 × 20 episodes** with swapped positions to ensure fairness."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4031bde5",
"metadata": {},
"outputs": [],
"source": [
"def run_all_matchups(\n",
" agents: dict[str, Agent],\n",
" episodes_per_leg: int = 20,\n",
") -> np.ndarray:\n",
" \"\"\"Run a full round-robin tournament and return an n*n win matrix.\n",
"\n",
" matrix[i][j] = wins of agent_i (as first_0) vs agent_j (as second_0).\n",
" Diagonal entries are NaN (no self-play).\n",
" \"\"\"\n",
" env = create_tournament_env()\n",
" names = list(agents.keys())\n",
" n = len(names)\n",
" matrix = np.full((n, n), np.nan)\n",
"\n",
" for i in range(n):\n",
" for j in range(n):\n",
" if i == j:\n",
" continue\n",
" result = run_match(\n",
" env, agents[names[i]], agents[names[j]], episodes=episodes_per_leg,\n",
" )\n",
" matrix[i, j] = result[\"first\"]\n",
"\n",
" env.close()\n",
" return matrix\n",
"\n",
"\n",
"def plot_matchup_matrix(\n",
" matrix: np.ndarray,\n",
" agent_names: list[str],\n",
" episodes_per_leg: int = 20,\n",
" save_path: str = \"plots/championship_matrix.png\",\n",
") -> None:\n",
" \"\"\"Plot the n*n matchup matrix as an annotated heatmap.\"\"\"\n",
" n = len(agent_names)\n",
" pct_matrix = matrix / episodes_per_leg * 100\n",
"\n",
" fig, ax = plt.subplots(figsize=(8, 7))\n",
" masked = np.ma.array(pct_matrix, mask=np.isnan(pct_matrix))\n",
" cmap = plt.cm.RdYlGn.copy()\n",
" cmap.set_bad(color=\"#f0f0f0\")\n",
"\n",
" im = ax.imshow(masked, cmap=cmap, vmin=0, vmax=100, aspect=\"equal\")\n",
"\n",
" ax.set_xticks(np.arange(n))\n",
" ax.set_yticks(np.arange(n))\n",
" ax.set_xticklabels(agent_names, rotation=45, ha=\"right\", fontsize=10)\n",
" ax.set_yticklabels(agent_names, fontsize=10)\n",
" ax.set_xlabel(\"Opponent (second_0)\", fontsize=11)\n",
" ax.set_ylabel(\"Agent (first_0)\", fontsize=11)\n",
" ax.set_title(\"Championship Matchup Matrix\", fontsize=14, fontweight=\"bold\", pad=12)\n",
"\n",
" for i in range(n):\n",
" for j in range(n):\n",
" if i == j:\n",
" ax.text(j, i, \"—\", ha=\"center\", va=\"center\", color=\"#aaa\", fontsize=14)\n",
" else:\n",
" wins = int(matrix[i, j])\n",
" pct = pct_matrix[i, j]\n",
" txt_color = \"white\" if pct > 70 or pct < 30 else \"black\"\n",
" ax.text(j, i, f\"{wins}\\n({pct:.0f} %)\",\n",
" ha=\"center\", va=\"center\", color=txt_color,\n",
" fontsize=9, fontweight=\"bold\", linespacing=1.4)\n",
"\n",
" cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)\n",
" cbar.set_label(f\"Win Rate (%) — {episodes_per_leg} eps per matchup\", fontsize=10)\n",
" cbar.set_ticks([0, 25, 50, 75, 100])\n",
" cbar.set_ticklabels([\"0 %\", \"25 %\", \"50 %\", \"75 %\", \"100 %\"])\n",
"\n",
" plt.tight_layout()\n",
" plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b07b403c",
"metadata": {},
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mTypeError\u001b[39m Traceback (most recent call last)",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/lambda_wrappers/observation_lambda.py:68\u001b[39m, in \u001b[36maec_observation_lambda._modify_observation\u001b[39m\u001b[34m(self, agent, observation)\u001b[39m\n\u001b[32m 67\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m68\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mchange_observation_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobservation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mold_obs_space\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magent\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 69\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n",
"\u001b[31mTypeError\u001b[39m: basic_obs_wrapper.<locals>.change_obs() takes 2 positional arguments but 3 were given",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[41]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m matchup_matrix = \u001b[43mrun_all_matchups\u001b[49m\u001b[43m(\u001b[49m\u001b[43magents\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepisodes_per_leg\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m20\u001b[39;49m\u001b[43m)\u001b[49m\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[40]\u001b[39m\u001b[32m, line 19\u001b[39m, in \u001b[36mrun_all_matchups\u001b[39m\u001b[34m(agents, episodes_per_leg)\u001b[39m\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m i == j:\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m19\u001b[39m result = \u001b[43mrun_match\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 20\u001b[39m \u001b[43m \u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magents\u001b[49m\u001b[43m[\u001b[49m\u001b[43mnames\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magents\u001b[49m\u001b[43m[\u001b[49m\u001b[43mnames\u001b[49m\u001b[43m[\u001b[49m\u001b[43mj\u001b[49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepisodes\u001b[49m\u001b[43m=\u001b[49m\u001b[43mepisodes_per_leg\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 21\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 22\u001b[39m matrix[i, j] = result[\u001b[33m\"\u001b[39m\u001b[33mfirst\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 24\u001b[39m env.close()\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[28]\u001b[39m\u001b[32m, line 37\u001b[39m, in \u001b[36mrun_match\u001b[39m\u001b[34m(env, agent_first, agent_second, episodes, max_steps)\u001b[39m\n\u001b[32m 34\u001b[39m current_agent = agent_first \u001b[38;5;28;01mif\u001b[39;00m agent_id == \u001b[33m\"\u001b[39m\u001b[33mfirst_0\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m agent_second\n\u001b[32m 35\u001b[39m action = current_agent.get_action(np.asarray(obs), epsilon=\u001b[32m0.0\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m37\u001b[39m \u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 39\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m step_idx + \u001b[32m1\u001b[39m >= max_steps:\n\u001b[32m 40\u001b[39m \u001b[38;5;28;01mbreak\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/utils/base_aec_wrapper.py:46\u001b[39m, in \u001b[36mBaseWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 43\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m.terminations[agent] \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m.truncations[agent]):\n\u001b[32m 44\u001b[39m action = \u001b[38;5;28mself\u001b[39m._modify_action(agent, action)\n\u001b[32m---> \u001b[39m\u001b[32m46\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 48\u001b[39m \u001b[38;5;28mself\u001b[39m._update_step(\u001b[38;5;28mself\u001b[39m.agent_selection)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/order_enforcing.py:70\u001b[39m, in \u001b[36mOrderEnforcingWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 68\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 69\u001b[39m \u001b[38;5;28mself\u001b[39m._has_updated = \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m70\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/base.py:47\u001b[39m, in \u001b[36mBaseWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 46\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\u001b[38;5;28mself\u001b[39m, action: ActionType) -> \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m47\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/generic_wrappers/utils/shared_wrapper_util.py:69\u001b[39m, in \u001b[36mshared_wrapper_aec.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 66\u001b[39m \u001b[38;5;28msuper\u001b[39m().step(action)\n\u001b[32m 67\u001b[39m \u001b[38;5;28mself\u001b[39m.add_modifiers(\u001b[38;5;28mself\u001b[39m.agents)\n\u001b[32m 68\u001b[39m \u001b[38;5;28mself\u001b[39m.modifiers[\u001b[38;5;28mself\u001b[39m.agent_selection].modify_obs(\n\u001b[32m---> \u001b[39m\u001b[32m69\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobserve\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43magent_selection\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 70\u001b[39m )\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/order_enforcing.py:75\u001b[39m, in \u001b[36mOrderEnforcingWrapper.observe\u001b[39m\u001b[34m(self, agent)\u001b[39m\n\u001b[32m 73\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._has_reset:\n\u001b[32m 74\u001b[39m EnvLogger.error_observe_before_reset()\n\u001b[32m---> \u001b[39m\u001b[32m75\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobserve\u001b[49m\u001b[43m(\u001b[49m\u001b[43magent\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/base.py:41\u001b[39m, in \u001b[36mBaseWrapper.observe\u001b[39m\u001b[34m(self, agent)\u001b[39m\n\u001b[32m 40\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mobserve\u001b[39m(\u001b[38;5;28mself\u001b[39m, agent: AgentID) -> ObsType | \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m41\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobserve\u001b[49m\u001b[43m(\u001b[49m\u001b[43magent\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/utils/base_aec_wrapper.py:38\u001b[39m, in \u001b[36mBaseWrapper.observe\u001b[39m\u001b[34m(self, agent)\u001b[39m\n\u001b[32m 34\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mobserve\u001b[39m(\u001b[38;5;28mself\u001b[39m, agent):\n\u001b[32m 35\u001b[39m obs = \u001b[38;5;28msuper\u001b[39m().observe(\n\u001b[32m 36\u001b[39m agent\n\u001b[32m 37\u001b[39m ) \u001b[38;5;66;03m# problem is in this line, the obs is sometimes a different size from the obs space\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m38\u001b[39m observation = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_modify_observation\u001b[49m\u001b[43m(\u001b[49m\u001b[43magent\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 39\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m observation\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/lambda_wrappers/observation_lambda.py:70\u001b[39m, in \u001b[36maec_observation_lambda._modify_observation\u001b[39m\u001b[34m(self, agent, observation)\u001b[39m\n\u001b[32m 68\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.change_observation_fn(observation, old_obs_space, agent)\n\u001b[32m 69\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m70\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mchange_observation_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobservation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mold_obs_space\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/generic_wrappers/basic_wrappers.py:19\u001b[39m, in \u001b[36mbasic_obs_wrapper.<locals>.change_obs\u001b[39m\u001b[34m(obs, obs_space)\u001b[39m\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mchange_obs\u001b[39m(obs, obs_space):\n\u001b[32m---> \u001b[39m\u001b[32m19\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodule\u001b[49m\u001b[43m.\u001b[49m\u001b[43mchange_observation\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobs_space\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparam\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/utils/basic_transforms/resize.py:21\u001b[39m, in \u001b[36mchange_observation\u001b[39m\u001b[34m(obs, obs_space, resize)\u001b[39m\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mchange_obs_space\u001b[39m(obs_space, param):\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m convert_box(\u001b[38;5;28;01mlambda\u001b[39;00m obs: change_observation(obs, obs_space, param), obs_space)\n\u001b[32m---> \u001b[39m\u001b[32m21\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mchange_observation\u001b[39m(obs, obs_space, resize):\n\u001b[32m 22\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtinyscaler\u001b[39;00m\n\u001b[32m 24\u001b[39m xsize, ysize, linear_interp = resize\n",
"\u001b[31mKeyboardInterrupt\u001b[39m: "
]
}
],
"source": [
"matchup_matrix = run_all_matchups(agents, episodes_per_leg=20)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9668071a",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'matchup_matrix' is not defined",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mNameError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[42]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m plot_matchup_matrix(\u001b[43mmatchup_matrix\u001b[49m, episodes_per_leg=EPISODES_PER_LEG)\n",
"\u001b[31mNameError\u001b[39m: name 'matchup_matrix' is not defined"
]
}
],
"source": [
"plot_matchup_matrix(matchup_matrix, list(agents.keys()), episodes_per_leg=20)\n"
]
},
{
"cell_type": "markdown",
"id": "32c37e5d",
"metadata": {},
"source": [
"# Conclusion\n",
"\n",
"This project implemented and compared five Reinforcement Learning agents on Atari Tennis:\n",
"\n",
"| Agent | Type | Policy | Update Rule |\n",
"|-------|------|--------|-------------|\n",
"| **Random** | Baseline | Uniform random | None |\n",
"| **SARSA** | TD(0), on-policy | ε-greedy | $W_a \\leftarrow W_a + \\alpha \\cdot (r + \\gamma \\hat{q}(s', a') - \\hat{q}(s, a)) \\cdot \\phi(s)$ |\n",
"| **Q-Learning** | TD(0), off-policy | ε-greedy | $W_a \\leftarrow W_a + \\alpha \\cdot (r + \\gamma \\max_{a'} \\hat{q}(s', a') - \\hat{q}(s, a)) \\cdot \\phi(s)$ |\n",
"| **Monte Carlo** | First-visit MC | ε-greedy | $W_a \\leftarrow W_a + \\alpha \\cdot (G_t - \\hat{q}(s, a)) \\cdot \\phi(s)$ |\n",
"| **DQN** | Deep Q-Network | ε-greedy | Neural network (MLP 256→256) with experience replay and target network |\n",
"\n",
"**Architecture**:\n",
"- **Linear agents** (SARSA, Q-Learning, Monte Carlo): $\\hat{q}(s, a; \\mathbf{W}) = \\mathbf{W}_a^\\top \\phi(s)$ with $\\phi(s) \\in \\mathbb{R}^{28\\,224}$ (4 grayscale 84×84 frames, normalized)\n",
"- **DQN**: MLP network (28,224 → 256 → 256 → 18) trained with Adam optimizer, Huber loss, and periodic target network sync\n",
"\n",
"**Methodology**:\n",
"1. **Pre-training** each agent individually against Atari's built-in AI (5,000 episodes, ε decaying from 1.0 to 0.05)\n",
"2. **Evaluation vs Random** to validate learning (expected win rate > 50%)\n",
"3. **Head-to-head tournament** in matches via PettingZoo (2 × 20 episodes)\n",
"\n",
"> ⚠️ **Known Issue**: Monte Carlo and DQN agent checkpoints have loading issues. Their code is preserved here for reference."
]
},
{
"cell_type": "markdown",
"id": "15f80f84",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "studies (3.13.9)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}