{ "cells": [ { "cell_type": "markdown", "id": "fef45687", "metadata": {}, "source": [ "# RL Project: Atari Tennis Tournament\n", "\n", "This notebook implements four Reinforcement Learning algorithms to play Atari Tennis (`ALE/Tennis-v5` via Gymnasium):\n", "\n", "1. **SARSA** — Semi-gradient SARSA with linear approximation (inspired by Lab 7, on-policy update from Lab 5B)\n", "2. **Q-Learning** — Off-policy linear approximation (inspired by Lab 5B)\n", "3. **Monte Carlo** — First-visit MC control with linear approximation (inspired by Lab 4)\n", "\n", "Each agent is **pre-trained independently** against the built-in Atari AI opponent, then evaluated in a comparative tournament." ] }, { "cell_type": "code", "execution_count": 1, "id": "b50d7174", "metadata": {}, "outputs": [], "source": [ "import itertools\n", "import pickle\n", "from pathlib import Path\n", "\n", "import ale_py # noqa: F401 — registers ALE environments\n", "import gymnasium as gym\n", "import supersuit as ss\n", "from gymnasium.wrappers import FrameStackObservation, ResizeObservation\n", "from pettingzoo.atari import tennis_v3\n", "from tqdm.auto import tqdm\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import seaborn as sns\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "ff3486a4", "metadata": {}, "outputs": [], "source": [ "CHECKPOINT_DIR = Path(\"checkpoints\")\n", "CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n", "\n", "\n", "def get_path(name: str) -> Path:\n", " \"\"\"Return the checkpoint path for an agent (.pkl).\"\"\"\n", " base = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\")\n", " return CHECKPOINT_DIR / (base + \".pkl\")\n" ] }, { "cell_type": "markdown", "id": "ec691487", "metadata": {}, "source": [ "# Utility Functions\n", "\n", "## Observation Normalization\n", "\n", "The Tennis environment produces image observations of shape `(4, 84, 84)` after preprocessing (grayscale + resize + frame stack).\n", "We normalize them into 1D `float64` vectors divided by 255, as in Lab 7 (continuous feature normalization).\n", "\n", "## ε-greedy Policy\n", "\n", "Follows the pattern from Lab 5B (`epsilon_greedy`) and Lab 7 (`epsilon_greedy_action`):\n", "- With probability ε: random action (exploration)\n", "- With probability 1−ε: action maximizing $\\hat{q}(s, a)$ with uniform tie-breaking (`np.flatnonzero`)" ] }, { "cell_type": "code", "execution_count": 3, "id": "be85c130", "metadata": {}, "outputs": [], "source": [ "def normalize_obs(observation: np.ndarray) -> np.ndarray:\n", " \"\"\"Flatten and normalize an observation to a 1D float64 vector.\n", "\n", " Replicates the /255.0 normalization used in all agents from the original project.\n", " For image observations of shape (4, 84, 84), this produces a vector of length 28_224.\n", "\n", " Args:\n", " observation: Raw observation array from the environment.\n", "\n", " Returns:\n", " 1D numpy array of dtype float64, values in [0, 1].\n", "\n", " \"\"\"\n", " return observation.flatten().astype(np.float64) / 255.0\n", "\n", "\n", "def epsilon_greedy(\n", " q_values: np.ndarray,\n", " epsilon: float,\n", " rng: np.random.Generator,\n", ") -> int:\n", " \"\"\"Select an action using an ε-greedy policy with fair tie-breaking.\n", "\n", " Follows the same logic as Lab 5B epsilon_greedy and Lab 7 epsilon_greedy_action:\n", " - With probability epsilon: choose a random action (exploration).\n", " - With probability 1-epsilon: choose the action with highest Q-value (exploitation).\n", " - If multiple actions share the maximum Q-value, break ties uniformly at random.\n", "\n", " Handles edge cases: empty q_values, NaN/Inf values.\n", "\n", " Args:\n", " q_values: Array of Q-values for each action, shape (n_actions,).\n", " epsilon: Exploration probability in [0, 1].\n", " rng: NumPy random number generator.\n", "\n", " Returns:\n", " Selected action index.\n", "\n", " \"\"\"\n", " q_values = np.asarray(q_values, dtype=np.float64).reshape(-1)\n", "\n", " if q_values.size == 0:\n", " msg = \"q_values is empty.\"\n", " raise ValueError(msg)\n", "\n", " if rng.random() < epsilon:\n", " return int(rng.integers(0, q_values.size))\n", "\n", " # Handle NaN/Inf values safely\n", " finite_mask = np.isfinite(q_values)\n", " if not np.any(finite_mask):\n", " return int(rng.integers(0, q_values.size))\n", "\n", " safe_q = q_values.copy()\n", " safe_q[~finite_mask] = -np.inf\n", " max_val = np.max(safe_q)\n", " best = np.flatnonzero(safe_q == max_val)\n", "\n", " if best.size == 0:\n", " return int(rng.integers(0, q_values.size))\n", "\n", " return int(rng.choice(best))\n" ] }, { "cell_type": "markdown", "id": "bb53da28", "metadata": {}, "source": [ "# Agent Definitions\n", "\n", "## Base Class `Agent`\n", "\n", "Common interface for all agents, same signatures: `get_action`, `update`, `save`, `load`.\n", "Serialization uses `pickle` (compatible with numpy arrays)." ] }, { "cell_type": "code", "execution_count": 4, "id": "ded9b1fb", "metadata": {}, "outputs": [], "source": [ "class Agent:\n", " \"\"\"Base class for reinforcement learning agents.\n", "\n", " All agents share this interface so they are compatible with the tournament system.\n", " \"\"\"\n", "\n", " def __init__(self, seed: int, action_space: int) -> None:\n", " \"\"\"Initialize the agent with its action space and a reproducible RNG.\"\"\"\n", " self.action_space = action_space\n", " self.rng = np.random.default_rng(seed=seed)\n", "\n", " def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n", " \"\"\"Select an action from the current observation.\"\"\"\n", " raise NotImplementedError\n", "\n", " def update(\n", " self,\n", " state: np.ndarray,\n", " action: int,\n", " reward: float,\n", " next_state: np.ndarray,\n", " done: bool,\n", " next_action: int | None = None,\n", " ) -> None:\n", " \"\"\"Update agent parameters from one transition.\"\"\"\n", "\n", " def save(self, filename: str) -> None:\n", " \"\"\"Save the agent state to disk using pickle.\"\"\"\n", " with Path(filename).open(\"wb\") as f:\n", " pickle.dump(self.__dict__, f)\n", "\n", " def load(self, filename: str) -> None:\n", " \"\"\"Load the agent state from disk.\"\"\"\n", " with Path(filename).open(\"rb\") as f:\n", " self.__dict__.update(pickle.load(f)) # noqa: S301\n" ] }, { "cell_type": "markdown", "id": "8a4eae79", "metadata": {}, "source": [ "## Random Agent (baseline)\n", "\n", "Serves as a reference to evaluate the performance of learning agents." ] }, { "cell_type": "code", "execution_count": 5, "id": "78bdc9d2", "metadata": {}, "outputs": [], "source": [ "class RandomAgent(Agent):\n", " \"\"\"A simple agent that selects actions uniformly at random (baseline).\"\"\"\n", "\n", " def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n", " \"\"\"Select a random action, ignoring the observation and epsilon.\"\"\"\n", " _ = observation, epsilon\n", " return int(self.rng.integers(0, self.action_space))\n" ] }, { "cell_type": "markdown", "id": "5f679032", "metadata": {}, "source": [ "## SARSA Agent — Linear Approximation (Semi-gradient)\n", "\n", "This agent combines:\n", "- **Linear approximation** from Lab 7 (`SarsaAgent`): $\\hat{q}(s, a; \\mathbf{W}) = \\mathbf{W}_a^\\top \\phi(s)$\n", "- **On-policy SARSA update** from Lab 5B (`train_sarsa`): $\\delta = r + \\gamma \\hat{q}(s', a') - \\hat{q}(s, a)$\n", "\n", "The semi-gradient update rule is:\n", "$$W_a \\leftarrow W_a + \\alpha \\cdot \\delta \\cdot \\phi(s)$$\n", "\n", "where $\\phi(s)$ is the normalized observation vector (analogous to tile coding features in Lab 7, but in dense form)." ] }, { "cell_type": "code", "execution_count": 6, "id": "c124ed9a", "metadata": {}, "outputs": [], "source": [ "class SarsaAgent(Agent):\n", " \"\"\"Semi-gradient SARSA agent with linear function approximation.\n", "\n", " Inspired by:\n", " - Lab 7 SarsaAgent: linear q(s,a) = W_a . phi(s), semi-gradient update\n", " - Lab 5B train_sarsa: on-policy TD target using Q(s', a')\n", "\n", " The weight matrix W has shape (n_actions, n_features).\n", " For a given state s, q(s, a) = W[a] @ phi(s) is the dot product\n", " of the action's weight row with the normalized observation.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " n_features: int,\n", " n_actions: int,\n", " alpha: float = 0.001,\n", " gamma: float = 0.99,\n", " seed: int = 42,\n", " ) -> None:\n", " \"\"\"Initialize SARSA agent with linear weights.\n", "\n", " Args:\n", " n_features: Dimension of the feature vector phi(s).\n", " n_actions: Number of discrete actions.\n", " alpha: Learning rate (kept small for high-dim features).\n", " gamma: Discount factor.\n", " seed: RNG seed for reproducibility.\n", "\n", " \"\"\"\n", " super().__init__(seed, n_actions)\n", " self.n_features = n_features\n", " self.alpha = alpha\n", " self.gamma = gamma\n", " # Weight matrix: one row per action, analogous to Lab 7's self.w\n", " # but organized as (n_actions, n_features) for dense features.\n", " self.W = np.zeros((n_actions, n_features), dtype=np.float64)\n", "\n", " def _q_values(self, phi: np.ndarray) -> np.ndarray:\n", " \"\"\"Compute Q-values for all actions given feature vector phi(s).\n", "\n", " Equivalent to Lab 7's self.q(s, a) = self.w[idx].sum()\n", " but using dense linear approximation: q(s, a) = W[a] @ phi.\n", "\n", " Args:\n", " phi: Normalized feature vector, shape (n_features,).\n", "\n", " Returns:\n", " Array of Q-values, shape (n_actions,).\n", "\n", " \"\"\"\n", " return self.W @ phi # shape (n_actions,)\n", "\n", " def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n", " \"\"\"Select action using ε-greedy policy over linear Q-values.\n", "\n", " Same pattern as Lab 7 SarsaAgent.eps_greedy:\n", " compute q-values for all actions, then apply epsilon_greedy.\n", " \"\"\"\n", " phi = normalize_obs(observation)\n", " q_vals = self._q_values(phi)\n", " return epsilon_greedy(q_vals, epsilon, self.rng)\n", "\n", " def update(\n", " self,\n", " state: np.ndarray,\n", " action: int,\n", " reward: float,\n", " next_state: np.ndarray,\n", " done: bool,\n", " next_action: int | None = None,\n", " ) -> None:\n", " \"\"\"Perform one semi-gradient SARSA update.\n", "\n", " Follows the SARSA update from Lab 5B train_sarsa:\n", " td_target = r + gamma * Q(s', a') * (0 if done else 1)\n", " Q(s, a) += alpha * (td_target - Q(s, a))\n", "\n", " In continuous form with linear approximation (Lab 7 SarsaAgent.update):\n", " delta = target - q(s, a)\n", " W[a] += alpha * delta * phi(s)\n", "\n", " Args:\n", " state: Current observation.\n", " action: Action taken.\n", " reward: Reward received.\n", " next_state: Next observation.\n", " done: Whether the episode ended.\n", " next_action: Action chosen in next state (required for SARSA).\n", "\n", " \"\"\"\n", " phi = np.nan_to_num(normalize_obs(state), nan=0.0, posinf=0.0, neginf=0.0)\n", " q_sa = float(self.W[action] @ phi) # current estimate q(s, a)\n", " if not np.isfinite(q_sa):\n", " q_sa = 0.0\n", "\n", " if done:\n", " # Terminal: no future value (Lab 5B: gamma * Q[s2, a2] * 0)\n", " target = reward\n", " else:\n", " # On-policy: use q(s', a') where a' is the actual next action\n", " # This is the key SARSA property (Lab 5B)\n", " phi_next = np.nan_to_num(normalize_obs(next_state), nan=0.0, posinf=0.0, neginf=0.0)\n", " if next_action is None:\n", " next_action = 0 # fallback, should not happen in practice\n", " q_sp_ap = float(self.W[next_action] @ phi_next)\n", " if not np.isfinite(q_sp_ap):\n", " q_sp_ap = 0.0\n", " target = float(reward) + self.gamma * q_sp_ap\n", "\n", " # Semi-gradient update: W[a] += alpha * delta * phi(s)\n", " # Analogous to Lab 7: self.w[idx] += self.alpha * delta\n", " if not np.isfinite(target):\n", " return\n", "\n", " delta = float(target - q_sa)\n", " if not np.isfinite(delta):\n", " return\n", "\n", " td_step = float(np.clip(delta, -1_000.0, 1_000.0))\n", " self.W[action] += self.alpha * td_step * phi\n", " self.W[action] = np.nan_to_num(self.W[action], nan=0.0, posinf=1e6, neginf=-1e6)\n" ] }, { "cell_type": "markdown", "id": "d4e18536", "metadata": {}, "source": [ "## Q-Learning Agent — Linear Approximation (Off-policy)\n", "\n", "Same architecture as SARSA but with the **off-policy update** from Lab 5B (`train_q_learning`):\n", "\n", "$$\\delta = r + \\gamma \\max_{a'} \\hat{q}(s', a') - \\hat{q}(s, a)$$\n", "\n", "The key difference from SARSA: we use $\\max_{a'} Q(s', a')$ instead of $Q(s', a')$ where $a'$ is the action actually chosen. This allows learning the optimal policy independently of the exploration policy." ] }, { "cell_type": "code", "execution_count": 7, "id": "f5b5b9ea", "metadata": {}, "outputs": [], "source": [ "class QLearningAgent(Agent):\n", " \"\"\"Q-Learning agent with linear function approximation (off-policy).\n", "\n", " Inspired by:\n", " - Lab 5B train_q_learning: off-policy TD target using max_a' Q(s', a')\n", " - Lab 7 SarsaAgent: linear approximation q(s,a) = W[a] @ phi(s)\n", "\n", " The only difference from SarsaAgent is the TD target:\n", " SARSA uses Q(s', a') (on-policy), Q-Learning uses max_a' Q(s', a') (off-policy).\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " n_features: int,\n", " n_actions: int,\n", " alpha: float = 0.001,\n", " gamma: float = 0.99,\n", " seed: int = 42,\n", " ) -> None:\n", " \"\"\"Initialize Q-Learning agent with linear weights.\n", "\n", " Args:\n", " n_features: Dimension of the feature vector phi(s).\n", " n_actions: Number of discrete actions.\n", " alpha: Learning rate.\n", " gamma: Discount factor.\n", " seed: RNG seed.\n", "\n", " \"\"\"\n", " super().__init__(seed, n_actions)\n", " self.n_features = n_features\n", " self.alpha = alpha\n", " self.gamma = gamma\n", " self.W = np.zeros((n_actions, n_features), dtype=np.float64)\n", "\n", " def _q_values(self, phi: np.ndarray) -> np.ndarray:\n", " \"\"\"Compute Q-values for all actions: q(s, a) = W[a] @ phi for each a.\"\"\"\n", " return self.W @ phi\n", "\n", " def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n", " \"\"\"Select action using ε-greedy policy over linear Q-values.\"\"\"\n", " phi = normalize_obs(observation)\n", " q_vals = self._q_values(phi)\n", " return epsilon_greedy(q_vals, epsilon, self.rng)\n", "\n", " def update(\n", " self,\n", " state: np.ndarray,\n", " action: int,\n", " reward: float,\n", " next_state: np.ndarray,\n", " done: bool,\n", " next_action: int | None = None,\n", " ) -> None:\n", " \"\"\"Perform one Q-learning update.\n", "\n", " Follows Lab 5B train_q_learning:\n", " td_target = r + gamma * max(Q[s2]) * (0 if terminated else 1)\n", " Q[s, a] += alpha * (td_target - Q[s, a])\n", "\n", " In continuous form with linear approximation:\n", " delta = target - q(s, a)\n", " W[a] += alpha * delta * phi(s)\n", " \"\"\"\n", " _ = next_action # Q-learning is off-policy: next_action is not used\n", " phi = np.nan_to_num(normalize_obs(state), nan=0.0, posinf=0.0, neginf=0.0)\n", " q_sa = float(self.W[action] @ phi)\n", " if not np.isfinite(q_sa):\n", " q_sa = 0.0\n", "\n", " if done:\n", " # Terminal state: no future value\n", " # Lab 5B: gamma * np.max(Q[s2]) * (0 if terminated else 1)\n", " target = reward\n", " else:\n", " # Off-policy: use max over all actions in next state\n", " # This is the key Q-learning property (Lab 5B)\n", " phi_next = np.nan_to_num(normalize_obs(next_state), nan=0.0, posinf=0.0, neginf=0.0)\n", " q_next_all = self._q_values(phi_next) # q(s', a') for all a'\n", " q_next_max = float(np.max(q_next_all))\n", " if not np.isfinite(q_next_max):\n", " q_next_max = 0.0\n", " target = float(reward) + self.gamma * q_next_max\n", "\n", " if not np.isfinite(target):\n", " return\n", "\n", " delta = float(target - q_sa)\n", " if not np.isfinite(delta):\n", " return\n", "\n", " td_step = float(np.clip(delta, -1_000.0, 1_000.0))\n", " self.W[action] += self.alpha * td_step * phi\n", " self.W[action] = np.nan_to_num(self.W[action], nan=0.0, posinf=1e6, neginf=-1e6)\n" ] }, { "cell_type": "markdown", "id": "b7b63455", "metadata": {}, "source": [ "## Monte Carlo Agent — Linear Approximation (First-visit)\n", "\n", "This agent is inspired by Lab 4 (`mc_control_epsilon_soft`):\n", "- Accumulates transitions in an episode buffer `(state, action, reward)`\n", "- At the end of the episode (`done=True`), computes **cumulative returns** by traversing the buffer backward:\n", " $$G \\leftarrow \\gamma \\cdot G + r$$\n", "- Updates weights with the semi-gradient rule:\n", " $$W_a \\leftarrow W_a + \\alpha \\cdot (G - \\hat{q}(s, a)) \\cdot \\phi(s)$$\n", "\n", "Unlike TD methods (SARSA, Q-Learning), Monte Carlo waits for the complete episode to finish before updating." ] }, { "cell_type": "code", "execution_count": 8, "id": "3c9d74be", "metadata": {}, "outputs": [], "source": [ "class MonteCarloAgent(Agent):\n", " \"\"\"Monte Carlo control agent with linear function approximation.\n", "\n", " Inspired by Lab 4 mc_control_epsilon_soft:\n", " - Accumulates transitions in an episode buffer\n", " - At episode end (done=True), computes discounted returns backward:\n", " G = gamma * G + r (same as Lab 4's reversed loop)\n", " - Updates weights with semi-gradient: W[a] += alpha * (G - q(s,a)) * phi(s)\n", "\n", " Unlike TD methods (SARSA, Q-Learning), no update occurs until the episode ends.\n", "\n", " Performance optimizations over naive per-step implementation:\n", " - float32 weights & features (halves memory bandwidth, faster SIMD)\n", " - Raw observations stored compactly as uint8, batch-normalized at episode end\n", " - Vectorized return computation & chunk-based weight updates via einsum\n", " - Single weight sanitization per episode instead of per-step\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " n_features: int,\n", " n_actions: int,\n", " alpha: float = 0.001,\n", " gamma: float = 0.99,\n", " seed: int = 42,\n", " ) -> None:\n", " super().__init__(seed, n_actions)\n", " self.n_features = n_features\n", " self.alpha = alpha\n", " self.gamma = gamma\n", " self.W = np.zeros((n_actions, n_features), dtype=np.float32)\n", " self._obs_buf: list[np.ndarray] = []\n", " self._act_buf: list[int] = []\n", " self._rew_buf: list[float] = []\n", "\n", " def _q_values(self, phi: np.ndarray) -> np.ndarray:\n", " return self.W @ phi\n", "\n", " def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n", " phi = observation.flatten().astype(np.float32) / np.float32(255.0)\n", " q_vals = self._q_values(phi)\n", " return epsilon_greedy(q_vals, epsilon, self.rng)\n", "\n", " def update(\n", " self,\n", " state: np.ndarray,\n", " action: int,\n", " reward: float,\n", " next_state: np.ndarray,\n", " done: bool,\n", " next_action: int | None = None,\n", " ) -> None:\n", " _ = next_state, next_action\n", "\n", " self._obs_buf.append(state)\n", " self._act_buf.append(action)\n", " self._rew_buf.append(reward)\n", "\n", " if not done:\n", " return\n", "\n", " n = len(self._rew_buf)\n", " actions = np.array(self._act_buf, dtype=np.intp)\n", "\n", " returns = np.empty(n, dtype=np.float32)\n", " G = np.float32(0.0)\n", " gamma32 = np.float32(self.gamma)\n", " for i in range(n - 1, -1, -1):\n", " G = gamma32 * G + np.float32(self._rew_buf[i])\n", " returns[i] = G\n", "\n", " alpha32 = np.float32(self.alpha)\n", " chunk_size = 500\n", " for start in range(0, n, chunk_size):\n", " end = min(start + chunk_size, n)\n", " cs = end - start\n", "\n", " raw = np.array(self._obs_buf[start:end])\n", " phi = raw.reshape(cs, -1).astype(np.float32)\n", " phi /= np.float32(255.0)\n", "\n", " ca = actions[start:end]\n", " q_sa = np.einsum(\"ij,ij->i\", self.W[ca], phi)\n", "\n", " deltas = np.clip(returns[start:end] - q_sa, -1000.0, 1000.0)\n", "\n", " for a in range(self.action_space):\n", " mask = ca == a\n", " if not np.any(mask):\n", " continue\n", " self.W[a] += alpha32 * (deltas[mask] @ phi[mask])\n", "\n", " self.W = np.nan_to_num(self.W, nan=0.0, posinf=1e6, neginf=-1e6)\n", "\n", " self._obs_buf.clear()\n", " self._act_buf.clear()\n", " self._rew_buf.clear()\n" ] }, { "cell_type": "markdown", "id": "91e51dc8", "metadata": {}, "source": [ "## Tennis Environment\n", "\n", "Creation of the Atari Tennis environment via Gymnasium (`ALE/Tennis-v5`) with standard wrappers:\n", "- **Grayscale**: `obs_type=\"grayscale\"` — single-channel observations\n", "- **Resize**: `ResizeObservation(84, 84)` — downscale to 84×84\n", "- **Frame stack**: `FrameStackObservation(4)` — stack 4 consecutive frames\n", "\n", "The final observation is an array of shape `(4, 84, 84)`, which flattens to 28,224 features.\n", "\n", "The agent plays against the **built-in Atari AI opponent**." ] }, { "cell_type": "code", "execution_count": 9, "id": "f9a973dd", "metadata": {}, "outputs": [], "source": [ "def create_env() -> gym.Env:\n", " \"\"\"Create the ALE/Tennis-v5 environment with preprocessing wrappers.\n", "\n", " Applies:\n", " - obs_type=\"grayscale\": grayscale observation (210, 160)\n", " - ResizeObservation(84, 84): downscale to 84x84\n", " - FrameStackObservation(4): stack 4 consecutive frames -> (4, 84, 84)\n", "\n", " Returns:\n", " Gymnasium environment ready for training.\n", "\n", " \"\"\"\n", " env = gym.make(\"ALE/Tennis-v5\", obs_type=\"grayscale\")\n", " env = ResizeObservation(env, shape=(84, 84))\n", " return FrameStackObservation(env, stack_size=4)\n" ] }, { "cell_type": "markdown", "id": "18cb28d8", "metadata": {}, "source": [ "## Training & Evaluation Infrastructure\n", "\n", "Functions for training and evaluating agents in the single-agent Gymnasium environment:\n", "\n", "1. **`train_agent`** — Pre-trains an agent against the built-in AI for a given number of episodes with ε-greedy exploration\n", "2. **`evaluate_agent`** — Evaluates a trained agent (no exploration, ε = 0) and returns performance metrics\n", "3. **`plot_training_curves`** — Plots the training reward history (moving average) for all agents\n", "4. **`plot_evaluation_comparison`** — Bar chart comparing final evaluation scores across agents\n", "5. **`evaluate_tournament`** — Evaluates all agents and produces a summary comparison" ] }, { "cell_type": "code", "execution_count": 10, "id": "06b91580", "metadata": {}, "outputs": [], "source": [ "def train_agent(\n", " env: gym.Env,\n", " agent: Agent,\n", " name: str,\n", " *,\n", " episodes: int = 5000,\n", " epsilon_start: float = 1.0,\n", " epsilon_end: float = 0.05,\n", " epsilon_decay: float = 0.999,\n", " max_steps: int = 5000,\n", ") -> list[float]:\n", " \"\"\"Pre-train an agent against the built-in Atari AI opponent.\n", "\n", " Each agent learns independently by playing full episodes. This is the\n", " self-play pre-training phase: the agent interacts with the environment's\n", " built-in opponent and updates its parameters after each transition.\n", "\n", " Args:\n", " env: Gymnasium ALE/Tennis-v5 environment.\n", " agent: Agent instance to train.\n", " name: Display name for the progress bar.\n", " episodes: Number of training episodes.\n", " epsilon_start: Initial exploration rate.\n", " epsilon_end: Minimum exploration rate.\n", " epsilon_decay: Multiplicative decay per episode.\n", " max_steps: Maximum steps per episode.\n", "\n", " Returns:\n", " List of total rewards per episode.\n", "\n", " \"\"\"\n", " rewards_history: list[float] = []\n", " epsilon = epsilon_start\n", "\n", " pbar = tqdm(range(episodes), desc=f\"Training {name}\", leave=True)\n", "\n", " for _ep in pbar:\n", " obs, _info = env.reset()\n", " obs = np.asarray(obs)\n", " total_reward = 0.0\n", "\n", " # Select first action\n", " action = agent.get_action(obs, epsilon=epsilon)\n", "\n", " for _step in range(max_steps):\n", " next_obs, reward, terminated, truncated, _info = env.step(action)\n", " next_obs = np.asarray(next_obs)\n", " done = terminated or truncated\n", " reward = float(reward)\n", " total_reward += reward\n", "\n", " # Select next action (needed for SARSA's on-policy update)\n", " next_action = agent.get_action(next_obs, epsilon=epsilon) if not done else None\n", "\n", " # Update agent with the transition\n", " agent.update(\n", " state=obs,\n", " action=action,\n", " reward=reward,\n", " next_state=next_obs,\n", " done=done,\n", " next_action=next_action,\n", " )\n", "\n", " if done:\n", " break\n", "\n", " obs = next_obs\n", " action = next_action\n", "\n", " rewards_history.append(total_reward)\n", " epsilon = max(epsilon_end, epsilon * epsilon_decay)\n", "\n", " # Update progress bar\n", " recent_window = 50\n", " if len(rewards_history) >= recent_window:\n", " recent_avg = np.mean(rewards_history[-recent_window:])\n", " pbar.set_postfix(\n", " avg50=f\"{recent_avg:.1f}\",\n", " eps=f\"{epsilon:.3f}\",\n", " rew=f\"{total_reward:.0f}\",\n", " )\n", "\n", " return rewards_history\n", "\n", "\n", "def evaluate_agent(\n", " env: gym.Env,\n", " agent: Agent,\n", " name: str,\n", " *,\n", " episodes: int = 20,\n", " max_steps: int = 5000,\n", ") -> dict[str, object]:\n", " \"\"\"Evaluate a trained agent with no exploration (ε = 0).\n", "\n", " Args:\n", " env: Gymnasium ALE/Tennis-v5 environment.\n", " agent: Trained agent to evaluate.\n", " name: Display name for the progress bar.\n", " episodes: Number of evaluation episodes.\n", " max_steps: Maximum steps per episode.\n", "\n", " Returns:\n", " Dictionary with rewards list, mean, std, wins, and win rate.\n", "\n", " \"\"\"\n", " rewards: list[float] = []\n", " wins = 0\n", "\n", " for _ep in tqdm(range(episodes), desc=f\"Evaluating {name}\", leave=False):\n", " obs, _info = env.reset()\n", " total_reward = 0.0\n", "\n", " for _step in range(max_steps):\n", " action = agent.get_action(np.asarray(obs), epsilon=0.0)\n", " obs, reward, terminated, truncated, _info = env.step(action)\n", " reward = float(reward)\n", " total_reward += reward\n", " if terminated or truncated:\n", " break\n", "\n", " rewards.append(total_reward)\n", " if total_reward > 0:\n", " wins += 1\n", "\n", " return {\n", " \"rewards\": rewards,\n", " \"mean_reward\": float(np.mean(rewards)),\n", " \"std_reward\": float(np.std(rewards)),\n", " \"wins\": wins,\n", " \"win_rate\": wins / episodes,\n", " }\n", "\n", "\n", "def plot_training_curves(\n", " training_histories: dict[str, list[float]],\n", " path: str,\n", " window: int = 100,\n", ") -> None:\n", " \"\"\"Plot training reward curves for all agents on a single figure.\n", "\n", " Uses a moving average to smooth the curves.\n", "\n", " Args:\n", " training_histories: Dict mapping agent names to reward lists.\n", " path: File path to save the plot image.\n", " window: Moving average window size.\n", "\n", " \"\"\"\n", " plt.figure(figsize=(12, 6))\n", "\n", " for name, rewards in training_histories.items():\n", " if len(rewards) >= window:\n", " ma = np.convolve(rewards, np.ones(window) / window, mode=\"valid\")\n", " plt.plot(np.arange(window - 1, len(rewards)), ma, label=name)\n", " else:\n", " plt.plot(rewards, label=f\"{name} (raw)\")\n", "\n", " plt.xlabel(\"Episodes\")\n", " plt.ylabel(f\"Average Reward (Window={window})\")\n", " plt.title(\"Training Curves (vs built-in AI)\")\n", " plt.legend()\n", " plt.grid(visible=True)\n", " plt.tight_layout()\n", " plt.savefig(path)\n", " plt.show()\n", "\n", "\n", "def plot_evaluation_comparison(results: dict[str, dict[str, object]]) -> None:\n", " \"\"\"Bar chart comparing evaluation performance of all agents.\n", "\n", " Args:\n", " results: Dict mapping agent names to evaluation result dicts.\n", "\n", " \"\"\"\n", " names = list(results.keys())\n", " means = [results[n][\"mean_reward\"] for n in names]\n", " stds = [results[n][\"std_reward\"] for n in names]\n", " win_rates = [results[n][\"win_rate\"] for n in names]\n", "\n", " _fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", "\n", " # Mean reward bar chart\n", " colors = sns.color_palette(\"husl\", len(names))\n", " axes[0].bar(names, means, yerr=stds, capsize=5, color=colors, edgecolor=\"black\")\n", " axes[0].set_ylabel(\"Mean Reward\")\n", " axes[0].set_title(\"Evaluation: Mean Reward per Agent (vs built-in AI)\")\n", " axes[0].axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.5)\n", " axes[0].grid(axis=\"y\", alpha=0.3)\n", "\n", " # Win rate bar chart\n", " axes[1].bar(names, win_rates, color=colors, edgecolor=\"black\")\n", " axes[1].set_ylabel(\"Win Rate\")\n", " axes[1].set_title(\"Evaluation: Win Rate per Agent (vs built-in AI)\")\n", " axes[1].set_ylim(0, 1)\n", " axes[1].axhline(y=0.5, color=\"gray\", linestyle=\"--\", alpha=0.5, label=\"50% baseline\")\n", " axes[1].legend()\n", " axes[1].grid(axis=\"y\", alpha=0.3)\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "\n", "def evaluate_tournament(\n", " env: gym.Env,\n", " agents: dict[str, Agent],\n", " episodes_per_agent: int = 20,\n", ") -> dict[str, dict[str, object]]:\n", " \"\"\"Evaluate all agents against the built-in AI and produce a comparison.\n", "\n", " Args:\n", " env: Gymnasium ALE/Tennis-v5 environment.\n", " agents: Dictionary mapping agent names to Agent instances.\n", " episodes_per_agent: Number of evaluation episodes per agent.\n", "\n", " Returns:\n", " Dict mapping agent names to their evaluation results.\n", "\n", " \"\"\"\n", " results: dict[str, dict[str, object]] = {}\n", " n_agents = len(agents)\n", "\n", " for idx, (name, agent) in enumerate(agents.items(), start=1):\n", " print(f\"[Evaluation {idx}/{n_agents}] {name}\")\n", " results[name] = evaluate_agent(\n", " env, agent, name, episodes=episodes_per_agent,\n", " )\n", " mean_r = results[name][\"mean_reward\"]\n", " wr = results[name][\"win_rate\"]\n", " print(f\" -> Mean reward: {mean_r:.2f} | Win rate: {wr:.1%}\\n\")\n", "\n", " return results\n" ] }, { "cell_type": "markdown", "id": "9605e9c4", "metadata": {}, "source": [ "## Agent Instantiation & Incremental Training (One Agent at a Time)\n", "\n", "**Environment**: `ALE/Tennis-v5` (grayscale, 84×84×4 frames → 28,224 features, 18 actions).\n", "\n", "**Agents**:\n", "- **Random** — random baseline (no training needed)\n", "- **SARSA** — linear approximation, semi-gradient TD(0)\n", "- **Q-Learning** — linear approximation, off-policy\n", "- **Monte Carlo** — linear approximation, first-visit returns\n", "\n", "**Workflow**:\n", "1. Train **one** selected agent (`AGENT_TO_TRAIN`)\n", "2. Save its weights to `checkpoints/` (`.pkl`)\n", "3. Repeat later for another agent without retraining previous ones\n", "4. Load all saved checkpoints before the final evaluation" ] }, { "cell_type": "code", "execution_count": 11, "id": "6f6ba8df", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "A.L.E: Arcade Learning Environment (version 0.11.2+ecc1138)\n", "[Powered by Stella]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Observation shape : (4, 84, 84)\n", "Feature vector dim: 28224\n", "Number of actions : 18\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "objc[49878]: Class SDLApplication is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d2c8) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418890). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class SDLAppDelegate is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d318) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x1244188e0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class SDLTranslatorResponder is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d390) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418958). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class SDLMessageBoxPresenter is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d3b8) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418980). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class SDL_cocoametalview is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d408) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x1244189d0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class SDLOpenGLContext is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d458) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418a20). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class SDL_ShapeData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d4d0) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418a98). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class SDL_CocoaClosure is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d520) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418ae8). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class SDL_VideoData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d570) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418b38). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class SDL_WindowData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d5c0) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418b88). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class SDLWindow is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d5e8) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418bb0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class Cocoa_WindowListener is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d610) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418bd8). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class SDLView is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d688) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418c50). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class METAL_RenderData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d700) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418cc8). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class METAL_TextureData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d750) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418d18). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class SDL_RumbleMotor is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d778) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418d40). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[49878]: Class SDL_RumbleContext is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d7c8) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418d90). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n" ] } ], "source": [ "# Create environment\n", "env = create_env()\n", "obs, _info = env.reset()\n", "\n", "n_actions = int(env.action_space.n)\n", "n_features = int(np.prod(obs.shape))\n", "\n", "print(f\"Observation shape : {obs.shape}\")\n", "print(f\"Feature vector dim: {n_features}\")\n", "print(f\"Number of actions : {n_actions}\")\n", "\n", "# Instantiate agents\n", "agent_random = RandomAgent(seed=42, action_space=int(n_actions))\n", "agent_sarsa = SarsaAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n", "agent_q = QLearningAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n", "agent_mc = MonteCarloAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n", "\n", "agents = {\n", " \"Random\": agent_random,\n", " \"SARSA\": agent_sarsa,\n", " \"Q-Learning\": agent_q,\n", " \"Monte Carlo\": agent_mc,\n", "}\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4d449701", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Selected agent: Monte Carlo\n", "Checkpoint path: checkpoints/monte_carlo.pkl\n", "\n", "============================================================\n", "Training: Monte Carlo (2500 episodes)\n", "============================================================\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0d5d098d18014fe6b736683e0b8b2488", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training Monte Carlo: 0%| | 0/2500 [00:00 weights loaded, training skipped.\")\n", " training_histories[AGENT_TO_TRAIN] = []\n", "else:\n", " print(f\"\\n{'='*60}\")\n", " print(f\"Training: {AGENT_TO_TRAIN} ({TRAINING_EPISODES} episodes)\")\n", " print(f\"{'='*60}\")\n", "\n", " training_histories[AGENT_TO_TRAIN] = train_agent(\n", " env=env,\n", " agent=agent,\n", " name=AGENT_TO_TRAIN,\n", " episodes=TRAINING_EPISODES,\n", " epsilon_start=1.0,\n", " epsilon_end=0.05,\n", " epsilon_decay=0.999,\n", " )\n", "\n", " avg_last_100 = np.mean(training_histories[AGENT_TO_TRAIN][-100:])\n", " print(f\"-> {AGENT_TO_TRAIN} avg reward (last 100 eps): {avg_last_100:.2f}\")\n", "\n", " agent.save(str(checkpoint_path))\n", " print(\"Checkpoint saved.\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a13a65df", "metadata": {}, "outputs": [], "source": [ "plot_training_curves(\n", " training_histories, f\"plots/{AGENT_TO_TRAIN}_training_curves.png\", window=100,\n", ")\n" ] }, { "cell_type": "markdown", "id": "0e5f2c49", "metadata": {}, "source": [ "## Final Evaluation\n", "\n", "Each agent plays 20 episodes against the built-in AI with no exploration (ε = 0).\n", "Performance is compared via mean reward and win rate." ] }, { "cell_type": "code", "execution_count": null, "id": "70f5d5cd", "metadata": {}, "outputs": [], "source": [ "# Build evaluation set: Random + agents with existing checkpoints\n", "eval_agents: dict[str, Agent] = {\"Random\": agents[\"Random\"]}\n", "missing_agents: list[str] = []\n", "\n", "for name, agent in agents.items():\n", " if name == \"Random\":\n", " continue\n", "\n", " checkpoint_path = get_path(name)\n", "\n", " if checkpoint_path.exists():\n", " agent.load(str(checkpoint_path))\n", " eval_agents[name] = agent\n", " else:\n", " missing_agents.append(name)\n", "\n", "print(f\"Agents evaluated: {list(eval_agents.keys())}\")\n", "if missing_agents:\n", " print(f\"Skipped (no checkpoint yet): {missing_agents}\")\n", "\n", "if len(eval_agents) < 2:\n", " raise RuntimeError(\"Train at least one non-random agent before final evaluation.\")\n", "\n", "results = evaluate_tournament(env, eval_agents, episodes_per_agent=20)\n", "plot_evaluation_comparison(results)\n", "\n", "# Print summary table\n", "print(f\"\\n{'Agent':<15} {'Mean Reward':>12} {'Std':>8} {'Win Rate':>10}\")\n", "print(\"-\" * 48)\n", "for name, res in results.items():\n", " print(f\"{name:<15} {res['mean_reward']:>12.2f} {res['std_reward']:>8.2f} {res['win_rate']:>9.1%}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d04d37e0", "metadata": {}, "outputs": [], "source": [ "def create_tournament_env():\n", " \"\"\"Create PettingZoo Tennis env with preprocessing compatible with our agents.\"\"\"\n", " env = tennis_v3.env(obs_type=\"rgb_image\")\n", " env = ss.color_reduction_v0(env, mode=\"full\")\n", " env = ss.resize_v1(env, x_size=84, y_size=84)\n", " return ss.frame_stack_v1(env, 4)\n", "\n", "\n", "def run_pz_match(\n", " env,\n", " agent_first: Agent,\n", " agent_second: Agent,\n", " episodes: int = 10,\n", " max_steps: int = 4000,\n", ") -> dict[str, int]:\n", " \"\"\"Run multiple PettingZoo episodes between two agents.\n", "\n", " Returns wins for global labels {'first': ..., 'second': ..., 'draw': ...}.\n", " \"\"\"\n", " wins = {\"first\": 0, \"second\": 0, \"draw\": 0}\n", "\n", " for _ep in range(episodes):\n", " env.reset()\n", " rewards = {\"first_0\": 0.0, \"second_0\": 0.0}\n", "\n", " for step_idx, agent_id in enumerate(env.agent_iter()):\n", " obs, reward, termination, truncation, _info = env.last()\n", " done = termination or truncation\n", " rewards[agent_id] += float(reward)\n", "\n", " if done or step_idx >= max_steps:\n", " action = None\n", " else:\n", " current_agent = agent_first if agent_id == \"first_0\" else agent_second\n", " action = current_agent.get_action(np.asarray(obs), epsilon=0.0)\n", "\n", " env.step(action)\n", "\n", " if step_idx + 1 >= max_steps:\n", " break\n", "\n", " if rewards[\"first_0\"] > rewards[\"second_0\"]:\n", " wins[\"first\"] += 1\n", " elif rewards[\"second_0\"] > rewards[\"first_0\"]:\n", " wins[\"second\"] += 1\n", " else:\n", " wins[\"draw\"] += 1\n", "\n", " return wins\n", "\n", "\n", "def run_pettingzoo_tournament(\n", " agents: dict[str, Agent],\n", " episodes_per_side: int = 10,\n", ") -> tuple[np.ndarray, list[str]]:\n", " \"\"\"Round-robin tournament excluding Random, with seat-swap fairness.\"\"\"\n", " _ = itertools # kept for notebook context consistency\n", " candidate_names = [name for name in agents if name != \"Random\"]\n", "\n", " # Keep only agents that have a checkpoint\n", " ready_names: list[str] = []\n", " for name in candidate_names:\n", " checkpoint_path = get_path(name)\n", " if checkpoint_path.exists():\n", " agents[name].load(str(checkpoint_path))\n", " ready_names.append(name)\n", "\n", " if len(ready_names) < 2:\n", " msg = \"Need at least 2 trained (checkpointed) non-random agents for PettingZoo tournament.\"\n", " raise RuntimeError(msg)\n", "\n", " n = len(ready_names)\n", " win_matrix = np.full((n, n), np.nan)\n", " np.fill_diagonal(win_matrix, 0.5)\n", "\n", " for i in range(n):\n", " for j in range(i + 1, n):\n", " name_i = ready_names[i]\n", " name_j = ready_names[j]\n", "\n", " print(f\"Matchup: {name_i} vs {name_j}\")\n", " env = create_tournament_env()\n", "\n", " # Leg 1: i as first_0, j as second_0\n", " leg1 = run_pz_match(\n", " env,\n", " agent_first=agents[name_i],\n", " agent_second=agents[name_j],\n", " episodes=episodes_per_side,\n", " )\n", "\n", " # Leg 2: swap seats\n", " leg2 = run_pz_match(\n", " env,\n", " agent_first=agents[name_j],\n", " agent_second=agents[name_i],\n", " episodes=episodes_per_side,\n", " )\n", "\n", " env.close()\n", "\n", " wins_i = leg1[\"first\"] + leg2[\"second\"]\n", " wins_j = leg1[\"second\"] + leg2[\"first\"]\n", "\n", " decisive = wins_i + wins_j\n", " if decisive == 0:\n", " wr_i = 0.5\n", " wr_j = 0.5\n", " else:\n", " wr_i = wins_i / decisive\n", " wr_j = wins_j / decisive\n", "\n", " win_matrix[i, j] = wr_i\n", " win_matrix[j, i] = wr_j\n", "\n", " print(f\" -> {name_i}: {wins_i} wins | {name_j}: {wins_j} wins\\n\")\n", "\n", " return win_matrix, ready_names\n", "\n", "\n", "# Run tournament (non-random agents only)\n", "win_matrix_pz, pz_names = run_pettingzoo_tournament(\n", " agents=agents,\n", " episodes_per_side=10,\n", ")\n", "\n", "# Plot win-rate matrix\n", "plt.figure(figsize=(8, 6))\n", "sns.heatmap(\n", " win_matrix_pz,\n", " annot=True,\n", " fmt=\".2f\",\n", " cmap=\"Blues\",\n", " vmin=0.0,\n", " vmax=1.0,\n", " xticklabels=pz_names,\n", " yticklabels=pz_names,\n", ")\n", "plt.xlabel(\"Opponent\")\n", "plt.ylabel(\"Agent\")\n", "plt.title(\"PettingZoo Tournament Win Rate Matrix (Non-random agents)\")\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "# Rank agents by mean win rate vs others (excluding diagonal)\n", "scores = {}\n", "for idx, name in enumerate(pz_names):\n", " row = np.delete(win_matrix_pz[idx], idx)\n", " scores[name] = float(np.mean(row))\n", "\n", "ranking = sorted(scores.items(), key=lambda x: x[1], reverse=True)\n", "print(\"Final ranking (PettingZoo tournament, non-random):\")\n", "for rank_idx, (name, score) in enumerate(ranking, start=1):\n", " print(f\"{rank_idx}. {name:<12} | mean win rate: {score:.3f}\")\n", "\n", "print(f\"\\nBest agent: {ranking[0][0]}\")\n" ] }, { "cell_type": "markdown", "id": "3f8b300d", "metadata": {}, "source": [ "## PettingZoo Tournament (Agents vs Agents)\n", "\n", "This tournament uses `from pettingzoo.atari import tennis_v3` to make trained agents play against each other directly.\n", "\n", "- Checkpoints are loaded from `checkpoints/` (`.pkl`)\n", "- `Random` is **excluded** from ranking\n", "- Each pair plays in both seat positions (`first_0` and `second_0`) to reduce position bias\n", "- A win-rate matrix and final ranking are produced" ] }, { "cell_type": "markdown", "id": "150e6764", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "studies (3.13.9)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.9" } }, "nbformat": 4, "nbformat_minor": 5 }