Files
ArtStudies/M2/Reinforcement Learning/project/Project.ipynb
Arthur DANJOU 63ebb3ec8d Enhance Monte Carlo agent with performance optimizations and memory efficiency
- Updated weight and feature storage to use float32 for reduced memory bandwidth.
- Implemented compact storage for raw observations as uint8, batch-normalized at episode end.
- Introduced vectorized return computation and chunk-based weight updates using einsum.
- Reduced weight sanitization to once per episode instead of per-step.
- Refactored action selection and return calculation for improved efficiency.
2026-03-04 18:25:43 +01:00

1364 lines
57 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "fef45687",
"metadata": {},
"source": [
"# RL Project: Atari Tennis Tournament\n",
"\n",
"This notebook implements four Reinforcement Learning algorithms to play Atari Tennis (`ALE/Tennis-v5` via Gymnasium):\n",
"\n",
"1. **SARSA** — Semi-gradient SARSA with linear approximation (inspired by Lab 7, on-policy update from Lab 5B)\n",
"2. **Q-Learning** — Off-policy linear approximation (inspired by Lab 5B)\n",
"3. **Monte Carlo** — First-visit MC control with linear approximation (inspired by Lab 4)\n",
"\n",
"Each agent is **pre-trained independently** against the built-in Atari AI opponent, then evaluated in a comparative tournament."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b50d7174",
"metadata": {},
"outputs": [],
"source": [
"import itertools\n",
"import pickle\n",
"from pathlib import Path\n",
"\n",
"import ale_py # noqa: F401 — registers ALE environments\n",
"import gymnasium as gym\n",
"import supersuit as ss\n",
"from gymnasium.wrappers import FrameStackObservation, ResizeObservation\n",
"from pettingzoo.atari import tennis_v3\n",
"from tqdm.auto import tqdm\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import seaborn as sns\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ff3486a4",
"metadata": {},
"outputs": [],
"source": [
"CHECKPOINT_DIR = Path(\"checkpoints\")\n",
"CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n",
"\n",
"\n",
"def get_path(name: str) -> Path:\n",
" \"\"\"Return the checkpoint path for an agent (.pkl).\"\"\"\n",
" base = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\")\n",
" return CHECKPOINT_DIR / (base + \".pkl\")\n"
]
},
{
"cell_type": "markdown",
"id": "ec691487",
"metadata": {},
"source": [
"# Utility Functions\n",
"\n",
"## Observation Normalization\n",
"\n",
"The Tennis environment produces image observations of shape `(4, 84, 84)` after preprocessing (grayscale + resize + frame stack).\n",
"We normalize them into 1D `float64` vectors divided by 255, as in Lab 7 (continuous feature normalization).\n",
"\n",
"## ε-greedy Policy\n",
"\n",
"Follows the pattern from Lab 5B (`epsilon_greedy`) and Lab 7 (`epsilon_greedy_action`):\n",
"- With probability ε: random action (exploration)\n",
"- With probability 1ε: action maximizing $\\hat{q}(s, a)$ with uniform tie-breaking (`np.flatnonzero`)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "be85c130",
"metadata": {},
"outputs": [],
"source": [
"def normalize_obs(observation: np.ndarray) -> np.ndarray:\n",
" \"\"\"Flatten and normalize an observation to a 1D float64 vector.\n",
"\n",
" Replicates the /255.0 normalization used in all agents from the original project.\n",
" For image observations of shape (4, 84, 84), this produces a vector of length 28_224.\n",
"\n",
" Args:\n",
" observation: Raw observation array from the environment.\n",
"\n",
" Returns:\n",
" 1D numpy array of dtype float64, values in [0, 1].\n",
"\n",
" \"\"\"\n",
" return observation.flatten().astype(np.float64) / 255.0\n",
"\n",
"\n",
"def epsilon_greedy(\n",
" q_values: np.ndarray,\n",
" epsilon: float,\n",
" rng: np.random.Generator,\n",
") -> int:\n",
" \"\"\"Select an action using an ε-greedy policy with fair tie-breaking.\n",
"\n",
" Follows the same logic as Lab 5B epsilon_greedy and Lab 7 epsilon_greedy_action:\n",
" - With probability epsilon: choose a random action (exploration).\n",
" - With probability 1-epsilon: choose the action with highest Q-value (exploitation).\n",
" - If multiple actions share the maximum Q-value, break ties uniformly at random.\n",
"\n",
" Handles edge cases: empty q_values, NaN/Inf values.\n",
"\n",
" Args:\n",
" q_values: Array of Q-values for each action, shape (n_actions,).\n",
" epsilon: Exploration probability in [0, 1].\n",
" rng: NumPy random number generator.\n",
"\n",
" Returns:\n",
" Selected action index.\n",
"\n",
" \"\"\"\n",
" q_values = np.asarray(q_values, dtype=np.float64).reshape(-1)\n",
"\n",
" if q_values.size == 0:\n",
" msg = \"q_values is empty.\"\n",
" raise ValueError(msg)\n",
"\n",
" if rng.random() < epsilon:\n",
" return int(rng.integers(0, q_values.size))\n",
"\n",
" # Handle NaN/Inf values safely\n",
" finite_mask = np.isfinite(q_values)\n",
" if not np.any(finite_mask):\n",
" return int(rng.integers(0, q_values.size))\n",
"\n",
" safe_q = q_values.copy()\n",
" safe_q[~finite_mask] = -np.inf\n",
" max_val = np.max(safe_q)\n",
" best = np.flatnonzero(safe_q == max_val)\n",
"\n",
" if best.size == 0:\n",
" return int(rng.integers(0, q_values.size))\n",
"\n",
" return int(rng.choice(best))\n"
]
},
{
"cell_type": "markdown",
"id": "bb53da28",
"metadata": {},
"source": [
"# Agent Definitions\n",
"\n",
"## Base Class `Agent`\n",
"\n",
"Common interface for all agents, same signatures: `get_action`, `update`, `save`, `load`.\n",
"Serialization uses `pickle` (compatible with numpy arrays)."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ded9b1fb",
"metadata": {},
"outputs": [],
"source": [
"class Agent:\n",
" \"\"\"Base class for reinforcement learning agents.\n",
"\n",
" All agents share this interface so they are compatible with the tournament system.\n",
" \"\"\"\n",
"\n",
" def __init__(self, seed: int, action_space: int) -> None:\n",
" \"\"\"Initialize the agent with its action space and a reproducible RNG.\"\"\"\n",
" self.action_space = action_space\n",
" self.rng = np.random.default_rng(seed=seed)\n",
"\n",
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
" \"\"\"Select an action from the current observation.\"\"\"\n",
" raise NotImplementedError\n",
"\n",
" def update(\n",
" self,\n",
" state: np.ndarray,\n",
" action: int,\n",
" reward: float,\n",
" next_state: np.ndarray,\n",
" done: bool,\n",
" next_action: int | None = None,\n",
" ) -> None:\n",
" \"\"\"Update agent parameters from one transition.\"\"\"\n",
"\n",
" def save(self, filename: str) -> None:\n",
" \"\"\"Save the agent state to disk using pickle.\"\"\"\n",
" with Path(filename).open(\"wb\") as f:\n",
" pickle.dump(self.__dict__, f)\n",
"\n",
" def load(self, filename: str) -> None:\n",
" \"\"\"Load the agent state from disk.\"\"\"\n",
" with Path(filename).open(\"rb\") as f:\n",
" self.__dict__.update(pickle.load(f)) # noqa: S301\n"
]
},
{
"cell_type": "markdown",
"id": "8a4eae79",
"metadata": {},
"source": [
"## Random Agent (baseline)\n",
"\n",
"Serves as a reference to evaluate the performance of learning agents."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "78bdc9d2",
"metadata": {},
"outputs": [],
"source": [
"class RandomAgent(Agent):\n",
" \"\"\"A simple agent that selects actions uniformly at random (baseline).\"\"\"\n",
"\n",
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
" \"\"\"Select a random action, ignoring the observation and epsilon.\"\"\"\n",
" _ = observation, epsilon\n",
" return int(self.rng.integers(0, self.action_space))\n"
]
},
{
"cell_type": "markdown",
"id": "5f679032",
"metadata": {},
"source": [
"## SARSA Agent — Linear Approximation (Semi-gradient)\n",
"\n",
"This agent combines:\n",
"- **Linear approximation** from Lab 7 (`SarsaAgent`): $\\hat{q}(s, a; \\mathbf{W}) = \\mathbf{W}_a^\\top \\phi(s)$\n",
"- **On-policy SARSA update** from Lab 5B (`train_sarsa`): $\\delta = r + \\gamma \\hat{q}(s', a') - \\hat{q}(s, a)$\n",
"\n",
"The semi-gradient update rule is:\n",
"$$W_a \\leftarrow W_a + \\alpha \\cdot \\delta \\cdot \\phi(s)$$\n",
"\n",
"where $\\phi(s)$ is the normalized observation vector (analogous to tile coding features in Lab 7, but in dense form)."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c124ed9a",
"metadata": {},
"outputs": [],
"source": [
"class SarsaAgent(Agent):\n",
" \"\"\"Semi-gradient SARSA agent with linear function approximation.\n",
"\n",
" Inspired by:\n",
" - Lab 7 SarsaAgent: linear q(s,a) = W_a . phi(s), semi-gradient update\n",
" - Lab 5B train_sarsa: on-policy TD target using Q(s', a')\n",
"\n",
" The weight matrix W has shape (n_actions, n_features).\n",
" For a given state s, q(s, a) = W[a] @ phi(s) is the dot product\n",
" of the action's weight row with the normalized observation.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" n_features: int,\n",
" n_actions: int,\n",
" alpha: float = 0.001,\n",
" gamma: float = 0.99,\n",
" seed: int = 42,\n",
" ) -> None:\n",
" \"\"\"Initialize SARSA agent with linear weights.\n",
"\n",
" Args:\n",
" n_features: Dimension of the feature vector phi(s).\n",
" n_actions: Number of discrete actions.\n",
" alpha: Learning rate (kept small for high-dim features).\n",
" gamma: Discount factor.\n",
" seed: RNG seed for reproducibility.\n",
"\n",
" \"\"\"\n",
" super().__init__(seed, n_actions)\n",
" self.n_features = n_features\n",
" self.alpha = alpha\n",
" self.gamma = gamma\n",
" # Weight matrix: one row per action, analogous to Lab 7's self.w\n",
" # but organized as (n_actions, n_features) for dense features.\n",
" self.W = np.zeros((n_actions, n_features), dtype=np.float64)\n",
"\n",
" def _q_values(self, phi: np.ndarray) -> np.ndarray:\n",
" \"\"\"Compute Q-values for all actions given feature vector phi(s).\n",
"\n",
" Equivalent to Lab 7's self.q(s, a) = self.w[idx].sum()\n",
" but using dense linear approximation: q(s, a) = W[a] @ phi.\n",
"\n",
" Args:\n",
" phi: Normalized feature vector, shape (n_features,).\n",
"\n",
" Returns:\n",
" Array of Q-values, shape (n_actions,).\n",
"\n",
" \"\"\"\n",
" return self.W @ phi # shape (n_actions,)\n",
"\n",
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
" \"\"\"Select action using ε-greedy policy over linear Q-values.\n",
"\n",
" Same pattern as Lab 7 SarsaAgent.eps_greedy:\n",
" compute q-values for all actions, then apply epsilon_greedy.\n",
" \"\"\"\n",
" phi = normalize_obs(observation)\n",
" q_vals = self._q_values(phi)\n",
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
"\n",
" def update(\n",
" self,\n",
" state: np.ndarray,\n",
" action: int,\n",
" reward: float,\n",
" next_state: np.ndarray,\n",
" done: bool,\n",
" next_action: int | None = None,\n",
" ) -> None:\n",
" \"\"\"Perform one semi-gradient SARSA update.\n",
"\n",
" Follows the SARSA update from Lab 5B train_sarsa:\n",
" td_target = r + gamma * Q(s', a') * (0 if done else 1)\n",
" Q(s, a) += alpha * (td_target - Q(s, a))\n",
"\n",
" In continuous form with linear approximation (Lab 7 SarsaAgent.update):\n",
" delta = target - q(s, a)\n",
" W[a] += alpha * delta * phi(s)\n",
"\n",
" Args:\n",
" state: Current observation.\n",
" action: Action taken.\n",
" reward: Reward received.\n",
" next_state: Next observation.\n",
" done: Whether the episode ended.\n",
" next_action: Action chosen in next state (required for SARSA).\n",
"\n",
" \"\"\"\n",
" phi = np.nan_to_num(normalize_obs(state), nan=0.0, posinf=0.0, neginf=0.0)\n",
" q_sa = float(self.W[action] @ phi) # current estimate q(s, a)\n",
" if not np.isfinite(q_sa):\n",
" q_sa = 0.0\n",
"\n",
" if done:\n",
" # Terminal: no future value (Lab 5B: gamma * Q[s2, a2] * 0)\n",
" target = reward\n",
" else:\n",
" # On-policy: use q(s', a') where a' is the actual next action\n",
" # This is the key SARSA property (Lab 5B)\n",
" phi_next = np.nan_to_num(normalize_obs(next_state), nan=0.0, posinf=0.0, neginf=0.0)\n",
" if next_action is None:\n",
" next_action = 0 # fallback, should not happen in practice\n",
" q_sp_ap = float(self.W[next_action] @ phi_next)\n",
" if not np.isfinite(q_sp_ap):\n",
" q_sp_ap = 0.0\n",
" target = float(reward) + self.gamma * q_sp_ap\n",
"\n",
" # Semi-gradient update: W[a] += alpha * delta * phi(s)\n",
" # Analogous to Lab 7: self.w[idx] += self.alpha * delta\n",
" if not np.isfinite(target):\n",
" return\n",
"\n",
" delta = float(target - q_sa)\n",
" if not np.isfinite(delta):\n",
" return\n",
"\n",
" td_step = float(np.clip(delta, -1_000.0, 1_000.0))\n",
" self.W[action] += self.alpha * td_step * phi\n",
" self.W[action] = np.nan_to_num(self.W[action], nan=0.0, posinf=1e6, neginf=-1e6)\n"
]
},
{
"cell_type": "markdown",
"id": "d4e18536",
"metadata": {},
"source": [
"## Q-Learning Agent — Linear Approximation (Off-policy)\n",
"\n",
"Same architecture as SARSA but with the **off-policy update** from Lab 5B (`train_q_learning`):\n",
"\n",
"$$\\delta = r + \\gamma \\max_{a'} \\hat{q}(s', a') - \\hat{q}(s, a)$$\n",
"\n",
"The key difference from SARSA: we use $\\max_{a'} Q(s', a')$ instead of $Q(s', a')$ where $a'$ is the action actually chosen. This allows learning the optimal policy independently of the exploration policy."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f5b5b9ea",
"metadata": {},
"outputs": [],
"source": [
"class QLearningAgent(Agent):\n",
" \"\"\"Q-Learning agent with linear function approximation (off-policy).\n",
"\n",
" Inspired by:\n",
" - Lab 5B train_q_learning: off-policy TD target using max_a' Q(s', a')\n",
" - Lab 7 SarsaAgent: linear approximation q(s,a) = W[a] @ phi(s)\n",
"\n",
" The only difference from SarsaAgent is the TD target:\n",
" SARSA uses Q(s', a') (on-policy), Q-Learning uses max_a' Q(s', a') (off-policy).\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" n_features: int,\n",
" n_actions: int,\n",
" alpha: float = 0.001,\n",
" gamma: float = 0.99,\n",
" seed: int = 42,\n",
" ) -> None:\n",
" \"\"\"Initialize Q-Learning agent with linear weights.\n",
"\n",
" Args:\n",
" n_features: Dimension of the feature vector phi(s).\n",
" n_actions: Number of discrete actions.\n",
" alpha: Learning rate.\n",
" gamma: Discount factor.\n",
" seed: RNG seed.\n",
"\n",
" \"\"\"\n",
" super().__init__(seed, n_actions)\n",
" self.n_features = n_features\n",
" self.alpha = alpha\n",
" self.gamma = gamma\n",
" self.W = np.zeros((n_actions, n_features), dtype=np.float64)\n",
"\n",
" def _q_values(self, phi: np.ndarray) -> np.ndarray:\n",
" \"\"\"Compute Q-values for all actions: q(s, a) = W[a] @ phi for each a.\"\"\"\n",
" return self.W @ phi\n",
"\n",
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
" \"\"\"Select action using ε-greedy policy over linear Q-values.\"\"\"\n",
" phi = normalize_obs(observation)\n",
" q_vals = self._q_values(phi)\n",
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
"\n",
" def update(\n",
" self,\n",
" state: np.ndarray,\n",
" action: int,\n",
" reward: float,\n",
" next_state: np.ndarray,\n",
" done: bool,\n",
" next_action: int | None = None,\n",
" ) -> None:\n",
" \"\"\"Perform one Q-learning update.\n",
"\n",
" Follows Lab 5B train_q_learning:\n",
" td_target = r + gamma * max(Q[s2]) * (0 if terminated else 1)\n",
" Q[s, a] += alpha * (td_target - Q[s, a])\n",
"\n",
" In continuous form with linear approximation:\n",
" delta = target - q(s, a)\n",
" W[a] += alpha * delta * phi(s)\n",
" \"\"\"\n",
" _ = next_action # Q-learning is off-policy: next_action is not used\n",
" phi = np.nan_to_num(normalize_obs(state), nan=0.0, posinf=0.0, neginf=0.0)\n",
" q_sa = float(self.W[action] @ phi)\n",
" if not np.isfinite(q_sa):\n",
" q_sa = 0.0\n",
"\n",
" if done:\n",
" # Terminal state: no future value\n",
" # Lab 5B: gamma * np.max(Q[s2]) * (0 if terminated else 1)\n",
" target = reward\n",
" else:\n",
" # Off-policy: use max over all actions in next state\n",
" # This is the key Q-learning property (Lab 5B)\n",
" phi_next = np.nan_to_num(normalize_obs(next_state), nan=0.0, posinf=0.0, neginf=0.0)\n",
" q_next_all = self._q_values(phi_next) # q(s', a') for all a'\n",
" q_next_max = float(np.max(q_next_all))\n",
" if not np.isfinite(q_next_max):\n",
" q_next_max = 0.0\n",
" target = float(reward) + self.gamma * q_next_max\n",
"\n",
" if not np.isfinite(target):\n",
" return\n",
"\n",
" delta = float(target - q_sa)\n",
" if not np.isfinite(delta):\n",
" return\n",
"\n",
" td_step = float(np.clip(delta, -1_000.0, 1_000.0))\n",
" self.W[action] += self.alpha * td_step * phi\n",
" self.W[action] = np.nan_to_num(self.W[action], nan=0.0, posinf=1e6, neginf=-1e6)\n"
]
},
{
"cell_type": "markdown",
"id": "b7b63455",
"metadata": {},
"source": [
"## Monte Carlo Agent — Linear Approximation (First-visit)\n",
"\n",
"This agent is inspired by Lab 4 (`mc_control_epsilon_soft`):\n",
"- Accumulates transitions in an episode buffer `(state, action, reward)`\n",
"- At the end of the episode (`done=True`), computes **cumulative returns** by traversing the buffer backward:\n",
" $$G \\leftarrow \\gamma \\cdot G + r$$\n",
"- Updates weights with the semi-gradient rule:\n",
" $$W_a \\leftarrow W_a + \\alpha \\cdot (G - \\hat{q}(s, a)) \\cdot \\phi(s)$$\n",
"\n",
"Unlike TD methods (SARSA, Q-Learning), Monte Carlo waits for the complete episode to finish before updating."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3c9d74be",
"metadata": {},
"outputs": [],
"source": [
"class MonteCarloAgent(Agent):\n",
" \"\"\"Monte Carlo control agent with linear function approximation.\n",
"\n",
" Inspired by Lab 4 mc_control_epsilon_soft:\n",
" - Accumulates transitions in an episode buffer\n",
" - At episode end (done=True), computes discounted returns backward:\n",
" G = gamma * G + r (same as Lab 4's reversed loop)\n",
" - Updates weights with semi-gradient: W[a] += alpha * (G - q(s,a)) * phi(s)\n",
"\n",
" Unlike TD methods (SARSA, Q-Learning), no update occurs until the episode ends.\n",
"\n",
" Performance optimizations over naive per-step implementation:\n",
" - float32 weights & features (halves memory bandwidth, faster SIMD)\n",
" - Raw observations stored compactly as uint8, batch-normalized at episode end\n",
" - Vectorized return computation & chunk-based weight updates via einsum\n",
" - Single weight sanitization per episode instead of per-step\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" n_features: int,\n",
" n_actions: int,\n",
" alpha: float = 0.001,\n",
" gamma: float = 0.99,\n",
" seed: int = 42,\n",
" ) -> None:\n",
" super().__init__(seed, n_actions)\n",
" self.n_features = n_features\n",
" self.alpha = alpha\n",
" self.gamma = gamma\n",
" self.W = np.zeros((n_actions, n_features), dtype=np.float32)\n",
" self._obs_buf: list[np.ndarray] = []\n",
" self._act_buf: list[int] = []\n",
" self._rew_buf: list[float] = []\n",
"\n",
" def _q_values(self, phi: np.ndarray) -> np.ndarray:\n",
" return self.W @ phi\n",
"\n",
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
" phi = observation.flatten().astype(np.float32) / np.float32(255.0)\n",
" q_vals = self._q_values(phi)\n",
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
"\n",
" def update(\n",
" self,\n",
" state: np.ndarray,\n",
" action: int,\n",
" reward: float,\n",
" next_state: np.ndarray,\n",
" done: bool,\n",
" next_action: int | None = None,\n",
" ) -> None:\n",
" _ = next_state, next_action\n",
"\n",
" self._obs_buf.append(state)\n",
" self._act_buf.append(action)\n",
" self._rew_buf.append(reward)\n",
"\n",
" if not done:\n",
" return\n",
"\n",
" n = len(self._rew_buf)\n",
" actions = np.array(self._act_buf, dtype=np.intp)\n",
"\n",
" returns = np.empty(n, dtype=np.float32)\n",
" G = np.float32(0.0)\n",
" gamma32 = np.float32(self.gamma)\n",
" for i in range(n - 1, -1, -1):\n",
" G = gamma32 * G + np.float32(self._rew_buf[i])\n",
" returns[i] = G\n",
"\n",
" alpha32 = np.float32(self.alpha)\n",
" chunk_size = 500\n",
" for start in range(0, n, chunk_size):\n",
" end = min(start + chunk_size, n)\n",
" cs = end - start\n",
"\n",
" raw = np.array(self._obs_buf[start:end])\n",
" phi = raw.reshape(cs, -1).astype(np.float32)\n",
" phi /= np.float32(255.0)\n",
"\n",
" ca = actions[start:end]\n",
" q_sa = np.einsum(\"ij,ij->i\", self.W[ca], phi)\n",
"\n",
" deltas = np.clip(returns[start:end] - q_sa, -1000.0, 1000.0)\n",
"\n",
" for a in range(self.action_space):\n",
" mask = ca == a\n",
" if not np.any(mask):\n",
" continue\n",
" self.W[a] += alpha32 * (deltas[mask] @ phi[mask])\n",
"\n",
" self.W = np.nan_to_num(self.W, nan=0.0, posinf=1e6, neginf=-1e6)\n",
"\n",
" self._obs_buf.clear()\n",
" self._act_buf.clear()\n",
" self._rew_buf.clear()\n"
]
},
{
"cell_type": "markdown",
"id": "91e51dc8",
"metadata": {},
"source": [
"## Tennis Environment\n",
"\n",
"Creation of the Atari Tennis environment via Gymnasium (`ALE/Tennis-v5`) with standard wrappers:\n",
"- **Grayscale**: `obs_type=\"grayscale\"` — single-channel observations\n",
"- **Resize**: `ResizeObservation(84, 84)` — downscale to 84×84\n",
"- **Frame stack**: `FrameStackObservation(4)` — stack 4 consecutive frames\n",
"\n",
"The final observation is an array of shape `(4, 84, 84)`, which flattens to 28,224 features.\n",
"\n",
"The agent plays against the **built-in Atari AI opponent**."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f9a973dd",
"metadata": {},
"outputs": [],
"source": [
"def create_env() -> gym.Env:\n",
" \"\"\"Create the ALE/Tennis-v5 environment with preprocessing wrappers.\n",
"\n",
" Applies:\n",
" - obs_type=\"grayscale\": grayscale observation (210, 160)\n",
" - ResizeObservation(84, 84): downscale to 84x84\n",
" - FrameStackObservation(4): stack 4 consecutive frames -> (4, 84, 84)\n",
"\n",
" Returns:\n",
" Gymnasium environment ready for training.\n",
"\n",
" \"\"\"\n",
" env = gym.make(\"ALE/Tennis-v5\", obs_type=\"grayscale\")\n",
" env = ResizeObservation(env, shape=(84, 84))\n",
" return FrameStackObservation(env, stack_size=4)\n"
]
},
{
"cell_type": "markdown",
"id": "18cb28d8",
"metadata": {},
"source": [
"## Training & Evaluation Infrastructure\n",
"\n",
"Functions for training and evaluating agents in the single-agent Gymnasium environment:\n",
"\n",
"1. **`train_agent`** — Pre-trains an agent against the built-in AI for a given number of episodes with ε-greedy exploration\n",
"2. **`evaluate_agent`** — Evaluates a trained agent (no exploration, ε = 0) and returns performance metrics\n",
"3. **`plot_training_curves`** — Plots the training reward history (moving average) for all agents\n",
"4. **`plot_evaluation_comparison`** — Bar chart comparing final evaluation scores across agents\n",
"5. **`evaluate_tournament`** — Evaluates all agents and produces a summary comparison"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "06b91580",
"metadata": {},
"outputs": [],
"source": [
"def train_agent(\n",
" env: gym.Env,\n",
" agent: Agent,\n",
" name: str,\n",
" *,\n",
" episodes: int = 5000,\n",
" epsilon_start: float = 1.0,\n",
" epsilon_end: float = 0.05,\n",
" epsilon_decay: float = 0.999,\n",
" max_steps: int = 5000,\n",
") -> list[float]:\n",
" \"\"\"Pre-train an agent against the built-in Atari AI opponent.\n",
"\n",
" Each agent learns independently by playing full episodes. This is the\n",
" self-play pre-training phase: the agent interacts with the environment's\n",
" built-in opponent and updates its parameters after each transition.\n",
"\n",
" Args:\n",
" env: Gymnasium ALE/Tennis-v5 environment.\n",
" agent: Agent instance to train.\n",
" name: Display name for the progress bar.\n",
" episodes: Number of training episodes.\n",
" epsilon_start: Initial exploration rate.\n",
" epsilon_end: Minimum exploration rate.\n",
" epsilon_decay: Multiplicative decay per episode.\n",
" max_steps: Maximum steps per episode.\n",
"\n",
" Returns:\n",
" List of total rewards per episode.\n",
"\n",
" \"\"\"\n",
" rewards_history: list[float] = []\n",
" epsilon = epsilon_start\n",
"\n",
" pbar = tqdm(range(episodes), desc=f\"Training {name}\", leave=True)\n",
"\n",
" for _ep in pbar:\n",
" obs, _info = env.reset()\n",
" obs = np.asarray(obs)\n",
" total_reward = 0.0\n",
"\n",
" # Select first action\n",
" action = agent.get_action(obs, epsilon=epsilon)\n",
"\n",
" for _step in range(max_steps):\n",
" next_obs, reward, terminated, truncated, _info = env.step(action)\n",
" next_obs = np.asarray(next_obs)\n",
" done = terminated or truncated\n",
" reward = float(reward)\n",
" total_reward += reward\n",
"\n",
" # Select next action (needed for SARSA's on-policy update)\n",
" next_action = agent.get_action(next_obs, epsilon=epsilon) if not done else None\n",
"\n",
" # Update agent with the transition\n",
" agent.update(\n",
" state=obs,\n",
" action=action,\n",
" reward=reward,\n",
" next_state=next_obs,\n",
" done=done,\n",
" next_action=next_action,\n",
" )\n",
"\n",
" if done:\n",
" break\n",
"\n",
" obs = next_obs\n",
" action = next_action\n",
"\n",
" rewards_history.append(total_reward)\n",
" epsilon = max(epsilon_end, epsilon * epsilon_decay)\n",
"\n",
" # Update progress bar\n",
" recent_window = 50\n",
" if len(rewards_history) >= recent_window:\n",
" recent_avg = np.mean(rewards_history[-recent_window:])\n",
" pbar.set_postfix(\n",
" avg50=f\"{recent_avg:.1f}\",\n",
" eps=f\"{epsilon:.3f}\",\n",
" rew=f\"{total_reward:.0f}\",\n",
" )\n",
"\n",
" return rewards_history\n",
"\n",
"\n",
"def evaluate_agent(\n",
" env: gym.Env,\n",
" agent: Agent,\n",
" name: str,\n",
" *,\n",
" episodes: int = 20,\n",
" max_steps: int = 5000,\n",
") -> dict[str, object]:\n",
" \"\"\"Evaluate a trained agent with no exploration (ε = 0).\n",
"\n",
" Args:\n",
" env: Gymnasium ALE/Tennis-v5 environment.\n",
" agent: Trained agent to evaluate.\n",
" name: Display name for the progress bar.\n",
" episodes: Number of evaluation episodes.\n",
" max_steps: Maximum steps per episode.\n",
"\n",
" Returns:\n",
" Dictionary with rewards list, mean, std, wins, and win rate.\n",
"\n",
" \"\"\"\n",
" rewards: list[float] = []\n",
" wins = 0\n",
"\n",
" for _ep in tqdm(range(episodes), desc=f\"Evaluating {name}\", leave=False):\n",
" obs, _info = env.reset()\n",
" total_reward = 0.0\n",
"\n",
" for _step in range(max_steps):\n",
" action = agent.get_action(np.asarray(obs), epsilon=0.0)\n",
" obs, reward, terminated, truncated, _info = env.step(action)\n",
" reward = float(reward)\n",
" total_reward += reward\n",
" if terminated or truncated:\n",
" break\n",
"\n",
" rewards.append(total_reward)\n",
" if total_reward > 0:\n",
" wins += 1\n",
"\n",
" return {\n",
" \"rewards\": rewards,\n",
" \"mean_reward\": float(np.mean(rewards)),\n",
" \"std_reward\": float(np.std(rewards)),\n",
" \"wins\": wins,\n",
" \"win_rate\": wins / episodes,\n",
" }\n",
"\n",
"\n",
"def plot_training_curves(\n",
" training_histories: dict[str, list[float]],\n",
" path: str,\n",
" window: int = 100,\n",
") -> None:\n",
" \"\"\"Plot training reward curves for all agents on a single figure.\n",
"\n",
" Uses a moving average to smooth the curves.\n",
"\n",
" Args:\n",
" training_histories: Dict mapping agent names to reward lists.\n",
" path: File path to save the plot image.\n",
" window: Moving average window size.\n",
"\n",
" \"\"\"\n",
" plt.figure(figsize=(12, 6))\n",
"\n",
" for name, rewards in training_histories.items():\n",
" if len(rewards) >= window:\n",
" ma = np.convolve(rewards, np.ones(window) / window, mode=\"valid\")\n",
" plt.plot(np.arange(window - 1, len(rewards)), ma, label=name)\n",
" else:\n",
" plt.plot(rewards, label=f\"{name} (raw)\")\n",
"\n",
" plt.xlabel(\"Episodes\")\n",
" plt.ylabel(f\"Average Reward (Window={window})\")\n",
" plt.title(\"Training Curves (vs built-in AI)\")\n",
" plt.legend()\n",
" plt.grid(visible=True)\n",
" plt.tight_layout()\n",
" plt.savefig(path)\n",
" plt.show()\n",
"\n",
"\n",
"def plot_evaluation_comparison(results: dict[str, dict[str, object]]) -> None:\n",
" \"\"\"Bar chart comparing evaluation performance of all agents.\n",
"\n",
" Args:\n",
" results: Dict mapping agent names to evaluation result dicts.\n",
"\n",
" \"\"\"\n",
" names = list(results.keys())\n",
" means = [results[n][\"mean_reward\"] for n in names]\n",
" stds = [results[n][\"std_reward\"] for n in names]\n",
" win_rates = [results[n][\"win_rate\"] for n in names]\n",
"\n",
" _fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
"\n",
" # Mean reward bar chart\n",
" colors = sns.color_palette(\"husl\", len(names))\n",
" axes[0].bar(names, means, yerr=stds, capsize=5, color=colors, edgecolor=\"black\")\n",
" axes[0].set_ylabel(\"Mean Reward\")\n",
" axes[0].set_title(\"Evaluation: Mean Reward per Agent (vs built-in AI)\")\n",
" axes[0].axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.5)\n",
" axes[0].grid(axis=\"y\", alpha=0.3)\n",
"\n",
" # Win rate bar chart\n",
" axes[1].bar(names, win_rates, color=colors, edgecolor=\"black\")\n",
" axes[1].set_ylabel(\"Win Rate\")\n",
" axes[1].set_title(\"Evaluation: Win Rate per Agent (vs built-in AI)\")\n",
" axes[1].set_ylim(0, 1)\n",
" axes[1].axhline(y=0.5, color=\"gray\", linestyle=\"--\", alpha=0.5, label=\"50% baseline\")\n",
" axes[1].legend()\n",
" axes[1].grid(axis=\"y\", alpha=0.3)\n",
"\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"\n",
"def evaluate_tournament(\n",
" env: gym.Env,\n",
" agents: dict[str, Agent],\n",
" episodes_per_agent: int = 20,\n",
") -> dict[str, dict[str, object]]:\n",
" \"\"\"Evaluate all agents against the built-in AI and produce a comparison.\n",
"\n",
" Args:\n",
" env: Gymnasium ALE/Tennis-v5 environment.\n",
" agents: Dictionary mapping agent names to Agent instances.\n",
" episodes_per_agent: Number of evaluation episodes per agent.\n",
"\n",
" Returns:\n",
" Dict mapping agent names to their evaluation results.\n",
"\n",
" \"\"\"\n",
" results: dict[str, dict[str, object]] = {}\n",
" n_agents = len(agents)\n",
"\n",
" for idx, (name, agent) in enumerate(agents.items(), start=1):\n",
" print(f\"[Evaluation {idx}/{n_agents}] {name}\")\n",
" results[name] = evaluate_agent(\n",
" env, agent, name, episodes=episodes_per_agent,\n",
" )\n",
" mean_r = results[name][\"mean_reward\"]\n",
" wr = results[name][\"win_rate\"]\n",
" print(f\" -> Mean reward: {mean_r:.2f} | Win rate: {wr:.1%}\\n\")\n",
"\n",
" return results\n"
]
},
{
"cell_type": "markdown",
"id": "9605e9c4",
"metadata": {},
"source": [
"## Agent Instantiation & Incremental Training (One Agent at a Time)\n",
"\n",
"**Environment**: `ALE/Tennis-v5` (grayscale, 84×84×4 frames → 28,224 features, 18 actions).\n",
"\n",
"**Agents**:\n",
"- **Random** — random baseline (no training needed)\n",
"- **SARSA** — linear approximation, semi-gradient TD(0)\n",
"- **Q-Learning** — linear approximation, off-policy\n",
"- **Monte Carlo** — linear approximation, first-visit returns\n",
"\n",
"**Workflow**:\n",
"1. Train **one** selected agent (`AGENT_TO_TRAIN`)\n",
"2. Save its weights to `checkpoints/` (`.pkl`)\n",
"3. Repeat later for another agent without retraining previous ones\n",
"4. Load all saved checkpoints before the final evaluation"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "6f6ba8df",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"A.L.E: Arcade Learning Environment (version 0.11.2+ecc1138)\n",
"[Powered by Stella]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Observation shape : (4, 84, 84)\n",
"Feature vector dim: 28224\n",
"Number of actions : 18\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"objc[49878]: Class SDLApplication is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d2c8) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418890). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class SDLAppDelegate is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d318) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x1244188e0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class SDLTranslatorResponder is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d390) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418958). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class SDLMessageBoxPresenter is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d3b8) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418980). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class SDL_cocoametalview is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d408) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x1244189d0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class SDLOpenGLContext is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d458) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418a20). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class SDL_ShapeData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d4d0) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418a98). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class SDL_CocoaClosure is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d520) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418ae8). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class SDL_VideoData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d570) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418b38). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class SDL_WindowData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d5c0) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418b88). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class SDLWindow is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d5e8) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418bb0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class Cocoa_WindowListener is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d610) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418bd8). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class SDLView is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d688) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418c50). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class METAL_RenderData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d700) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418cc8). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class METAL_TextureData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d750) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418d18). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class SDL_RumbleMotor is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d778) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418d40). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
"objc[49878]: Class SDL_RumbleContext is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d7c8) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418d90). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n"
]
}
],
"source": [
"# Create environment\n",
"env = create_env()\n",
"obs, _info = env.reset()\n",
"\n",
"n_actions = int(env.action_space.n)\n",
"n_features = int(np.prod(obs.shape))\n",
"\n",
"print(f\"Observation shape : {obs.shape}\")\n",
"print(f\"Feature vector dim: {n_features}\")\n",
"print(f\"Number of actions : {n_actions}\")\n",
"\n",
"# Instantiate agents\n",
"agent_random = RandomAgent(seed=42, action_space=int(n_actions))\n",
"agent_sarsa = SarsaAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
"agent_q = QLearningAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
"agent_mc = MonteCarloAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
"\n",
"agents = {\n",
" \"Random\": agent_random,\n",
" \"SARSA\": agent_sarsa,\n",
" \"Q-Learning\": agent_q,\n",
" \"Monte Carlo\": agent_mc,\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4d449701",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Selected agent: Monte Carlo\n",
"Checkpoint path: checkpoints/monte_carlo.pkl\n",
"\n",
"============================================================\n",
"Training: Monte Carlo (2500 episodes)\n",
"============================================================\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0d5d098d18014fe6b736683e0b8b2488",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training Monte Carlo: 0%| | 0/2500 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"AGENT_TO_TRAIN = \"Monte Carlo\" # TODO: change to: \"Q-Learning\", \"Monte Carlo\", \"Random\"\n",
"TRAINING_EPISODES = 2500\n",
"FORCE_RETRAIN = False\n",
"\n",
"if AGENT_TO_TRAIN not in agents:\n",
" msg = f\"Unknown agent '{AGENT_TO_TRAIN}'. Available: {list(agents)}\"\n",
" raise ValueError(msg)\n",
"\n",
"training_histories: dict[str, list[float]] = {}\n",
"agent = agents[AGENT_TO_TRAIN]\n",
"checkpoint_path = get_path(AGENT_TO_TRAIN)\n",
"\n",
"print(f\"Selected agent: {AGENT_TO_TRAIN}\")\n",
"print(f\"Checkpoint path: {checkpoint_path}\")\n",
"\n",
"if AGENT_TO_TRAIN == \"Random\":\n",
" print(\"Random is a baseline and is not trained.\")\n",
" training_histories[AGENT_TO_TRAIN] = []\n",
"elif checkpoint_path.exists() and not FORCE_RETRAIN:\n",
" agent.load(str(checkpoint_path))\n",
" print(\"Checkpoint found -> weights loaded, training skipped.\")\n",
" training_histories[AGENT_TO_TRAIN] = []\n",
"else:\n",
" print(f\"\\n{'='*60}\")\n",
" print(f\"Training: {AGENT_TO_TRAIN} ({TRAINING_EPISODES} episodes)\")\n",
" print(f\"{'='*60}\")\n",
"\n",
" training_histories[AGENT_TO_TRAIN] = train_agent(\n",
" env=env,\n",
" agent=agent,\n",
" name=AGENT_TO_TRAIN,\n",
" episodes=TRAINING_EPISODES,\n",
" epsilon_start=1.0,\n",
" epsilon_end=0.05,\n",
" epsilon_decay=0.999,\n",
" )\n",
"\n",
" avg_last_100 = np.mean(training_histories[AGENT_TO_TRAIN][-100:])\n",
" print(f\"-> {AGENT_TO_TRAIN} avg reward (last 100 eps): {avg_last_100:.2f}\")\n",
"\n",
" agent.save(str(checkpoint_path))\n",
" print(\"Checkpoint saved.\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a13a65df",
"metadata": {},
"outputs": [],
"source": [
"plot_training_curves(\n",
" training_histories, f\"plots/{AGENT_TO_TRAIN}_training_curves.png\", window=100,\n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "0e5f2c49",
"metadata": {},
"source": [
"## Final Evaluation\n",
"\n",
"Each agent plays 20 episodes against the built-in AI with no exploration (ε = 0).\n",
"Performance is compared via mean reward and win rate."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "70f5d5cd",
"metadata": {},
"outputs": [],
"source": [
"# Build evaluation set: Random + agents with existing checkpoints\n",
"eval_agents: dict[str, Agent] = {\"Random\": agents[\"Random\"]}\n",
"missing_agents: list[str] = []\n",
"\n",
"for name, agent in agents.items():\n",
" if name == \"Random\":\n",
" continue\n",
"\n",
" checkpoint_path = get_path(name)\n",
"\n",
" if checkpoint_path.exists():\n",
" agent.load(str(checkpoint_path))\n",
" eval_agents[name] = agent\n",
" else:\n",
" missing_agents.append(name)\n",
"\n",
"print(f\"Agents evaluated: {list(eval_agents.keys())}\")\n",
"if missing_agents:\n",
" print(f\"Skipped (no checkpoint yet): {missing_agents}\")\n",
"\n",
"if len(eval_agents) < 2:\n",
" raise RuntimeError(\"Train at least one non-random agent before final evaluation.\")\n",
"\n",
"results = evaluate_tournament(env, eval_agents, episodes_per_agent=20)\n",
"plot_evaluation_comparison(results)\n",
"\n",
"# Print summary table\n",
"print(f\"\\n{'Agent':<15} {'Mean Reward':>12} {'Std':>8} {'Win Rate':>10}\")\n",
"print(\"-\" * 48)\n",
"for name, res in results.items():\n",
" print(f\"{name:<15} {res['mean_reward']:>12.2f} {res['std_reward']:>8.2f} {res['win_rate']:>9.1%}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d04d37e0",
"metadata": {},
"outputs": [],
"source": [
"def create_tournament_env():\n",
" \"\"\"Create PettingZoo Tennis env with preprocessing compatible with our agents.\"\"\"\n",
" env = tennis_v3.env(obs_type=\"rgb_image\")\n",
" env = ss.color_reduction_v0(env, mode=\"full\")\n",
" env = ss.resize_v1(env, x_size=84, y_size=84)\n",
" return ss.frame_stack_v1(env, 4)\n",
"\n",
"\n",
"def run_pz_match(\n",
" env,\n",
" agent_first: Agent,\n",
" agent_second: Agent,\n",
" episodes: int = 10,\n",
" max_steps: int = 4000,\n",
") -> dict[str, int]:\n",
" \"\"\"Run multiple PettingZoo episodes between two agents.\n",
"\n",
" Returns wins for global labels {'first': ..., 'second': ..., 'draw': ...}.\n",
" \"\"\"\n",
" wins = {\"first\": 0, \"second\": 0, \"draw\": 0}\n",
"\n",
" for _ep in range(episodes):\n",
" env.reset()\n",
" rewards = {\"first_0\": 0.0, \"second_0\": 0.0}\n",
"\n",
" for step_idx, agent_id in enumerate(env.agent_iter()):\n",
" obs, reward, termination, truncation, _info = env.last()\n",
" done = termination or truncation\n",
" rewards[agent_id] += float(reward)\n",
"\n",
" if done or step_idx >= max_steps:\n",
" action = None\n",
" else:\n",
" current_agent = agent_first if agent_id == \"first_0\" else agent_second\n",
" action = current_agent.get_action(np.asarray(obs), epsilon=0.0)\n",
"\n",
" env.step(action)\n",
"\n",
" if step_idx + 1 >= max_steps:\n",
" break\n",
"\n",
" if rewards[\"first_0\"] > rewards[\"second_0\"]:\n",
" wins[\"first\"] += 1\n",
" elif rewards[\"second_0\"] > rewards[\"first_0\"]:\n",
" wins[\"second\"] += 1\n",
" else:\n",
" wins[\"draw\"] += 1\n",
"\n",
" return wins\n",
"\n",
"\n",
"def run_pettingzoo_tournament(\n",
" agents: dict[str, Agent],\n",
" episodes_per_side: int = 10,\n",
") -> tuple[np.ndarray, list[str]]:\n",
" \"\"\"Round-robin tournament excluding Random, with seat-swap fairness.\"\"\"\n",
" _ = itertools # kept for notebook context consistency\n",
" candidate_names = [name for name in agents if name != \"Random\"]\n",
"\n",
" # Keep only agents that have a checkpoint\n",
" ready_names: list[str] = []\n",
" for name in candidate_names:\n",
" checkpoint_path = get_path(name)\n",
" if checkpoint_path.exists():\n",
" agents[name].load(str(checkpoint_path))\n",
" ready_names.append(name)\n",
"\n",
" if len(ready_names) < 2:\n",
" msg = \"Need at least 2 trained (checkpointed) non-random agents for PettingZoo tournament.\"\n",
" raise RuntimeError(msg)\n",
"\n",
" n = len(ready_names)\n",
" win_matrix = np.full((n, n), np.nan)\n",
" np.fill_diagonal(win_matrix, 0.5)\n",
"\n",
" for i in range(n):\n",
" for j in range(i + 1, n):\n",
" name_i = ready_names[i]\n",
" name_j = ready_names[j]\n",
"\n",
" print(f\"Matchup: {name_i} vs {name_j}\")\n",
" env = create_tournament_env()\n",
"\n",
" # Leg 1: i as first_0, j as second_0\n",
" leg1 = run_pz_match(\n",
" env,\n",
" agent_first=agents[name_i],\n",
" agent_second=agents[name_j],\n",
" episodes=episodes_per_side,\n",
" )\n",
"\n",
" # Leg 2: swap seats\n",
" leg2 = run_pz_match(\n",
" env,\n",
" agent_first=agents[name_j],\n",
" agent_second=agents[name_i],\n",
" episodes=episodes_per_side,\n",
" )\n",
"\n",
" env.close()\n",
"\n",
" wins_i = leg1[\"first\"] + leg2[\"second\"]\n",
" wins_j = leg1[\"second\"] + leg2[\"first\"]\n",
"\n",
" decisive = wins_i + wins_j\n",
" if decisive == 0:\n",
" wr_i = 0.5\n",
" wr_j = 0.5\n",
" else:\n",
" wr_i = wins_i / decisive\n",
" wr_j = wins_j / decisive\n",
"\n",
" win_matrix[i, j] = wr_i\n",
" win_matrix[j, i] = wr_j\n",
"\n",
" print(f\" -> {name_i}: {wins_i} wins | {name_j}: {wins_j} wins\\n\")\n",
"\n",
" return win_matrix, ready_names\n",
"\n",
"\n",
"# Run tournament (non-random agents only)\n",
"win_matrix_pz, pz_names = run_pettingzoo_tournament(\n",
" agents=agents,\n",
" episodes_per_side=10,\n",
")\n",
"\n",
"# Plot win-rate matrix\n",
"plt.figure(figsize=(8, 6))\n",
"sns.heatmap(\n",
" win_matrix_pz,\n",
" annot=True,\n",
" fmt=\".2f\",\n",
" cmap=\"Blues\",\n",
" vmin=0.0,\n",
" vmax=1.0,\n",
" xticklabels=pz_names,\n",
" yticklabels=pz_names,\n",
")\n",
"plt.xlabel(\"Opponent\")\n",
"plt.ylabel(\"Agent\")\n",
"plt.title(\"PettingZoo Tournament Win Rate Matrix (Non-random agents)\")\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"# Rank agents by mean win rate vs others (excluding diagonal)\n",
"scores = {}\n",
"for idx, name in enumerate(pz_names):\n",
" row = np.delete(win_matrix_pz[idx], idx)\n",
" scores[name] = float(np.mean(row))\n",
"\n",
"ranking = sorted(scores.items(), key=lambda x: x[1], reverse=True)\n",
"print(\"Final ranking (PettingZoo tournament, non-random):\")\n",
"for rank_idx, (name, score) in enumerate(ranking, start=1):\n",
" print(f\"{rank_idx}. {name:<12} | mean win rate: {score:.3f}\")\n",
"\n",
"print(f\"\\nBest agent: {ranking[0][0]}\")\n"
]
},
{
"cell_type": "markdown",
"id": "3f8b300d",
"metadata": {},
"source": [
"## PettingZoo Tournament (Agents vs Agents)\n",
"\n",
"This tournament uses `from pettingzoo.atari import tennis_v3` to make trained agents play against each other directly.\n",
"\n",
"- Checkpoints are loaded from `checkpoints/` (`.pkl`)\n",
"- `Random` is **excluded** from ranking\n",
"- Each pair plays in both seat positions (`first_0` and `second_0`) to reduce position bias\n",
"- A win-rate matrix and final ranking are produced"
]
},
{
"cell_type": "markdown",
"id": "150e6764",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "studies (3.13.9)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}