mirror of
https://github.com/ArthurDanjou/ArtStudies.git
synced 2026-03-16 07:10:13 +01:00
1893 lines
109 KiB
Plaintext
1893 lines
109 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "fef45687",
|
||
"metadata": {},
|
||
"source": [
|
||
"# RL Project: Atari Tennis Tournament\n",
|
||
"\n",
|
||
"This notebook implements four Reinforcement Learning algorithms to play Atari Tennis (`ALE/Tennis-v5` via Gymnasium):\n",
|
||
"\n",
|
||
"1. **SARSA** — Semi-gradient SARSA with linear approximation (inspired by Lab 7, on-policy update from Lab 5B)\n",
|
||
"2. **Q-Learning** — Off-policy linear approximation (inspired by Lab 5B)\n",
|
||
"\n",
|
||
"Each agent is **pre-trained independently** against the built-in Atari AI opponent, then evaluated in a comparative tournament."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 43,
|
||
"id": "b50d7174",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Using device: mps\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import pickle\n",
|
||
"from collections import deque\n",
|
||
"from pathlib import Path\n",
|
||
"\n",
|
||
"import ale_py # noqa: F401 — registers ALE environments\n",
|
||
"import gymnasium as gym\n",
|
||
"import supersuit as ss\n",
|
||
"from gymnasium.wrappers import FrameStackObservation, ResizeObservation\n",
|
||
"from pettingzoo.atari import tennis_v3\n",
|
||
"from tqdm.auto import tqdm\n",
|
||
"\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"import torch\n",
|
||
"from torch import nn, optim\n",
|
||
"\n",
|
||
"DEVICE = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
"print(f\"Using device: {DEVICE}\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "86047166",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Configuration & Checkpoints\n",
|
||
"\n",
|
||
"We use a **checkpoint** system (`pickle` serialization) to save and restore trained agent weights. This enables an incremental workflow:\n",
|
||
"- Train one agent at a time and save its weights\n",
|
||
"- Resume later without retraining previous agents\n",
|
||
"- Load all checkpoints for the final evaluation"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 44,
|
||
"id": "ff3486a4",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"CHECKPOINT_DIR = Path(\"checkpoints\")\n",
|
||
"CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_path(name: str) -> Path:\n",
|
||
" \"\"\"Return the checkpoint path for an agent (.pkl).\"\"\"\n",
|
||
" base = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\")\n",
|
||
" return CHECKPOINT_DIR / (base + \".pkl\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ec691487",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Utility Functions\n",
|
||
"\n",
|
||
"## Observation Normalization\n",
|
||
"\n",
|
||
"The Tennis environment produces image observations of shape `(4, 84, 84)` after preprocessing (grayscale + resize + frame stack).\n",
|
||
"We normalize them into 1D `float64` vectors divided by 255, as in Lab 7 (continuous feature normalization).\n",
|
||
"\n",
|
||
"## ε-greedy Policy\n",
|
||
"\n",
|
||
"Follows the pattern from Lab 5B (`epsilon_greedy`) and Lab 7 (`epsilon_greedy_action`):\n",
|
||
"- With probability ε: random action (exploration)\n",
|
||
"- With probability 1−ε: action maximizing $\\hat{q}(s, a)$ with uniform tie-breaking (`np.flatnonzero`)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 45,
|
||
"id": "be85c130",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def normalize_obs(observation: np.ndarray) -> np.ndarray:\n",
|
||
" \"\"\"Flatten and normalize an observation to a 1D float64 vector.\n",
|
||
"\n",
|
||
" Replicates the /255.0 normalization used in all agents from the original project.\n",
|
||
" For image observations of shape (4, 84, 84), this produces a vector of length 28_224.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" observation: Raw observation array from the environment.\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" 1D numpy array of dtype float64, values in [0, 1].\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" return observation.flatten().astype(np.float64) / 255.0\n",
|
||
"\n",
|
||
"\n",
|
||
"def epsilon_greedy(\n",
|
||
" q_values: np.ndarray,\n",
|
||
" epsilon: float,\n",
|
||
" rng: np.random.Generator,\n",
|
||
") -> int:\n",
|
||
" \"\"\"Select an action using an ε-greedy policy with fair tie-breaking.\n",
|
||
"\n",
|
||
" Follows the same logic as Lab 5B epsilon_greedy and Lab 7 epsilon_greedy_action:\n",
|
||
" - With probability epsilon: choose a random action (exploration).\n",
|
||
" - With probability 1-epsilon: choose the action with highest Q-value (exploitation).\n",
|
||
" - If multiple actions share the maximum Q-value, break ties uniformly at random.\n",
|
||
"\n",
|
||
" Handles edge cases: empty q_values, NaN/Inf values.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" q_values: Array of Q-values for each action, shape (n_actions,).\n",
|
||
" epsilon: Exploration probability in [0, 1].\n",
|
||
" rng: NumPy random number generator.\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" Selected action index.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" q_values = np.asarray(q_values, dtype=np.float64).reshape(-1)\n",
|
||
"\n",
|
||
" if q_values.size == 0:\n",
|
||
" msg = \"q_values is empty.\"\n",
|
||
" raise ValueError(msg)\n",
|
||
"\n",
|
||
" if rng.random() < epsilon:\n",
|
||
" return int(rng.integers(0, q_values.size))\n",
|
||
"\n",
|
||
" # Handle NaN/Inf values safely\n",
|
||
" finite_mask = np.isfinite(q_values)\n",
|
||
" if not np.any(finite_mask):\n",
|
||
" return int(rng.integers(0, q_values.size))\n",
|
||
"\n",
|
||
" safe_q = q_values.copy()\n",
|
||
" safe_q[~finite_mask] = -np.inf\n",
|
||
" max_val = np.max(safe_q)\n",
|
||
" best = np.flatnonzero(safe_q == max_val)\n",
|
||
"\n",
|
||
" if best.size == 0:\n",
|
||
" return int(rng.integers(0, q_values.size))\n",
|
||
"\n",
|
||
" return int(rng.choice(best))\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "bb53da28",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Agent Definitions\n",
|
||
"\n",
|
||
"## Base Class `Agent`\n",
|
||
"\n",
|
||
"Common interface for all agents, same signatures: `get_action`, `update`, `save`, `load`.\n",
|
||
"Serialization uses `pickle` (compatible with numpy arrays)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 46,
|
||
"id": "ded9b1fb",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class Agent:\n",
|
||
" \"\"\"Base class for reinforcement learning agents.\n",
|
||
"\n",
|
||
" All agents share this interface so they are compatible with the tournament system.\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" def __init__(self, seed: int, action_space: int) -> None:\n",
|
||
" \"\"\"Initialize the agent with its action space and a reproducible RNG.\"\"\"\n",
|
||
" self.action_space = action_space\n",
|
||
" self.rng = np.random.default_rng(seed=seed)\n",
|
||
"\n",
|
||
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
|
||
" \"\"\"Select an action from the current observation.\"\"\"\n",
|
||
" raise NotImplementedError\n",
|
||
"\n",
|
||
" def update(\n",
|
||
" self,\n",
|
||
" state: np.ndarray,\n",
|
||
" action: int,\n",
|
||
" reward: float,\n",
|
||
" next_state: np.ndarray,\n",
|
||
" done: bool,\n",
|
||
" next_action: int | None = None,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Update agent parameters from one transition.\"\"\"\n",
|
||
"\n",
|
||
" def save(self, filename: str) -> None:\n",
|
||
" \"\"\"Save the agent state to disk using pickle.\"\"\"\n",
|
||
" with Path(filename).open(\"wb\") as f:\n",
|
||
" pickle.dump(self.__dict__, f)\n",
|
||
"\n",
|
||
" def load(self, filename: str) -> None:\n",
|
||
" \"\"\"Load the agent state from disk.\"\"\"\n",
|
||
" with Path(filename).open(\"rb\") as f:\n",
|
||
" self.__dict__.update(pickle.load(f)) # noqa: S301\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8a4eae79",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Random Agent (baseline)\n",
|
||
"\n",
|
||
"Serves as a reference to evaluate the performance of learning agents."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 47,
|
||
"id": "78bdc9d2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class RandomAgent(Agent):\n",
|
||
" \"\"\"A simple agent that selects actions uniformly at random (baseline).\"\"\"\n",
|
||
"\n",
|
||
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
|
||
" \"\"\"Select a random action, ignoring the observation and epsilon.\"\"\"\n",
|
||
" _ = observation, epsilon\n",
|
||
" return int(self.rng.integers(0, self.action_space))\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "5f679032",
|
||
"metadata": {},
|
||
"source": [
|
||
"## SARSA Agent — Linear Approximation (Semi-gradient)\n",
|
||
"\n",
|
||
"This agent combines:\n",
|
||
"- **Linear approximation** from Lab 7 (`SarsaAgent`): $\\hat{q}(s, a; \\mathbf{W}) = \\mathbf{W}_a^\\top \\phi(s)$\n",
|
||
"- **On-policy SARSA update** from Lab 5B (`train_sarsa`): $\\delta = r + \\gamma \\hat{q}(s', a') - \\hat{q}(s, a)$\n",
|
||
"\n",
|
||
"The semi-gradient update rule is:\n",
|
||
"$$W_a \\leftarrow W_a + \\alpha \\cdot \\delta \\cdot \\phi(s)$$\n",
|
||
"\n",
|
||
"where $\\phi(s)$ is the normalized observation vector (analogous to tile coding features in Lab 7, but in dense form)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 48,
|
||
"id": "c124ed9a",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class SarsaAgent(Agent):\n",
|
||
" \"\"\"Semi-gradient SARSA agent with linear function approximation.\n",
|
||
"\n",
|
||
" Inspired by:\n",
|
||
" - Lab 7 SarsaAgent: linear q(s,a) = W_a . phi(s), semi-gradient update\n",
|
||
" - Lab 5B train_sarsa: on-policy TD target using Q(s', a')\n",
|
||
"\n",
|
||
" The weight matrix W has shape (n_actions, n_features).\n",
|
||
" For a given state s, q(s, a) = W[a] @ phi(s) is the dot product\n",
|
||
" of the action's weight row with the normalized observation.\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" def __init__(\n",
|
||
" self,\n",
|
||
" n_features: int,\n",
|
||
" n_actions: int,\n",
|
||
" alpha: float = 0.001,\n",
|
||
" gamma: float = 0.99,\n",
|
||
" seed: int = 42,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Initialize SARSA agent with linear weights.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" n_features: Dimension of the feature vector phi(s).\n",
|
||
" n_actions: Number of discrete actions.\n",
|
||
" alpha: Learning rate (kept small for high-dim features).\n",
|
||
" gamma: Discount factor.\n",
|
||
" seed: RNG seed for reproducibility.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" super().__init__(seed, n_actions)\n",
|
||
" self.n_features = n_features\n",
|
||
" self.alpha = alpha\n",
|
||
" self.gamma = gamma\n",
|
||
" # Weight matrix: one row per action, analogous to Lab 7's self.w\n",
|
||
" # but organized as (n_actions, n_features) for dense features.\n",
|
||
" self.W = np.zeros((n_actions, n_features), dtype=np.float64)\n",
|
||
"\n",
|
||
" def _q_values(self, phi: np.ndarray) -> np.ndarray:\n",
|
||
" \"\"\"Compute Q-values for all actions given feature vector phi(s).\n",
|
||
"\n",
|
||
" Equivalent to Lab 7's self.q(s, a) = self.w[idx].sum()\n",
|
||
" but using dense linear approximation: q(s, a) = W[a] @ phi.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" phi: Normalized feature vector, shape (n_features,).\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" Array of Q-values, shape (n_actions,).\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" return self.W @ phi # shape (n_actions,)\n",
|
||
"\n",
|
||
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
|
||
" \"\"\"Select action using ε-greedy policy over linear Q-values.\n",
|
||
"\n",
|
||
" Same pattern as Lab 7 SarsaAgent.eps_greedy:\n",
|
||
" compute q-values for all actions, then apply epsilon_greedy.\n",
|
||
" \"\"\"\n",
|
||
" phi = normalize_obs(observation)\n",
|
||
" q_vals = self._q_values(phi)\n",
|
||
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
|
||
"\n",
|
||
" def update(\n",
|
||
" self,\n",
|
||
" state: np.ndarray,\n",
|
||
" action: int,\n",
|
||
" reward: float,\n",
|
||
" next_state: np.ndarray,\n",
|
||
" done: bool,\n",
|
||
" next_action: int | None = None,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Perform one semi-gradient SARSA update.\n",
|
||
"\n",
|
||
" Follows the SARSA update from Lab 5B train_sarsa:\n",
|
||
" td_target = r + gamma * Q(s', a') * (0 if done else 1)\n",
|
||
" Q(s, a) += alpha * (td_target - Q(s, a))\n",
|
||
"\n",
|
||
" In continuous form with linear approximation (Lab 7 SarsaAgent.update):\n",
|
||
" delta = target - q(s, a)\n",
|
||
" W[a] += alpha * delta * phi(s)\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" state: Current observation.\n",
|
||
" action: Action taken.\n",
|
||
" reward: Reward received.\n",
|
||
" next_state: Next observation.\n",
|
||
" done: Whether the episode ended.\n",
|
||
" next_action: Action chosen in next state (required for SARSA).\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" phi = np.nan_to_num(normalize_obs(state), nan=0.0, posinf=0.0, neginf=0.0)\n",
|
||
" q_sa = float(self.W[action] @ phi) # current estimate q(s, a)\n",
|
||
" if not np.isfinite(q_sa):\n",
|
||
" q_sa = 0.0\n",
|
||
"\n",
|
||
" if done:\n",
|
||
" # Terminal: no future value (Lab 5B: gamma * Q[s2, a2] * 0)\n",
|
||
" target = reward\n",
|
||
" else:\n",
|
||
" # On-policy: use q(s', a') where a' is the actual next action\n",
|
||
" # This is the key SARSA property (Lab 5B)\n",
|
||
" phi_next = np.nan_to_num(normalize_obs(next_state), nan=0.0, posinf=0.0, neginf=0.0)\n",
|
||
" if next_action is None:\n",
|
||
" next_action = 0 # fallback, should not happen in practice\n",
|
||
" q_sp_ap = float(self.W[next_action] @ phi_next)\n",
|
||
" if not np.isfinite(q_sp_ap):\n",
|
||
" q_sp_ap = 0.0\n",
|
||
" target = float(reward) + self.gamma * q_sp_ap\n",
|
||
"\n",
|
||
" # Semi-gradient update: W[a] += alpha * delta * phi(s)\n",
|
||
" # Analogous to Lab 7: self.w[idx] += self.alpha * delta\n",
|
||
" if not np.isfinite(target):\n",
|
||
" return\n",
|
||
"\n",
|
||
" delta = float(target - q_sa)\n",
|
||
" if not np.isfinite(delta):\n",
|
||
" return\n",
|
||
"\n",
|
||
" td_step = float(np.clip(delta, -1_000.0, 1_000.0))\n",
|
||
" self.W[action] += self.alpha * td_step * phi\n",
|
||
" self.W[action] = np.nan_to_num(self.W[action], nan=0.0, posinf=1e6, neginf=-1e6)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "d4e18536",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Q-Learning Agent — Linear Approximation (Off-policy)\n",
|
||
"\n",
|
||
"Same architecture as SARSA but with the **off-policy update** from Lab 5B (`train_q_learning`):\n",
|
||
"\n",
|
||
"$$\\delta = r + \\gamma \\max_{a'} \\hat{q}(s', a') - \\hat{q}(s, a)$$\n",
|
||
"\n",
|
||
"The key difference from SARSA: we use $\\max_{a'} Q(s', a')$ instead of $Q(s', a')$ where $a'$ is the action actually chosen. This allows learning the optimal policy independently of the exploration policy."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 49,
|
||
"id": "f5b5b9ea",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class QLearningAgent(Agent):\n",
|
||
" \"\"\"Q-Learning agent with linear function approximation (off-policy).\n",
|
||
"\n",
|
||
" Inspired by:\n",
|
||
" - Lab 5B train_q_learning: off-policy TD target using max_a' Q(s', a')\n",
|
||
" - Lab 7 SarsaAgent: linear approximation q(s,a) = W[a] @ phi(s)\n",
|
||
"\n",
|
||
" The only difference from SarsaAgent is the TD target:\n",
|
||
" SARSA uses Q(s', a') (on-policy), Q-Learning uses max_a' Q(s', a') (off-policy).\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" def __init__(\n",
|
||
" self,\n",
|
||
" n_features: int,\n",
|
||
" n_actions: int,\n",
|
||
" alpha: float = 0.001,\n",
|
||
" gamma: float = 0.99,\n",
|
||
" seed: int = 42,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Initialize Q-Learning agent with linear weights.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" n_features: Dimension of the feature vector phi(s).\n",
|
||
" n_actions: Number of discrete actions.\n",
|
||
" alpha: Learning rate.\n",
|
||
" gamma: Discount factor.\n",
|
||
" seed: RNG seed.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" super().__init__(seed, n_actions)\n",
|
||
" self.n_features = n_features\n",
|
||
" self.alpha = alpha\n",
|
||
" self.gamma = gamma\n",
|
||
" self.W = np.zeros((n_actions, n_features), dtype=np.float64)\n",
|
||
"\n",
|
||
" def _q_values(self, phi: np.ndarray) -> np.ndarray:\n",
|
||
" \"\"\"Compute Q-values for all actions: q(s, a) = W[a] @ phi for each a.\"\"\"\n",
|
||
" return self.W @ phi\n",
|
||
"\n",
|
||
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
|
||
" \"\"\"Select action using ε-greedy policy over linear Q-values.\"\"\"\n",
|
||
" phi = normalize_obs(observation)\n",
|
||
" q_vals = self._q_values(phi)\n",
|
||
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
|
||
"\n",
|
||
" def update(\n",
|
||
" self,\n",
|
||
" state: np.ndarray,\n",
|
||
" action: int,\n",
|
||
" reward: float,\n",
|
||
" next_state: np.ndarray,\n",
|
||
" done: bool,\n",
|
||
" next_action: int | None = None,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Perform one Q-learning update.\n",
|
||
"\n",
|
||
" Follows Lab 5B train_q_learning:\n",
|
||
" td_target = r + gamma * max(Q[s2]) * (0 if terminated else 1)\n",
|
||
" Q[s, a] += alpha * (td_target - Q[s, a])\n",
|
||
"\n",
|
||
" In continuous form with linear approximation:\n",
|
||
" delta = target - q(s, a)\n",
|
||
" W[a] += alpha * delta * phi(s)\n",
|
||
" \"\"\"\n",
|
||
" _ = next_action # Q-learning is off-policy: next_action is not used\n",
|
||
" phi = np.nan_to_num(normalize_obs(state), nan=0.0, posinf=0.0, neginf=0.0)\n",
|
||
" q_sa = float(self.W[action] @ phi)\n",
|
||
" if not np.isfinite(q_sa):\n",
|
||
" q_sa = 0.0\n",
|
||
"\n",
|
||
" if done:\n",
|
||
" # Terminal state: no future value\n",
|
||
" # Lab 5B: gamma * np.max(Q[s2]) * (0 if terminated else 1)\n",
|
||
" target = reward\n",
|
||
" else:\n",
|
||
" # Off-policy: use max over all actions in next state\n",
|
||
" # This is the key Q-learning property (Lab 5B)\n",
|
||
" phi_next = np.nan_to_num(normalize_obs(next_state), nan=0.0, posinf=0.0, neginf=0.0)\n",
|
||
" q_next_all = self._q_values(phi_next) # q(s', a') for all a'\n",
|
||
" q_next_max = float(np.max(q_next_all))\n",
|
||
" if not np.isfinite(q_next_max):\n",
|
||
" q_next_max = 0.0\n",
|
||
" target = float(reward) + self.gamma * q_next_max\n",
|
||
"\n",
|
||
" if not np.isfinite(target):\n",
|
||
" return\n",
|
||
"\n",
|
||
" delta = float(target - q_sa)\n",
|
||
" if not np.isfinite(delta):\n",
|
||
" return\n",
|
||
"\n",
|
||
" td_step = float(np.clip(delta, -1_000.0, 1_000.0))\n",
|
||
" self.W[action] += self.alpha * td_step * phi\n",
|
||
" self.W[action] = np.nan_to_num(self.W[action], nan=0.0, posinf=1e6, neginf=-1e6)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "79e6b39f",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Monte Carlo Agent — Linear Approximation (First-visit)\n",
|
||
"\n",
|
||
"This agent is inspired by Lab 4 (`mc_control_epsilon_soft`):\n",
|
||
"- Accumulates transitions in an episode buffer `(state, action, reward)`\n",
|
||
"- At the end of the episode (`done=True`), computes **cumulative returns** by traversing the buffer backward:\n",
|
||
" $$G \\leftarrow \\gamma \\cdot G + r$$\n",
|
||
"- Updates weights with the semi-gradient rule:\n",
|
||
" $$W_a \\leftarrow W_a + \\alpha \\cdot (G - \\hat{q}(s, a)) \\cdot \\phi(s)$$\n",
|
||
"\n",
|
||
"Unlike TD methods (SARSA, Q-Learning), Monte Carlo waits for the complete episode to finish before updating.\n",
|
||
"\n",
|
||
"> **Note**: This agent currently has **checkpoint loading issues** — the saved weights fail to restore properly, causing the agent to behave as if untrained during evaluation. The training code itself works correctly."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 50,
|
||
"id": "7a3aa454",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class MonteCarloAgent(Agent):\n",
|
||
" \"\"\"Monte Carlo control agent with linear function approximation.\n",
|
||
"\n",
|
||
" Inspired by Lab 4 mc_control_epsilon_soft:\n",
|
||
" - Accumulates transitions in an episode buffer\n",
|
||
" - At episode end (done=True), computes discounted returns backward:\n",
|
||
" G = gamma * G + r (same as Lab 4's reversed loop)\n",
|
||
" - Updates weights with semi-gradient: W[a] += alpha * (G - q(s,a)) * phi(s)\n",
|
||
"\n",
|
||
" Unlike TD methods (SARSA, Q-Learning), no update occurs until the episode ends.\n",
|
||
"\n",
|
||
" Performance optimizations over naive per-step implementation:\n",
|
||
" - float32 weights & features (halves memory bandwidth, faster SIMD)\n",
|
||
" - Raw observations stored compactly as uint8, batch-normalized at episode end\n",
|
||
" - Vectorized return computation & chunk-based weight updates via einsum\n",
|
||
" - Single weight sanitization per episode instead of per-step\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" def __init__(\n",
|
||
" self,\n",
|
||
" n_features: int,\n",
|
||
" n_actions: int,\n",
|
||
" alpha: float = 0.001,\n",
|
||
" gamma: float = 0.99,\n",
|
||
" seed: int = 42,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Initialize Monte Carlo agent with linear weights and episode buffers.\"\"\"\n",
|
||
" super().__init__(seed, n_actions)\n",
|
||
" self.n_features = n_features\n",
|
||
" self.alpha = alpha\n",
|
||
" self.gamma = gamma\n",
|
||
" self.W = np.zeros((n_actions, n_features), dtype=np.float32)\n",
|
||
" self._obs_buf: list[np.ndarray] = []\n",
|
||
" self._act_buf: list[int] = []\n",
|
||
" self._rew_buf: list[float] = []\n",
|
||
"\n",
|
||
" def _q_values(self, phi: np.ndarray) -> np.ndarray:\n",
|
||
" return self.W @ phi\n",
|
||
"\n",
|
||
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
|
||
" \"\"\"Select action using ε-greedy policy over linear Q-values.\"\"\"\n",
|
||
" phi = observation.flatten().astype(np.float32) / np.float32(255.0)\n",
|
||
" q_vals = self._q_values(phi)\n",
|
||
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
|
||
"\n",
|
||
" def update(\n",
|
||
" self,\n",
|
||
" state: np.ndarray,\n",
|
||
" action: int,\n",
|
||
" reward: float,\n",
|
||
" next_state: np.ndarray,\n",
|
||
" done: bool,\n",
|
||
" next_action: int | None = None,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Accumulate episode transitions and update weights at episode end.\"\"\"\n",
|
||
" _ = next_state, next_action\n",
|
||
"\n",
|
||
" self._obs_buf.append(state)\n",
|
||
" self._act_buf.append(action)\n",
|
||
" self._rew_buf.append(reward)\n",
|
||
"\n",
|
||
" if not done:\n",
|
||
" return\n",
|
||
"\n",
|
||
" n = len(self._rew_buf)\n",
|
||
" actions = np.array(self._act_buf, dtype=np.intp)\n",
|
||
"\n",
|
||
" returns = np.empty(n, dtype=np.float32)\n",
|
||
" G = np.float32(0.0)\n",
|
||
" gamma32 = np.float32(self.gamma)\n",
|
||
" for i in range(n - 1, -1, -1):\n",
|
||
" G = gamma32 * G + np.float32(self._rew_buf[i])\n",
|
||
" returns[i] = G\n",
|
||
"\n",
|
||
" alpha32 = np.float32(self.alpha)\n",
|
||
" chunk_size = 500\n",
|
||
" for start in range(0, n, chunk_size):\n",
|
||
" end = min(start + chunk_size, n)\n",
|
||
" cs = end - start\n",
|
||
"\n",
|
||
" raw = np.array(self._obs_buf[start:end])\n",
|
||
" phi = raw.reshape(cs, -1).astype(np.float32)\n",
|
||
" phi /= np.float32(255.0)\n",
|
||
"\n",
|
||
" ca = actions[start:end]\n",
|
||
" q_sa = np.einsum(\"ij,ij->i\", self.W[ca], phi)\n",
|
||
"\n",
|
||
" deltas = np.clip(returns[start:end] - q_sa, -1000.0, 1000.0)\n",
|
||
"\n",
|
||
" for a in range(self.action_space):\n",
|
||
" mask = ca == a\n",
|
||
" if not np.any(mask):\n",
|
||
" continue\n",
|
||
" self.W[a] += alpha32 * (deltas[mask] @ phi[mask])\n",
|
||
"\n",
|
||
" self.W = np.nan_to_num(self.W, nan=0.0, posinf=1e6, neginf=-1e6)\n",
|
||
"\n",
|
||
" self._obs_buf.clear()\n",
|
||
" self._act_buf.clear()\n",
|
||
" self._rew_buf.clear()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "d5766fe9",
|
||
"metadata": {},
|
||
"source": [
|
||
"## DQN Agent — PyTorch MLP with Experience Replay and Target Network\n",
|
||
"\n",
|
||
"This agent implements the Deep Q-Network (DQN) using **PyTorch** for GPU-accelerated training (MPS on Apple Silicon).\n",
|
||
"\n",
|
||
"**Network architecture** (same structure as before, now as `torch.nn.Module`):\n",
|
||
"$$\\text{Input}(n\\_features) \\to \\text{Linear}(256) \\to \\text{ReLU} \\to \\text{Linear}(256) \\to \\text{ReLU} \\to \\text{Linear}(n\\_actions)$$\n",
|
||
"\n",
|
||
"**Key techniques** (inspired by Lab 6A Dyna-Q + classic DQN):\n",
|
||
"- **Experience Replay**: circular buffer of transitions, sampled as minibatches for off-policy updates\n",
|
||
"- **Target Network**: periodically synchronized copy of the Q-network, stabilizes learning\n",
|
||
"- **Gradient clipping**: prevents exploding gradients in deep networks\n",
|
||
"- **GPU acceleration**: tensors on MPS/CUDA device for fast forward/backward passes\n",
|
||
"\n",
|
||
"> **Note**: This agent currently has **checkpoint loading issues** — the saved `.pt` checkpoint fails to restore properly (device mismatch / state dict incompatibility), causing errors during evaluation. The training code itself works correctly."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 51,
|
||
"id": "9c777493",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class QNetwork(nn.Module):\n",
|
||
" \"\"\"MLP Q-network: Input -> 256 -> ReLU -> 256 -> ReLU -> n_actions.\"\"\"\n",
|
||
"\n",
|
||
" def __init__(self, n_features: int, n_actions: int) -> None:\n",
|
||
" \"\"\"Initialize the Q-network architecture.\"\"\"\n",
|
||
" super().__init__()\n",
|
||
" self.net = nn.Sequential(\n",
|
||
" nn.Linear(n_features, 256),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Linear(256, 256),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Linear(256, n_actions),\n",
|
||
" )\n",
|
||
"\n",
|
||
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
|
||
" \"\"\"Compute Q-values for input state tensor x.\"\"\"\n",
|
||
" return self.net(x)\n",
|
||
"\n",
|
||
"\n",
|
||
"class ReplayBuffer:\n",
|
||
" \"\"\"Fixed-size circular replay buffer storing (s, a, r, s', done) transitions.\"\"\"\n",
|
||
"\n",
|
||
" def __init__(self, capacity: int) -> None:\n",
|
||
" \"\"\"Initialize the replay buffer with a maximum capacity.\"\"\"\n",
|
||
" self.buffer: deque[tuple[np.ndarray, int, float, np.ndarray, bool]] = deque(maxlen=capacity)\n",
|
||
"\n",
|
||
" def push(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool) -> None:\n",
|
||
" \"\"\"Add a transition to the replay buffer.\"\"\"\n",
|
||
" self.buffer.append((state, action, reward, next_state, done))\n",
|
||
"\n",
|
||
" def sample(self, batch_size: int, rng: np.random.Generator) -> tuple[np.ndarray, ...]:\n",
|
||
" \"\"\"Sample a random minibatch of transitions from the buffer.\"\"\"\n",
|
||
" indices = rng.choice(len(self.buffer), size=batch_size, replace=False)\n",
|
||
" batch = [self.buffer[i] for i in indices]\n",
|
||
" states = np.array([t[0] for t in batch])\n",
|
||
" actions = np.array([t[1] for t in batch])\n",
|
||
" rewards = np.array([t[2] for t in batch])\n",
|
||
" next_states = np.array([t[3] for t in batch])\n",
|
||
" dones = np.array([t[4] for t in batch], dtype=np.float32)\n",
|
||
" return states, actions, rewards, next_states, dones\n",
|
||
"\n",
|
||
" def __len__(self) -> int:\n",
|
||
" \"\"\"Return the current number of transitions stored in the buffer.\"\"\"\n",
|
||
" return len(self.buffer)\n",
|
||
"\n",
|
||
"\n",
|
||
"class DQNAgent(Agent):\n",
|
||
" \"\"\"Deep Q-Network agent using PyTorch with GPU acceleration (MPS/CUDA).\n",
|
||
"\n",
|
||
" Inspired by:\n",
|
||
" - Lab 6A Dyna-Q: experience replay (store transitions, sample for updates)\n",
|
||
" - Classic DQN (Mnih et al., 2015): target network, minibatch SGD\n",
|
||
"\n",
|
||
" Uses Adam optimizer and Huber loss (smooth L1) for stable training.\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" def __init__(\n",
|
||
" self,\n",
|
||
" n_features: int,\n",
|
||
" n_actions: int,\n",
|
||
" lr: float = 1e-4,\n",
|
||
" gamma: float = 0.99,\n",
|
||
" buffer_size: int = 50_000,\n",
|
||
" batch_size: int = 128,\n",
|
||
" target_update_freq: int = 1000,\n",
|
||
" seed: int = 42,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Initialize DQN agent.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" n_features: Input feature dimension.\n",
|
||
" n_actions: Number of discrete actions.\n",
|
||
" lr: Learning rate for Adam optimizer.\n",
|
||
" gamma: Discount factor.\n",
|
||
" buffer_size: Maximum replay buffer capacity.\n",
|
||
" batch_size: Minibatch size for updates.\n",
|
||
" target_update_freq: Steps between target network syncs.\n",
|
||
" seed: RNG seed.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" super().__init__(seed, n_actions)\n",
|
||
" self.n_features = n_features\n",
|
||
" self.lr = lr\n",
|
||
" self.gamma = gamma\n",
|
||
" self.batch_size = batch_size\n",
|
||
" self.target_update_freq = target_update_freq\n",
|
||
" self.update_step = 0\n",
|
||
"\n",
|
||
" # Q-network and target network on GPU\n",
|
||
" torch.manual_seed(seed)\n",
|
||
" self.q_net = QNetwork(n_features, n_actions).to(DEVICE)\n",
|
||
" self.target_net = QNetwork(n_features, n_actions).to(DEVICE)\n",
|
||
" self.target_net.load_state_dict(self.q_net.state_dict())\n",
|
||
" self.target_net.eval()\n",
|
||
"\n",
|
||
" self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)\n",
|
||
" self.loss_fn = nn.SmoothL1Loss() # Huber loss — more robust than MSE\n",
|
||
"\n",
|
||
" # Experience replay buffer\n",
|
||
" self.replay_buffer = ReplayBuffer(buffer_size)\n",
|
||
"\n",
|
||
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
|
||
" \"\"\"Select action using epsilon-greedy policy over Q-network outputs.\"\"\"\n",
|
||
" if self.rng.random() < epsilon:\n",
|
||
" return int(self.rng.integers(0, self.action_space))\n",
|
||
"\n",
|
||
" phi = normalize_obs(observation)\n",
|
||
" with torch.no_grad():\n",
|
||
" state_t = torch.from_numpy(phi).float().unsqueeze(0).to(DEVICE)\n",
|
||
" q_vals = self.q_net(state_t).cpu().numpy().squeeze(0)\n",
|
||
" return epsilon_greedy(q_vals, 0.0, self.rng)\n",
|
||
"\n",
|
||
" def update(\n",
|
||
" self,\n",
|
||
" state: np.ndarray,\n",
|
||
" action: int,\n",
|
||
" reward: float,\n",
|
||
" next_state: np.ndarray,\n",
|
||
" done: bool,\n",
|
||
" next_action: int | None = None,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Store transition and perform a minibatch DQN update.\"\"\"\n",
|
||
" _ = next_action # DQN is off-policy\n",
|
||
"\n",
|
||
" # Store transition\n",
|
||
" phi_s = normalize_obs(state)\n",
|
||
" phi_sp = normalize_obs(next_state)\n",
|
||
" self.replay_buffer.push(phi_s, action, reward, phi_sp, done)\n",
|
||
"\n",
|
||
" if len(self.replay_buffer) < self.batch_size:\n",
|
||
" return\n",
|
||
"\n",
|
||
" # Sample minibatch\n",
|
||
" states_b, actions_b, rewards_b, next_states_b, dones_b = self.replay_buffer.sample(\n",
|
||
" self.batch_size, self.rng,\n",
|
||
" )\n",
|
||
"\n",
|
||
" # Convert to tensors on device\n",
|
||
" states_t = torch.from_numpy(states_b).float().to(DEVICE)\n",
|
||
" actions_t = torch.from_numpy(actions_b).long().to(DEVICE)\n",
|
||
" rewards_t = torch.from_numpy(rewards_b).float().to(DEVICE)\n",
|
||
" next_states_t = torch.from_numpy(next_states_b).float().to(DEVICE)\n",
|
||
" dones_t = torch.from_numpy(dones_b).float().to(DEVICE)\n",
|
||
"\n",
|
||
" # Current Q-values for taken actions\n",
|
||
" q_values = self.q_net(states_t)\n",
|
||
" q_curr = q_values.gather(1, actions_t.unsqueeze(1)).squeeze(1)\n",
|
||
"\n",
|
||
" # Target Q-values (off-policy: max over actions in next state)\n",
|
||
" with torch.no_grad():\n",
|
||
" q_next = self.target_net(next_states_t).max(dim=1).values\n",
|
||
" targets = rewards_t + (1.0 - dones_t) * self.gamma * q_next\n",
|
||
"\n",
|
||
" # Compute loss and update\n",
|
||
" loss = self.loss_fn(q_curr, targets)\n",
|
||
" self.optimizer.zero_grad()\n",
|
||
" loss.backward()\n",
|
||
" nn.utils.clip_grad_norm_(self.q_net.parameters(), max_norm=10.0)\n",
|
||
" self.optimizer.step()\n",
|
||
"\n",
|
||
" # Sync target network periodically\n",
|
||
" self.update_step += 1\n",
|
||
" if self.update_step % self.target_update_freq == 0:\n",
|
||
" self.target_net.load_state_dict(self.q_net.state_dict())\n",
|
||
"\n",
|
||
" def save(self, filename: str) -> None:\n",
|
||
" \"\"\"Save agent state using torch.save (networks + optimizer + metadata).\"\"\"\n",
|
||
" torch.save(\n",
|
||
" {\n",
|
||
" \"q_net\": self.q_net.state_dict(),\n",
|
||
" \"target_net\": self.target_net.state_dict(),\n",
|
||
" \"optimizer\": self.optimizer.state_dict(),\n",
|
||
" \"update_step\": self.update_step,\n",
|
||
" \"n_features\": self.n_features,\n",
|
||
" \"action_space\": self.action_space,\n",
|
||
" },\n",
|
||
" filename,\n",
|
||
" )\n",
|
||
"\n",
|
||
" def load(self, filename: str) -> None:\n",
|
||
" \"\"\"Load agent state from a torch checkpoint.\"\"\"\n",
|
||
" checkpoint = torch.load(filename, map_location=DEVICE, weights_only=False)\n",
|
||
" self.q_net.load_state_dict(checkpoint[\"q_net\"])\n",
|
||
" self.target_net.load_state_dict(checkpoint[\"target_net\"])\n",
|
||
" self.optimizer.load_state_dict(checkpoint[\"optimizer\"])\n",
|
||
" self.update_step = checkpoint[\"update_step\"]\n",
|
||
" self.q_net.to(DEVICE)\n",
|
||
" self.target_net.to(DEVICE)\n",
|
||
" self.target_net.eval()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "91e51dc8",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Tennis Environment\n",
|
||
"\n",
|
||
"Creation of the Atari Tennis environment via Gymnasium (`ALE/Tennis-v5`) with standard wrappers:\n",
|
||
"- **Grayscale**: `obs_type=\"grayscale\"` — single-channel observations\n",
|
||
"- **Resize**: `ResizeObservation(84, 84)` — downscale to 84×84\n",
|
||
"- **Frame stack**: `FrameStackObservation(4)` — stack 4 consecutive frames\n",
|
||
"\n",
|
||
"The final observation is an array of shape `(4, 84, 84)`, which flattens to 28,224 features.\n",
|
||
"\n",
|
||
"The agent plays against the **built-in Atari AI opponent**."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 52,
|
||
"id": "f9a973dd",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def create_env() -> gym.Env:\n",
|
||
" \"\"\"Create the ALE/Tennis-v5 environment with preprocessing wrappers.\n",
|
||
"\n",
|
||
" Applies:\n",
|
||
" - obs_type=\"grayscale\": grayscale observation (210, 160)\n",
|
||
" - ResizeObservation(84, 84): downscale to 84x84\n",
|
||
" - FrameStackObservation(4): stack 4 consecutive frames -> (4, 84, 84)\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" Gymnasium environment ready for training.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" env = gym.make(\"ALE/Tennis-v5\", obs_type=\"grayscale\")\n",
|
||
" env = ResizeObservation(env, shape=(84, 84))\n",
|
||
" return FrameStackObservation(env, stack_size=4)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "18cb28d8",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Training & Evaluation Infrastructure\n",
|
||
"\n",
|
||
"Functions for training and evaluating agents in the single-agent Gymnasium environment:\n",
|
||
"\n",
|
||
"1. **`train_agent`** — Pre-trains an agent against the built-in AI for a given number of episodes with ε-greedy exploration\n",
|
||
"2. **`plot_training_curves`** — Plots the training reward history (moving average) for all agents"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "06b91580",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"AGENT_COLORS = {\n",
|
||
" \"Random\": \"#95a5a6\",\n",
|
||
" \"SARSA\": \"#3498db\",\n",
|
||
" \"Q-Learning\": \"#e74c3c\",\n",
|
||
" \"MonteCarlo\": \"#2ecc71\",\n",
|
||
" \"DQN\": \"#9b59b6\",\n",
|
||
"}\n",
|
||
"\n",
|
||
"\n",
|
||
"def train_agent(\n",
|
||
" env: gym.Env,\n",
|
||
" agent: Agent,\n",
|
||
" name: str,\n",
|
||
" *,\n",
|
||
" episodes: int = 5000,\n",
|
||
" epsilon_start: float = 1.0,\n",
|
||
" epsilon_end: float = 0.05,\n",
|
||
" epsilon_decay: float = 0.999,\n",
|
||
" max_steps: int = 5000,\n",
|
||
") -> list[float]:\n",
|
||
" \"\"\"Pre-train an agent against the built-in Atari AI opponent.\n",
|
||
"\n",
|
||
" Each agent learns independently by playing full episodes. This is the\n",
|
||
" self-play pre-training phase: the agent interacts with the environment's\n",
|
||
" built-in opponent and updates its parameters after each transition.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" env: Gymnasium ALE/Tennis-v5 environment.\n",
|
||
" agent: Agent instance to train.\n",
|
||
" name: Display name for the progress bar.\n",
|
||
" episodes: Number of training episodes.\n",
|
||
" epsilon_start: Initial exploration rate.\n",
|
||
" epsilon_end: Minimum exploration rate.\n",
|
||
" epsilon_decay: Multiplicative decay per episode.\n",
|
||
" max_steps: Maximum steps per episode.\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" List of total rewards per episode.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" rewards_history: list[float] = []\n",
|
||
" epsilon = epsilon_start\n",
|
||
"\n",
|
||
" pbar = tqdm(range(episodes), desc=f\"Training {name}\", leave=True)\n",
|
||
"\n",
|
||
" for _ep in pbar:\n",
|
||
" obs, _info = env.reset()\n",
|
||
" obs = np.asarray(obs)\n",
|
||
" total_reward = 0.0\n",
|
||
"\n",
|
||
" action = agent.get_action(obs, epsilon=epsilon)\n",
|
||
"\n",
|
||
" for _step in range(max_steps):\n",
|
||
" next_obs, reward, terminated, truncated, _info = env.step(action)\n",
|
||
" next_obs = np.asarray(next_obs)\n",
|
||
" done = terminated or truncated\n",
|
||
" reward = float(reward)\n",
|
||
" total_reward += reward\n",
|
||
"\n",
|
||
" next_action = agent.get_action(next_obs, epsilon=epsilon) if not done else None\n",
|
||
"\n",
|
||
" agent.update(\n",
|
||
" state=obs,\n",
|
||
" action=action,\n",
|
||
" reward=reward,\n",
|
||
" next_state=next_obs,\n",
|
||
" done=done,\n",
|
||
" next_action=next_action,\n",
|
||
" )\n",
|
||
"\n",
|
||
" if done:\n",
|
||
" break\n",
|
||
"\n",
|
||
" obs = next_obs\n",
|
||
" action = next_action\n",
|
||
"\n",
|
||
" rewards_history.append(total_reward)\n",
|
||
" epsilon = max(epsilon_end, epsilon * epsilon_decay)\n",
|
||
"\n",
|
||
" recent_window = 50\n",
|
||
" if len(rewards_history) >= recent_window:\n",
|
||
" recent_avg = np.mean(rewards_history[-recent_window:])\n",
|
||
" pbar.set_postfix(\n",
|
||
" avg50=f\"{recent_avg:.1f}\",\n",
|
||
" eps=f\"{epsilon:.3f}\",\n",
|
||
" rew=f\"{total_reward:.0f}\",\n",
|
||
" )\n",
|
||
"\n",
|
||
" return rewards_history\n",
|
||
"\n",
|
||
"def plot_training_curves(\n",
|
||
" training_histories: dict[str, list[float]],\n",
|
||
" path: str,\n",
|
||
" window: int = 100,\n",
|
||
") -> None:\n",
|
||
" \"\"\"Plot training reward curves for all agents on a single figure.\n",
|
||
"\n",
|
||
" Uses a moving average to smooth the curves.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" training_histories: Dict mapping agent names to reward lists.\n",
|
||
" path: File path to save the plot image.\n",
|
||
" window: Moving average window size.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" fig, ax = plt.subplots(figsize=(13, 6))\n",
|
||
"\n",
|
||
" for name, rewards in training_histories.items():\n",
|
||
" color = AGENT_COLORS.get(name, \"#7f8c8d\")\n",
|
||
" if len(rewards) >= window:\n",
|
||
" ma = np.convolve(rewards, np.ones(window) / window, mode=\"valid\")\n",
|
||
" x = np.arange(window - 1, len(rewards))\n",
|
||
" ax.plot(x, ma, label=name, color=color, linewidth=2)\n",
|
||
" ax.fill_between(x, ma - np.std(rewards) * 0.3, ma + np.std(rewards) * 0.3,\n",
|
||
" color=color, alpha=0.1)\n",
|
||
" else:\n",
|
||
" ax.plot(rewards, label=f\"{name} (raw)\", color=color, linewidth=1.5, alpha=0.7)\n",
|
||
"\n",
|
||
" ax.set_xlabel(\"Episodes\", fontsize=12)\n",
|
||
" ax.set_ylabel(f\"Average Reward (window = {window})\", fontsize=12)\n",
|
||
" ax.set_title(\"Training Curves (vs Built-in AI)\", fontsize=14, fontweight=\"bold\")\n",
|
||
" ax.legend(fontsize=10, framealpha=0.9)\n",
|
||
" ax.grid(visible=True, alpha=0.3, linestyle=\"--\")\n",
|
||
" ax.axhline(y=0, color=\"gray\", linewidth=0.8, linestyle=\":\", alpha=0.5)\n",
|
||
" ax.spines[\"top\"].set_visible(False)\n",
|
||
" ax.spines[\"right\"].set_visible(False)\n",
|
||
" fig.tight_layout()\n",
|
||
" plt.savefig(path, dpi=150, bbox_inches=\"tight\")\n",
|
||
" plt.show()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9605e9c4",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Agent Instantiation & Incremental Training (One Agent at a Time)\n",
|
||
"\n",
|
||
"**Environment**: `ALE/Tennis-v5` (grayscale, 84×84×4 frames → 28,224 features, 18 actions).\n",
|
||
"\n",
|
||
"**Agents**:\n",
|
||
"- **Random** — random baseline (no training needed)\n",
|
||
"- **SARSA** — linear approximation, semi-gradient TD(0)\n",
|
||
"- **Q-Learning** — linear approximation, off-policy\n",
|
||
"- **Monte Carlo** — first-visit MC with linear weights (⚠️ checkpoint loading issues)\n",
|
||
"- **DQN** — deep Q-network with experience replay and target network (⚠️ `.pt` checkpoint loading issues)\n",
|
||
"\n",
|
||
"**Workflow**:\n",
|
||
"1. Train **one** selected agent (`AGENT_TO_TRAIN`)\n",
|
||
"2. Save its weights to `checkpoints/` (`.pkl` for linear agents, `.pt` for DQN)\n",
|
||
"3. Repeat later for another agent without retraining previous ones\n",
|
||
"4. Load all saved checkpoints before the final evaluation"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 54,
|
||
"id": "6f6ba8df",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Observation shape : (4, 84, 84)\n",
|
||
"Feature vector dim: 28224\n",
|
||
"Number of actions : 18\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Create environment\n",
|
||
"env = create_env()\n",
|
||
"obs, _info = env.reset()\n",
|
||
"\n",
|
||
"n_actions = int(env.action_space.n)\n",
|
||
"n_features = int(np.prod(obs.shape))\n",
|
||
"\n",
|
||
"print(f\"Observation shape : {obs.shape}\")\n",
|
||
"print(f\"Feature vector dim: {n_features}\")\n",
|
||
"print(f\"Number of actions : {n_actions}\")\n",
|
||
"\n",
|
||
"# Instantiate agents\n",
|
||
"agent_random = RandomAgent(seed=42, action_space=int(n_actions))\n",
|
||
"agent_sarsa = SarsaAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
|
||
"agent_q = QLearningAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
|
||
"agent_mc = MonteCarloAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
|
||
"agent_dqn = DQNAgent(n_features=n_features, n_actions=n_actions)\n",
|
||
"\n",
|
||
"agents = {\n",
|
||
" \"Random\": agent_random,\n",
|
||
" \"SARSA\": agent_sarsa,\n",
|
||
" \"Q-Learning\": agent_q,\n",
|
||
" \"MonteCarlo\": agent_mc,\n",
|
||
" \"DQN\": agent_dqn,\n",
|
||
"}\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 55,
|
||
"id": "4d449701",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Selected agent: Q-Learning\n",
|
||
"Checkpoint path: checkpoints/q_learning.pkl\n",
|
||
"\n",
|
||
"============================================================\n",
|
||
"Training: Q-Learning (3000 episodes)\n",
|
||
"============================================================\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Selected agent: Q-Learning\n",
|
||
"Checkpoint path: checkpoints/q_learning.pkl\n",
|
||
"\n",
|
||
"============================================================\n",
|
||
"Training: Q-Learning (3000 episodes)\n",
|
||
"============================================================\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "0c3a441131c648749a7d938732265143",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
"Training Q-Learning: 0%| | 0/3000 [00:00<?, ?it/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Selected agent: Q-Learning\n",
|
||
"Checkpoint path: checkpoints/q_learning.pkl\n",
|
||
"\n",
|
||
"============================================================\n",
|
||
"Training: Q-Learning (3000 episodes)\n",
|
||
"============================================================\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "0c3a441131c648749a7d938732265143",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
"Training Q-Learning: 0%| | 0/3000 [00:00<?, ?it/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"ename": "KeyboardInterrupt",
|
||
"evalue": "",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
|
||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[55]\u001b[39m\u001b[32m, line 28\u001b[39m\n\u001b[32m 25\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mTraining: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mAGENT_TO_TRAIN\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mTRAINING_EPISODES\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m episodes)\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 26\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m=\u001b[39m\u001b[33m'\u001b[39m*\u001b[32m60\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m28\u001b[39m training_histories[AGENT_TO_TRAIN] = \u001b[43mtrain_agent\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 29\u001b[39m \u001b[43m \u001b[49m\u001b[43menv\u001b[49m\u001b[43m=\u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 30\u001b[39m \u001b[43m \u001b[49m\u001b[43magent\u001b[49m\u001b[43m=\u001b[49m\u001b[43magent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 31\u001b[39m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m=\u001b[49m\u001b[43mAGENT_TO_TRAIN\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 32\u001b[39m \u001b[43m \u001b[49m\u001b[43mepisodes\u001b[49m\u001b[43m=\u001b[49m\u001b[43mTRAINING_EPISODES\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 33\u001b[39m \u001b[43m \u001b[49m\u001b[43mepsilon_start\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1.0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 34\u001b[39m \u001b[43m \u001b[49m\u001b[43mepsilon_end\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0.05\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 35\u001b[39m \u001b[43m \u001b[49m\u001b[43mepsilon_decay\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0.999\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 36\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 38\u001b[39m avg_last_100 = np.mean(training_histories[AGENT_TO_TRAIN][-\u001b[32m100\u001b[39m:])\n\u001b[32m 39\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m-> \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mAGENT_TO_TRAIN\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m avg reward (last 100 eps): \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mavg_last_100\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.2f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n",
|
||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[53]\u001b[39m\u001b[32m, line 45\u001b[39m, in \u001b[36mtrain_agent\u001b[39m\u001b[34m(env, agent, name, episodes, epsilon_start, epsilon_end, epsilon_decay, max_steps)\u001b[39m\n\u001b[32m 42\u001b[39m action = agent.get_action(obs, epsilon=epsilon)\n\u001b[32m 44\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_steps):\n\u001b[32m---> \u001b[39m\u001b[32m45\u001b[39m next_obs, reward, terminated, truncated, _info = \u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 46\u001b[39m next_obs = np.asarray(next_obs)\n\u001b[32m 47\u001b[39m done = terminated \u001b[38;5;129;01mor\u001b[39;00m truncated\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/stateful_observation.py:425\u001b[39m, in \u001b[36mFrameStackObservation.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 414\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 415\u001b[39m \u001b[38;5;28mself\u001b[39m, action: WrapperActType\n\u001b[32m 416\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 417\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Steps through the environment, appending the observation to the frame buffer.\u001b[39;00m\n\u001b[32m 418\u001b[39m \n\u001b[32m 419\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 423\u001b[39m \u001b[33;03m Stacked observations, reward, terminated, truncated, and info from the environment\u001b[39;00m\n\u001b[32m 424\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m425\u001b[39m obs, reward, terminated, truncated, info = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 426\u001b[39m \u001b[38;5;28mself\u001b[39m.obs_queue.append(obs)\n\u001b[32m 428\u001b[39m updated_obs = deepcopy(\n\u001b[32m 429\u001b[39m concatenate(\u001b[38;5;28mself\u001b[39m.env.observation_space, \u001b[38;5;28mself\u001b[39m.obs_queue, \u001b[38;5;28mself\u001b[39m.stacked_obs)\n\u001b[32m 430\u001b[39m )\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/core.py:560\u001b[39m, in \u001b[36mObservationWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 556\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 557\u001b[39m \u001b[38;5;28mself\u001b[39m, action: ActType\n\u001b[32m 558\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 559\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Modifies the :attr:`env` after calling :meth:`step` using :meth:`self.observation` on the returned observations.\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m560\u001b[39m observation, reward, terminated, truncated, info = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 561\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.observation(observation), reward, terminated, truncated, info\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/common.py:393\u001b[39m, in \u001b[36mOrderEnforcing.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 391\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._has_reset:\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m ResetNeeded(\u001b[33m\"\u001b[39m\u001b[33mCannot call env.step() before calling env.reset()\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m393\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/core.py:327\u001b[39m, in \u001b[36mWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 323\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 324\u001b[39m \u001b[38;5;28mself\u001b[39m, action: WrapperActType\n\u001b[32m 325\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 326\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data.\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m327\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/common.py:285\u001b[39m, in \u001b[36mPassiveEnvChecker.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 283\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m env_step_passive_checker(\u001b[38;5;28mself\u001b[39m.env, action)\n\u001b[32m 284\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m285\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/ale_py/env.py:305\u001b[39m, in \u001b[36mAtariEnv.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 303\u001b[39m reward = \u001b[32m0.0\u001b[39m\n\u001b[32m 304\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(frameskip):\n\u001b[32m--> \u001b[39m\u001b[32m305\u001b[39m reward += \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43male\u001b[49m\u001b[43m.\u001b[49m\u001b[43mact\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrength\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 307\u001b[39m is_terminal = \u001b[38;5;28mself\u001b[39m.ale.game_over(with_truncation=\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m 308\u001b[39m is_truncated = \u001b[38;5;28mself\u001b[39m.ale.game_truncated()\n",
|
||
"\u001b[31mKeyboardInterrupt\u001b[39m: "
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"AGENT_TO_TRAIN = \"Q-Learning\"\n",
|
||
"TRAINING_EPISODES = 3000\n",
|
||
"FORCE_RETRAIN = True\n",
|
||
"\n",
|
||
"if AGENT_TO_TRAIN not in agents:\n",
|
||
" msg = f\"Unknown agent '{AGENT_TO_TRAIN}'. Available: {list(agents)}\"\n",
|
||
" raise ValueError(msg)\n",
|
||
"\n",
|
||
"training_histories: dict[str, list[float]] = {}\n",
|
||
"agent = agents[AGENT_TO_TRAIN]\n",
|
||
"checkpoint_path = get_path(AGENT_TO_TRAIN)\n",
|
||
"\n",
|
||
"print(f\"Selected agent: {AGENT_TO_TRAIN}\")\n",
|
||
"print(f\"Checkpoint path: {checkpoint_path}\")\n",
|
||
"\n",
|
||
"if AGENT_TO_TRAIN == \"Random\":\n",
|
||
" print(\"Random is a baseline and is not trained.\")\n",
|
||
" training_histories[AGENT_TO_TRAIN] = []\n",
|
||
"elif checkpoint_path.exists() and not FORCE_RETRAIN:\n",
|
||
" agent.load(str(checkpoint_path))\n",
|
||
" print(\"Checkpoint found -> weights loaded, training skipped.\")\n",
|
||
" training_histories[AGENT_TO_TRAIN] = []\n",
|
||
"else:\n",
|
||
" print(f\"\\n{'='*60}\")\n",
|
||
" print(f\"Training: {AGENT_TO_TRAIN} ({TRAINING_EPISODES} episodes)\")\n",
|
||
" print(f\"{'='*60}\")\n",
|
||
"\n",
|
||
" training_histories[AGENT_TO_TRAIN] = train_agent(\n",
|
||
" env=env,\n",
|
||
" agent=agent,\n",
|
||
" name=AGENT_TO_TRAIN,\n",
|
||
" episodes=TRAINING_EPISODES,\n",
|
||
" epsilon_start=1.0,\n",
|
||
" epsilon_end=0.05,\n",
|
||
" epsilon_decay=0.999,\n",
|
||
" )\n",
|
||
"\n",
|
||
" avg_last_100 = np.mean(training_histories[AGENT_TO_TRAIN][-100:])\n",
|
||
" print(f\"-> {AGENT_TO_TRAIN} avg reward (last 100 eps): {avg_last_100:.2f}\")\n",
|
||
"\n",
|
||
" agent.save(str(checkpoint_path))\n",
|
||
" print(\"Checkpoint saved.\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a13a65df",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"plot_training_curves(\n",
|
||
" training_histories, f\"plots/{AGENT_TO_TRAIN}_training_curves.png\", window=100,\n",
|
||
")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "56959ccd",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Multi-Agent Evaluation & Tournament"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "0698f673",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Loading Checkpoints & Running Evaluation\n",
|
||
"\n",
|
||
"Before evaluation, we load the saved weights for each trained agent from the `checkpoints/` directory. If a checkpoint is missing, the agent will play with its initial weights (zeros), which is effectively random behavior."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 56,
|
||
"id": "5315eb02",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for name in [\"SARSA\", \"Q-Learning\"]:\n",
|
||
" checkpoint_path = get_path(name)\n",
|
||
" if checkpoint_path.exists():\n",
|
||
" agents[name].load(str(checkpoint_path))\n",
|
||
" else:\n",
|
||
" print(f\"Warning: Missing checkpoint for {name}\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "3047c4b0",
|
||
"metadata": {},
|
||
"source": [
|
||
"After individual pre-training against Atari's built-in AI, we move on to **head-to-head evaluation** between our agents.\n",
|
||
"\n",
|
||
"**PettingZoo Environment**: unlike Gymnasium (single player vs built-in AI), PettingZoo (`tennis_v3`) allows **two Python agents** to play against each other. The same preprocessing wrappers are applied (grayscale, resize to 84×84, 4-frame stacking) via **SuperSuit**.\n",
|
||
"\n",
|
||
"**Match Protocol**: each matchup is played over **two legs** with swapped positions (`first_0` / `second_0`) to eliminate any side-of-court advantage. Results are tallied as wins, losses, and draws."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 57,
|
||
"id": "d04d37e0",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def create_tournament_env() -> gym.Env:\n",
|
||
" \"\"\"Create PettingZoo Tennis env with preprocessing compatible with our agents.\"\"\"\n",
|
||
" env = tennis_v3.env(obs_type=\"grayscale_image\")\n",
|
||
" env = ss.resize_v1(env, x_size=84, y_size=84)\n",
|
||
" env = ss.frame_stack_v1(env, 4)\n",
|
||
" return ss.observation_lambda_v0(env, lambda obs, _: np.transpose(obs, (2, 0, 1)))\n",
|
||
"\n",
|
||
"\n",
|
||
"def run_match(\n",
|
||
" env: gym.Env,\n",
|
||
" agent_first: Agent,\n",
|
||
" agent_second: Agent,\n",
|
||
" episodes: int = 10,\n",
|
||
" max_steps: int = 4000,\n",
|
||
") -> dict[str, int]:\n",
|
||
" \"\"\"Run multiple PettingZoo episodes between two agents.\n",
|
||
"\n",
|
||
" Returns wins for global labels {'first': ..., 'second': ..., 'draw': ...}.\n",
|
||
" \"\"\"\n",
|
||
" wins = {\"first\": 0, \"second\": 0, \"draw\": 0}\n",
|
||
"\n",
|
||
" for _ep in range(episodes):\n",
|
||
" env.reset()\n",
|
||
" rewards = {\"first_0\": 0.0, \"second_0\": 0.0}\n",
|
||
"\n",
|
||
" for step_idx, agent_id in enumerate(env.agent_iter()):\n",
|
||
" obs, reward, termination, truncation, _info = env.last()\n",
|
||
" done = termination or truncation\n",
|
||
" rewards[agent_id] += float(reward)\n",
|
||
"\n",
|
||
" if done or step_idx >= max_steps:\n",
|
||
" action = None\n",
|
||
" else:\n",
|
||
" current_agent = agent_first if agent_id == \"first_0\" else agent_second\n",
|
||
" action = current_agent.get_action(np.asarray(obs), epsilon=0.0)\n",
|
||
"\n",
|
||
" env.step(action)\n",
|
||
"\n",
|
||
" if step_idx + 1 >= max_steps:\n",
|
||
" break\n",
|
||
"\n",
|
||
" # Determine winner for the current episode\n",
|
||
" if rewards[\"first_0\"] > rewards[\"second_0\"]:\n",
|
||
" wins[\"first\"] += 1\n",
|
||
" elif rewards[\"second_0\"] > rewards[\"first_0\"]:\n",
|
||
" wins[\"second\"] += 1\n",
|
||
" else:\n",
|
||
" wins[\"draw\"] += 1\n",
|
||
"\n",
|
||
" return wins\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8fd8215a",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Observation Space Alignment Check: Gymnasium vs PettingZoo"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 58,
|
||
"id": "fd235ac7",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Gymnasium: (4, 84, 84) uint8 0 214\n",
|
||
"PettingZoo: (4, 84, 84) uint8 0 214\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"env_gym = create_env()\n",
|
||
"obs_gym, _ = env_gym.reset()\n",
|
||
"print(\n",
|
||
" \"Gymnasium:\",\n",
|
||
" np.asarray(obs_gym).shape,\n",
|
||
" np.asarray(obs_gym).dtype,\n",
|
||
" np.asarray(obs_gym).min(),\n",
|
||
" np.asarray(obs_gym).max(),\n",
|
||
")\n",
|
||
"\n",
|
||
"env_pz = create_tournament_env()\n",
|
||
"env_pz.reset()\n",
|
||
"for agent_id in env_pz.agent_iter():\n",
|
||
" obs_pz, *_ = env_pz.last()\n",
|
||
" print(\n",
|
||
" \"PettingZoo:\",\n",
|
||
" np.asarray(obs_pz).shape,\n",
|
||
" np.asarray(obs_pz).dtype,\n",
|
||
" np.asarray(obs_pz).min(),\n",
|
||
" np.asarray(obs_pz).max(),\n",
|
||
" )\n",
|
||
" break\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "5cfcfdb0",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Evaluate against Atari pre-trained agents"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "bee640bb",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def evaluate_vs_atari(\n",
|
||
" env: gym.Env,\n",
|
||
" agents: dict[str, Agent],\n",
|
||
" episodes: int = 20,\n",
|
||
" max_steps: int = 5000,\n",
|
||
") -> dict[str, float]:\n",
|
||
" \"\"\"Evaluate each agent against the built-in Atari AI opponent over multiple episodes.\"\"\"\n",
|
||
" results = {}\n",
|
||
" for name, agent in agents.items():\n",
|
||
" total_rewards = []\n",
|
||
" for _ in range(episodes):\n",
|
||
" obs, _info = env.reset()\n",
|
||
" obs = np.asarray(obs)\n",
|
||
" ep_reward = 0.0\n",
|
||
"\n",
|
||
" for _ in range(max_steps):\n",
|
||
" action = agent.get_action(obs, epsilon=0.0)\n",
|
||
" next_obs, reward, terminated, truncated, _info = env.step(action)\n",
|
||
" ep_reward += float(reward)\n",
|
||
"\n",
|
||
" if terminated or truncated:\n",
|
||
" break\n",
|
||
"\n",
|
||
" obs = np.asarray(next_obs)\n",
|
||
"\n",
|
||
" total_rewards.append(ep_reward)\n",
|
||
"\n",
|
||
" avg_reward = float(np.mean(total_rewards))\n",
|
||
" results[name] = avg_reward\n",
|
||
" print(\n",
|
||
" f\"{name} vs Atari AI: Score moyen = {avg_reward:.2f} sur {episodes} épisodes\",\n",
|
||
" )\n",
|
||
"\n",
|
||
" return results\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "94a804cc",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Evaluation against built-in Atari AI opponent:\n",
|
||
"Random vs Atari AI: Score moyen = -23.90 sur 10 épisodes\n",
|
||
"SARSA vs Atari AI: Score moyen = 0.00 sur 10 épisodes\n",
|
||
"Q-Learning vs Atari AI: Score moyen = 0.00 sur 10 épisodes\n",
|
||
"MonteCarlo vs Atari AI: Score moyen = -23.80 sur 10 épisodes\n",
|
||
"DQN vs Atari AI: Score moyen = -24.00 sur 10 épisodes\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(\"Evaluation against built-in Atari AI opponent:\")\n",
|
||
"atari_results = evaluate_vs_atari(env, agents, episodes=10)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "150e6764",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Evaluation against the Random Agent (Baseline)\n",
|
||
"\n",
|
||
"To quantify whether our agents have actually learned, we first evaluate them against the **Random agent** baseline. A properly trained agent should achieve a **win rate significantly above 50%** against a random opponent.\n",
|
||
"\n",
|
||
"Each agent plays **two legs** (one in each position) for a total of 20 episodes. Only decisive matches (excluding draws) are counted in the win rate calculation."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "1b85a88f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def evaluate_vs_random(\n",
|
||
" agents: dict[str, Agent],\n",
|
||
" random_agent_name: str = \"Random\",\n",
|
||
" episodes_per_leg: int = 10,\n",
|
||
") -> dict[str, float]:\n",
|
||
" \"\"\"Evaluate each agent against the specified random agent and compute win rates.\"\"\"\n",
|
||
" win_rates = {}\n",
|
||
" env = create_tournament_env()\n",
|
||
" agent_random = agents[random_agent_name]\n",
|
||
"\n",
|
||
" for name, agent in agents.items():\n",
|
||
" if name == random_agent_name:\n",
|
||
" continue\n",
|
||
"\n",
|
||
" leg1 = run_match(env, agent, agent_random, episodes=episodes_per_leg)\n",
|
||
" leg2 = run_match(env, agent_random, agent, episodes=episodes_per_leg)\n",
|
||
"\n",
|
||
" total_wins = leg1[\"first\"] + leg2[\"second\"]\n",
|
||
" total_matches = (episodes_per_leg * 2) - (leg1[\"draw\"] + leg2[\"draw\"])\n",
|
||
"\n",
|
||
" if total_matches == 0:\n",
|
||
" win_rates[name] = 0.5\n",
|
||
" else:\n",
|
||
" win_rates[name] = total_wins / total_matches\n",
|
||
"\n",
|
||
" print(\n",
|
||
" f\"{name} vs {random_agent_name}: {total_wins} wins out of {total_matches} decisive matches (Win rate: {win_rates[name]:.1%})\",\n",
|
||
" )\n",
|
||
"\n",
|
||
" env.close()\n",
|
||
" return win_rates\n",
|
||
"\n",
|
||
"def plot_evaluation_results(\n",
|
||
" atari_results: dict[str, float],\n",
|
||
" random_win_rates: dict[str, float],\n",
|
||
" save_path: str = \"plots/evaluation_results.png\",\n",
|
||
") -> None:\n",
|
||
" \"\"\"Plot evaluation results as two side-by-side subplots.\"\"\"\n",
|
||
" agent_names = list(atari_results.keys())\n",
|
||
" colors = [AGENT_COLORS.get(n, \"#7f8c8d\") for n in agent_names]\n",
|
||
"\n",
|
||
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
|
||
"\n",
|
||
" # --- Left panel: Average reward vs Atari AI ---\n",
|
||
" atari_scores = [atari_results[n] for n in agent_names]\n",
|
||
" x = np.arange(len(agent_names))\n",
|
||
" bars1 = ax1.bar(x, atari_scores, color=colors, edgecolor=\"white\", width=0.6)\n",
|
||
" ax1.set_xticks(x)\n",
|
||
" ax1.set_xticklabels(agent_names, rotation=30, ha=\"right\", fontsize=10)\n",
|
||
" ax1.set_ylabel(\"Average Reward\", fontsize=11)\n",
|
||
" ax1.set_title(\"Performance vs Atari Built-in AI\", fontsize=13, fontweight=\"bold\")\n",
|
||
" ax1.axhline(y=0, color=\"gray\", linestyle=\"--\", linewidth=0.8, alpha=0.5)\n",
|
||
" ax1.grid(axis=\"y\", alpha=0.3, linestyle=\"--\")\n",
|
||
" ax1.spines[\"top\"].set_visible(False)\n",
|
||
" ax1.spines[\"right\"].set_visible(False)\n",
|
||
" for bar, score in zip(bars1, atari_scores):\n",
|
||
" va = \"bottom\" if score >= 0 else \"top\"\n",
|
||
" ax1.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),\n",
|
||
" f\"{score:+.1f}\", ha=\"center\", va=va, fontsize=9, fontweight=\"bold\")\n",
|
||
"\n",
|
||
" # --- Right panel: Win rate vs Random ---\n",
|
||
" trained = [n for n in agent_names if n in random_win_rates]\n",
|
||
" trained_colors = [AGENT_COLORS.get(n, \"#7f8c8d\") for n in trained]\n",
|
||
" rates_pct = [random_win_rates[n] * 100 for n in trained]\n",
|
||
" x2 = np.arange(len(trained))\n",
|
||
" bars2 = ax2.bar(x2, rates_pct, color=trained_colors, edgecolor=\"white\", width=0.6)\n",
|
||
" ax2.set_xticks(x2)\n",
|
||
" ax2.set_xticklabels(trained, rotation=30, ha=\"right\", fontsize=10)\n",
|
||
" ax2.set_ylabel(\"Win Rate (%)\", fontsize=11)\n",
|
||
" ax2.set_title(\"Win Rate vs Random Baseline\", fontsize=13, fontweight=\"bold\")\n",
|
||
" ax2.set_ylim(0, 110)\n",
|
||
" ax2.axhline(y=50, color=\"gray\", linestyle=\"--\", linewidth=0.8, alpha=0.5, label=\"50 % baseline\")\n",
|
||
" ax2.legend(fontsize=8, loc=\"upper right\")\n",
|
||
" ax2.grid(axis=\"y\", alpha=0.3, linestyle=\"--\")\n",
|
||
" ax2.spines[\"top\"].set_visible(False)\n",
|
||
" ax2.spines[\"right\"].set_visible(False)\n",
|
||
" for bar, wr in zip(bars2, rates_pct):\n",
|
||
" ax2.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 1.5,\n",
|
||
" f\"{wr:.0f} %\", ha=\"center\", va=\"bottom\", fontsize=9, fontweight=\"bold\")\n",
|
||
"\n",
|
||
" fig.suptitle(\"Agent Evaluation Summary\", fontsize=15, fontweight=\"bold\", y=1.02)\n",
|
||
" fig.tight_layout()\n",
|
||
" plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n",
|
||
" plt.show()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a053644e",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Evaluation against the Random agent\n",
|
||
"SARSA vs Random: 0 wins out of 20 decisive matches (Win rate: 0.0%)\n",
|
||
"Q-Learning vs Random: 0 wins out of 20 decisive matches (Win rate: 0.0%)\n",
|
||
"MonteCarlo vs Random: 8 wins out of 18 decisive matches (Win rate: 44.4%)\n",
|
||
"DQN vs Random: 6 wins out of 12 decisive matches (Win rate: 50.0%)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(\"Evaluation against the Random agent\")\n",
|
||
"win_rates_vs_random = evaluate_vs_random(\n",
|
||
" agents, random_agent_name=\"Random\", episodes_per_leg=10,\n",
|
||
")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "5fe97c3d",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"ename": "TypeError",
|
||
"evalue": "'float' object is not subscriptable",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||
"\u001b[31mTypeError\u001b[39m Traceback (most recent call last)",
|
||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[62]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mplot_evaluation_results\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2\u001b[39m \u001b[43m \u001b[49m\u001b[43matari_results\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwin_rates_vs_random\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msave_path\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mplots/evaluation_results.png\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 3\u001b[39m \u001b[43m)\u001b[49m\n",
|
||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[61]\u001b[39m\u001b[32m, line 66\u001b[39m, in \u001b[36mplot_evaluation_results\u001b[39m\u001b[34m(atari_results, random_results, save_path)\u001b[39m\n\u001b[32m 63\u001b[39m fig, (ax1, ax2) = plt.subplots(\u001b[32m1\u001b[39m, \u001b[32m2\u001b[39m, figsize=(\u001b[32m14\u001b[39m, \u001b[32m5\u001b[39m))\n\u001b[32m 65\u001b[39m \u001b[38;5;66;03m# --- Left: Average reward vs Atari AI with error bars ---\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m66\u001b[39m means = [\u001b[43matari_results\u001b[49m\u001b[43m[\u001b[49m\u001b[43mn\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mmean\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m n \u001b[38;5;129;01min\u001b[39;00m agent_names]\n\u001b[32m 67\u001b[39m stds = [atari_results[n][\u001b[33m\"\u001b[39m\u001b[33mstd\u001b[39m\u001b[33m\"\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m n \u001b[38;5;129;01min\u001b[39;00m agent_names]\n\u001b[32m 68\u001b[39m x = np.arange(\u001b[38;5;28mlen\u001b[39m(agent_names))\n",
|
||
"\u001b[31mTypeError\u001b[39m: 'float' object is not subscriptable"
|
||
]
|
||
},
|
||
{
|
||
"ename": "TypeError",
|
||
"evalue": "'float' object is not subscriptable",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||
"\u001b[31mTypeError\u001b[39m Traceback (most recent call last)",
|
||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[62]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mplot_evaluation_results\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2\u001b[39m \u001b[43m \u001b[49m\u001b[43matari_results\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwin_rates_vs_random\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msave_path\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mplots/evaluation_results.png\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 3\u001b[39m \u001b[43m)\u001b[49m\n",
|
||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[61]\u001b[39m\u001b[32m, line 66\u001b[39m, in \u001b[36mplot_evaluation_results\u001b[39m\u001b[34m(atari_results, random_results, save_path)\u001b[39m\n\u001b[32m 63\u001b[39m fig, (ax1, ax2) = plt.subplots(\u001b[32m1\u001b[39m, \u001b[32m2\u001b[39m, figsize=(\u001b[32m14\u001b[39m, \u001b[32m5\u001b[39m))\n\u001b[32m 65\u001b[39m \u001b[38;5;66;03m# --- Left: Average reward vs Atari AI with error bars ---\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m66\u001b[39m means = [\u001b[43matari_results\u001b[49m\u001b[43m[\u001b[49m\u001b[43mn\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mmean\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m n \u001b[38;5;129;01min\u001b[39;00m agent_names]\n\u001b[32m 67\u001b[39m stds = [atari_results[n][\u001b[33m\"\u001b[39m\u001b[33mstd\u001b[39m\u001b[33m\"\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m n \u001b[38;5;129;01min\u001b[39;00m agent_names]\n\u001b[32m 68\u001b[39m x = np.arange(\u001b[38;5;28mlen\u001b[39m(agent_names))\n",
|
||
"\u001b[31mTypeError\u001b[39m: 'float' object is not subscriptable"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAABHsAAAGyCAYAAAB0jsg1AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAIQZJREFUeJzt3X2MFtXZB+CzgICmgloKCEWpWkWLgoJsQY2xoZJosPzRlKoBSvyo1RoLaQVEwW+srxqSukpErf5RC2rEGCFYpRJjpSGCJNoKRlGhRhao5aOooDBvZprdsrhYF3efnb33upIRZnZmn7PPkZ17f3vmnKosy7IEAAAAQAgdWrsBAAAAADQfYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAew57XnrppTR69OjUp0+fVFVVlZ5++un/ec3SpUvTaaedlrp06ZKOO+649MgjjxxoewEA2hS1EwBQ+rBnx44dadCgQammpuYrnf/uu++m888/P51zzjlp1apV6Ve/+lW69NJL03PPPXcg7QUAaFPUTgBApVVlWZYd8MVVVWnBggVpzJgx+z1nypQpaeHChemNN96oP/bTn/40bdmyJS1evPhAXxoAoM1ROwEAldCppV9g2bJlaeTIkQ2OjRo1qhjhsz87d+4stjp79uxJH330UfrmN79ZFEkAQDnlv0Pavn178bh3hw6mBqxU7ZRTPwFA25S1QP3U4mHPhg0bUq9evRocy/e3bduWPvnkk3TwwQd/4ZpZs2alm266qaWbBgC0kPXr16dvf/vb3t8K1U459RMAtG3rm7F+avGw50BMmzYtTZ48uX5/69at6aijjiq+8G7durVq2wCA/csDiX79+qVDDz3U21Rh6icAaJu2tUD91OJhT+/evVNtbW2DY/l+Htrs7zdT+apd+bav/BphDwCUn8euK1s75dRPANC2VTXjtDUt/jD98OHD05IlSxoce/7554vjAAConQCA5tXksOff//53sYR6vtUtrZ7/fd26dfVDiMePH19//hVXXJHWrl2brr322rR69ep03333pccffzxNmjSpOb8OAIBSUjsBAKUPe1599dV06qmnFlsun1sn//uMGTOK/Q8//LA++Ml95zvfKZZez0fzDBo0KN19993pwQcfLFaVAACITu0EAFRaVZav8dUGJivq3r17MVGzOXsAoLzcs8tDXwBA+71nt/icPQAAAABUjrAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAoL2HPTU1Nal///6pa9euqbq6Oi1fvvxLz589e3Y64YQT0sEHH5z69euXJk2alD799NMDbTMAQJujfgIAShv2zJ8/P02ePDnNnDkzrVy5Mg0aNCiNGjUqbdy4sdHzH3vssTR16tTi/DfffDM99NBDxee47rrrmqP9AAClp34CAEod9txzzz3psssuSxMnTkwnnXRSmjNnTjrkkEPSww8/3Oj5r7zySjrjjDPSRRddVIwGOvfcc9OFF174P0cDAQBEoX4CAEob9uzatSutWLEijRw58r+foEOHYn/ZsmWNXjNixIjimrpwZ+3atWnRokXpvPPO2+/r7Ny5M23btq3BBgDQFqmfAIBK69SUkzdv3px2796devXq1eB4vr969epGr8lH9OTXnXnmmSnLsvT555+nK6644ksf45o1a1a66aabmtI0AIBSUj8BAOFW41q6dGm6/fbb03333VfM8fPUU0+lhQsXpltuuWW/10ybNi1t3bq1flu/fn1LNxMAoDTUTwBAxUb29OjRI3Xs2DHV1tY2OJ7v9+7du9FrbrjhhjRu3Lh06aWXFvsnn3xy2rFjR7r88svT9OnTi8fA9tWlS5diAwBo69RPAECpR/Z07tw5DRkyJC1ZsqT+2J49e4r94cOHN3rNxx9//IVAJw+McvljXQAAkamfAIBSj+zJ5cuuT5gwIQ0dOjQNGzYszZ49uxipk6/OlRs/fnzq27dvMe9ObvTo0cUKFKeeemqqrq5Ob7/9djHaJz9eF/oAAESmfgIASh32jB07Nm3atCnNmDEjbdiwIQ0ePDgtXry4ftLmdevWNRjJc/3116eqqqrizw8++CB961vfKoKe2267rXm/EgCAklI/AQCVVJW1gWep8qXXu3fvXkzW3K1bt9ZuDgCwH+7Z5aEvAKD93rNbfDUuAAAAACpH2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAADQ3sOempqa1L9//9S1a9dUXV2dli9f/qXnb9myJV111VXpyCOPTF26dEnHH398WrRo0YG2GQCgzVE/AQCV0qmpF8yfPz9Nnjw5zZkzpwh6Zs+enUaNGpXWrFmTevbs+YXzd+3alX74wx8WH3vyySdT37590/vvv58OO+yw5voaAABKTf0EAFRSVZZlWVMuyAOe008/Pd17773F/p49e1K/fv3S1VdfnaZOnfqF8/NQ6P/+7//S6tWr00EHHXRAjdy2bVvq3r172rp1a+rWrdsBfQ4AoOW5ZzdO/QQAVLJ+atJjXPkonRUrVqSRI0f+9xN06FDsL1u2rNFrnnnmmTR8+PDiMa5evXqlgQMHpttvvz3t3r17v6+zc+fO4ovdewMAaIvUTwBApTUp7Nm8eXMR0uShzd7y/Q0bNjR6zdq1a4vHt/Lr8nl6brjhhnT33XenW2+9db+vM2vWrCLVqtvykUMAAG2R+gkACLcaV/6YVz5fzwMPPJCGDBmSxo4dm6ZPn1483rU/06ZNK4Yv1W3r169v6WYCAJSG+gkAqNgEzT169EgdO3ZMtbW1DY7n+7179270mnwFrnyunvy6OieeeGIxEigf1ty5c+cvXJOv2JVvAABtnfoJACj1yJ48mMlH5yxZsqTBb57y/XxensacccYZ6e233y7Oq/PWW28VIVBjQQ8AQCTqJwCg9I9x5cuuz507Nz366KPpzTffTL/4xS/Sjh070sSJE4uPjx8/vngMq07+8Y8++ihdc801RcizcOHCYoLmfMJmAID2QP0EAJT2Ma5cPufOpk2b0owZM4pHsQYPHpwWL15cP2nzunXrihW66uSTKz/33HNp0qRJ6ZRTTkl9+/Ytgp8pU6Y071cCAFBS6icAoJKqsizLUjtccx4AaH7u2eWhLwCg/d6zW3w1LgAAAAAqR9gDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAA0N7DnpqamtS/f//UtWvXVF1dnZYvX/6Vrps3b16qqqpKY8aMOZCXBQBos9RPAEBpw5758+enyZMnp5kzZ6aVK1emQYMGpVGjRqWNGzd+6XXvvfde+vWvf53OOuusr9NeAIA2R/0EAJQ67LnnnnvSZZddliZOnJhOOumkNGfOnHTIIYekhx9+eL/X7N69O1188cXppptuSsccc8zXbTMAQJuifgIAShv27Nq1K61YsSKNHDnyv5+gQ4dif9myZfu97uabb049e/ZMl1xyyVd6nZ07d6Zt27Y12AAA2iL1EwBQ6rBn8+bNxSidXr16NTie72/YsKHRa15++eX00EMPpblz537l15k1a1bq3r17/davX7+mNBMAoDTUTwBAqNW4tm/fnsaNG1cEPT169PjK102bNi1t3bq1flu/fn1LNhMAoDTUTwDA19WpKSfngU3Hjh1TbW1tg+P5fu/evb9w/jvvvFNMzDx69Oj6Y3v27PnPC3fqlNasWZOOPfbYL1zXpUuXYgMAaOvUTwBAqUf2dO7cOQ0ZMiQtWbKkQXiT7w8fPvwL5w8YMCC9/vrradWqVfXbBRdckM4555zi7x7PAgCiUz8BAKUe2ZPLl12fMGFCGjp0aBo2bFiaPXt22rFjR7E6V278+PGpb9++xbw7Xbt2TQMHDmxw/WGHHVb8ue9xAICo1E8AQKnDnrFjx6ZNmzalGTNmFJMyDx48OC1evLh+0uZ169YVK3QBAKB+AgAqryrLsiyVXL70er4qVz5Zc7du3Vq7OQDAfrhnl4e+AID2e882BAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQHsPe2pqalL//v1T165dU3V1dVq+fPl+z507d24666yz0uGHH15sI0eO/NLzAQAiUj8BAKUNe+bPn58mT56cZs6cmVauXJkGDRqURo0alTZu3Njo+UuXLk0XXnhhevHFF9OyZctSv3790rnnnps++OCD5mg/AEDpqZ8AgEqqyrIsa8oF+Uie008/Pd17773F/p49e4oA5+qrr05Tp079n9fv3r27GOGTXz9+/Piv9Jrbtm1L3bt3T1u3bk3dunVrSnMBgApyz26c+gkAqGT91KSRPbt27UorVqwoHsWq/wQdOhT7+aidr+Ljjz9On332WTriiCP2e87OnTuLL3bvDQCgLVI/AQCV1qSwZ/PmzcXInF69ejU4nu9v2LDhK32OKVOmpD59+jQIjPY1a9asItWq2/KRQwAAbZH6CQAIvRrXHXfckebNm5cWLFhQTO68P9OmTSuGL9Vt69evr2QzAQBKQ/0EADRVp6ac3KNHj9SxY8dUW1vb4Hi+37t37y+99q677iqKlRdeeCGdcsopX3puly5dig0AoK1TPwEApR7Z07lz5zRkyJC0ZMmS+mP5BM35/vDhw/d73Z133pluueWWtHjx4jR06NCv12IAgDZE/QQAlHpkTy5fdn3ChAlFaDNs2LA0e/bstGPHjjRx4sTi4/kKW3379i3m3cn99re/TTNmzEiPPfZY6t+/f/3cPt/4xjeKDQAgOvUTAFDqsGfs2LFp06ZNRYCTBzeDBw8uRuzUTdq8bt26YoWuOvfff3+xCsWPf/zjBp9n5syZ6cYbb2yOrwEAoNTUTwBAJVVlWZaldrjmPADQ/Nyzy0NfAED7vWdXdDUuAAAAAFqWsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAACgvYc9NTU1qX///qlr166puro6LV++/EvPf+KJJ9KAAQOK808++eS0aNGiA20vAECbpH4CAEob9syfPz9Nnjw5zZw5M61cuTINGjQojRo1Km3cuLHR81955ZV04YUXpksuuSS99tpracyYMcX2xhtvNEf7AQBKT/0EAFRSVZZlWVMuyEfynH766enee+8t9vfs2ZP69euXrr766jR16tQvnD927Ni0Y8eO9Oyzz9Yf+/73v58GDx6c5syZ85Vec9u2bal79+5p69atqVu3bk1pLgBQQe7ZjVM/AQCVrJ86NeXkXbt2pRUrVqRp06bVH+vQoUMaOXJkWrZsWaPX5MfzkUB7y0cCPf300/t9nZ07dxZbnfwLrnsDAIDyqrtXN/F3SaGpnwCAStdPTQp7Nm/enHbv3p169erV4Hi+v3r16kav2bBhQ6Pn58f3Z9asWemmm276wvF8BBEAUH7//Oc/i99QoX4CACpfPzUp7KmUfOTQ3qOBtmzZko4++ui0bt06hWMrp4154LZ+/XqP07UyfVEe+qIc9EN55KNxjzrqqHTEEUe0dlPaHfVTOfn+VB76ohz0Q3noi9j1U5PCnh49eqSOHTum2traBsfz/d69ezd6TX68KefnunTpUmz7yhMuc/a0vrwP9EM56Ivy0BfloB/KI3/Mm/9QP5Hz/ak89EU56Ify0Bcx66cmfabOnTunIUOGpCVLltQfyydozveHDx/e6DX58b3Pzz3//PP7PR8AIBL1EwBQaU1+jCt/vGrChAlp6NChadiwYWn27NnFalsTJ04sPj5+/PjUt2/fYt6d3DXXXJPOPvvsdPfdd6fzzz8/zZs3L7366qvpgQceaP6vBgCghNRPAECpw558KfVNmzalGTNmFJMs50uoL168uH4S5nxenb2HHo0YMSI99thj6frrr0/XXXdd+u53v1usxDVw4MCv/Jr5I10zZ85s9NEuKkc/lIe+KA99UQ76oTz0RePUT+2XfxPloS/KQT+Uh76I3RdVmbVRAQAAAMIweyIAAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAilN2FNTU5P69++funbtmqqrq9Py5cu/9PwnnngiDRgwoDj/5JNPTosWLapYWyNrSj/MnTs3nXXWWenwww8vtpEjR/7PfqNl+mJv8+bNS1VVVWnMmDHe7lbqiy1btqSrrroqHXnkkcWM+scff7zvUa3QD7Nnz04nnHBCOvjgg1O/fv3SpEmT0qefftocTWm3XnrppTR69OjUp0+f4vtMvrrm/7J06dJ02mmnFf8WjjvuuPTII49UpK3tgdqpPNRP5aF+Kge1U3mon9px/ZSVwLx587LOnTtnDz/8cPa3v/0tu+yyy7LDDjssq62tbfT8v/zlL1nHjh2zO++8M/v73/+eXX/99dlBBx2Uvf766xVveyRN7YeLLrooq6mpyV577bXszTffzH72s59l3bt3z/7xj39UvO3tvS/qvPvuu1nfvn2zs846K/vRj35UsfZG1tS+2LlzZzZ06NDsvPPOy15++eWiT5YuXZqtWrWq4m1vz/3whz/8IevSpUvxZ94Hzz33XHbkkUdmkyZNqnjbI1m0aFE2ffr07KmnnsryEmLBggVfev7atWuzQw45JJs8eXJxv/7d735X3L8XL15csTZHpXYqD/VTeaifykHtVB7qp/ZdP5Ui7Bk2bFh21VVX1e/v3r0769OnTzZr1qxGz//JT36SnX/++Q2OVVdXZz//+c9bvK2RNbUf9vX5559nhx56aPboo4+2YCvbhwPpi/z9HzFiRPbggw9mEyZMEPa0Ul/cf//92THHHJPt2rWruZrAAfRDfu4PfvCDBsfyG+YZZ5zh/WwmX6VYufbaa7Pvfe97DY6NHTs2GzVqlH74mtRO5aF+Kg/1UzmoncpD/dS+66dWf4xr165dacWKFcUjQHU6dOhQ7C9btqzRa/Lje5+fGzVq1H7Pp2X6YV8ff/xx+uyzz9IRRxzhLW+Fvrj55ptTz5490yWXXOL9b8W+eOaZZ9Lw4cOLx7h69eqVBg4cmG6//fa0e/du/VLBfhgxYkRxTd2jXmvXri0epTvvvPP0QwW5X7cMtVN5qJ/KQ/1UDmqn8lA/tV3NVT91Sq1s8+bNxQ9B+Q9Fe8v3V69e3eg1GzZsaPT8/DiV64d9TZkypXgOcd//MWn5vnj55ZfTQw89lFatWuXtbuW+yEOFP//5z+niiy8uwoW33347XXnllUUQOnPmTP1ToX646KKLiuvOPPPMfARr+vzzz9MVV1yRrrvuOn1QQfu7X2/bti198sknxXxKNJ3aqTzUT+WhfioHtVN5qJ/aruaqn1p9ZA8x3HHHHcXEwAsWLCgmT6Vytm/fnsaNG1dMmN2jRw9vfSvbs2dPMcLqgQceSEOGDEljx45N06dPT3PmzGntprUr+aR2+Yiq++67L61cuTI99dRTaeHChemWW25p7aYB1FM/tR71U3moncpD/RRLq4/syX847dixY6qtrW1wPN/v3bt3o9fkx5tyPi3TD3Xuuuuuolh54YUX0imnnOLtrnBfvPPOO+m9994rZnjf+6aZ69SpU1qzZk069thj9UsF+iKXr8B10EEHFdfVOfHEE4uEPh9O27lzZ31RgX644YYbihD00ksvLfbzVRt37NiRLr/88iJ8yx8Do+Xt737drVs3o3q+BrVTeaifykP9VA5qp/JQP7VdzVU/tXq1m//gk//2e8mSJQ1+UM3383kvGpMf3/v83PPPP7/f82mZfsjdeeedxW/KFy9enIYOHeqtboW+GDBgQHr99deLR7jqtgsuuCCdc845xd/zJaepTF/kzjjjjOLRrbrALffWW28VIZCgp3L9kM8htm+gUxfA/WduPCrB/bplqJ3KQ/1UHuqnclA7lYf6qe1qtvopK8mScPkSuY888kixtNjll19eLKm7YcOG4uPjxo3Lpk6d2mDp9U6dOmV33XVXseT3zJkzLb3eCv1wxx13FEshP/nkk9mHH35Yv23fvr05mtOuNbUv9mU1rtbri3Xr1hWr0v3yl7/M1qxZkz377LNZz549s1tvvbUZW9X+NLUf8vtC3g9//OMfi+Ur//SnP2XHHntssZojBy7//v7aa68VW15C3HPPPcXf33///eLjeR/kfbHv0qG/+c1vivt1TU2NpdebidqpPNRP5aF+Kge1U3mon9p3/VSKsCeXrx1/1FFHFeFBvkTcX//61/qPnX322cUPr3t7/PHHs+OPP744P1+WbOHCha3Q6nia0g9HH3108T/rvlv+QxaV7Yt9CXtaty9eeeWVrLq6uggn8mXYb7vttuzzzz9v5la1P03ph88++yy78cYbi4Cna9euWb9+/bIrr7wy+9e//tVKrY/hxRdfbPT7ft17n/+Z98W+1wwePLjot/zfw+9///tWan08aqfyUD+Vh/qpHNRO5aF+ar/1U1X+n+YddAQAAABAa2n1OXsAAAAAaD7CHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAFIc/w/JLyUiKOJTlAAAAABJRU5ErkJggg==",
|
||
"text/plain": [
|
||
"<Figure size 1400x500 with 2 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"plot_evaluation_results(\n",
|
||
" atari_results, win_rates_vs_random, save_path=\"plots/evaluation_results.png\",\n",
|
||
")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b9cb30f5",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Championship Match: SARSA vs Q-Learning\n",
|
||
"\n",
|
||
"The final showdown pits the two trained agents against each other: **SARSA** (on-policy) versus **Q-Learning** (off-policy).\n",
|
||
"\n",
|
||
"This match directly compares the two TD learning strategies:\n",
|
||
"- **SARSA** updates its weights following the policy it actually executes (on-policy): $\\delta = r + \\gamma \\hat{q}(s', a') - \\hat{q}(s, a)$\n",
|
||
"- **Q-Learning** learns the optimal policy independently of exploration (off-policy): $\\delta = r + \\gamma \\max_{a'} \\hat{q}(s', a') - \\hat{q}(s, a)$\n",
|
||
"\n",
|
||
"The championship is played over **2 × 20 episodes** with swapped positions to ensure fairness."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "4031bde5",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def run_all_matchups(\n",
|
||
" agents: dict[str, Agent],\n",
|
||
" episodes_per_leg: int = 20,\n",
|
||
") -> np.ndarray:\n",
|
||
" \"\"\"Run a full round-robin tournament and return an n*n win matrix.\n",
|
||
"\n",
|
||
" matrix[i][j] = wins of agent_i (as first_0) vs agent_j (as second_0).\n",
|
||
" Diagonal entries are NaN (no self-play).\n",
|
||
" \"\"\"\n",
|
||
" env = create_tournament_env()\n",
|
||
" names = list(agents.keys())\n",
|
||
" n = len(names)\n",
|
||
" matrix = np.full((n, n), np.nan)\n",
|
||
"\n",
|
||
" for i in range(n):\n",
|
||
" for j in range(n):\n",
|
||
" if i == j:\n",
|
||
" continue\n",
|
||
" result = run_match(\n",
|
||
" env, agents[names[i]], agents[names[j]], episodes=episodes_per_leg,\n",
|
||
" )\n",
|
||
" matrix[i, j] = result[\"first\"]\n",
|
||
"\n",
|
||
" env.close()\n",
|
||
" return matrix\n",
|
||
"\n",
|
||
"\n",
|
||
"def plot_matchup_matrix(\n",
|
||
" matrix: np.ndarray,\n",
|
||
" agent_names: list[str],\n",
|
||
" episodes_per_leg: int = 20,\n",
|
||
" save_path: str = \"plots/championship_matrix.png\",\n",
|
||
") -> None:\n",
|
||
" \"\"\"Plot the n*n matchup matrix as an annotated heatmap.\"\"\"\n",
|
||
" n = len(agent_names)\n",
|
||
" pct_matrix = matrix / episodes_per_leg * 100\n",
|
||
"\n",
|
||
" fig, ax = plt.subplots(figsize=(8, 7))\n",
|
||
" masked = np.ma.array(pct_matrix, mask=np.isnan(pct_matrix))\n",
|
||
" cmap = plt.cm.RdYlGn.copy()\n",
|
||
" cmap.set_bad(color=\"#f0f0f0\")\n",
|
||
"\n",
|
||
" im = ax.imshow(masked, cmap=cmap, vmin=0, vmax=100, aspect=\"equal\")\n",
|
||
"\n",
|
||
" ax.set_xticks(np.arange(n))\n",
|
||
" ax.set_yticks(np.arange(n))\n",
|
||
" ax.set_xticklabels(agent_names, rotation=45, ha=\"right\", fontsize=10)\n",
|
||
" ax.set_yticklabels(agent_names, fontsize=10)\n",
|
||
" ax.set_xlabel(\"Opponent (second_0)\", fontsize=11)\n",
|
||
" ax.set_ylabel(\"Agent (first_0)\", fontsize=11)\n",
|
||
" ax.set_title(\"Championship Matchup Matrix\", fontsize=14, fontweight=\"bold\", pad=12)\n",
|
||
"\n",
|
||
" for i in range(n):\n",
|
||
" for j in range(n):\n",
|
||
" if i == j:\n",
|
||
" ax.text(j, i, \"—\", ha=\"center\", va=\"center\", color=\"#aaa\", fontsize=14)\n",
|
||
" else:\n",
|
||
" wins = int(matrix[i, j])\n",
|
||
" pct = pct_matrix[i, j]\n",
|
||
" txt_color = \"white\" if pct > 70 or pct < 30 else \"black\"\n",
|
||
" ax.text(j, i, f\"{wins}\\n({pct:.0f} %)\",\n",
|
||
" ha=\"center\", va=\"center\", color=txt_color,\n",
|
||
" fontsize=9, fontweight=\"bold\", linespacing=1.4)\n",
|
||
"\n",
|
||
" cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)\n",
|
||
" cbar.set_label(f\"Win Rate (%) — {episodes_per_leg} eps per matchup\", fontsize=10)\n",
|
||
" cbar.set_ticks([0, 25, 50, 75, 100])\n",
|
||
" cbar.set_ticklabels([\"0 %\", \"25 %\", \"50 %\", \"75 %\", \"100 %\"])\n",
|
||
"\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n",
|
||
" plt.show()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b07b403c",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"ename": "KeyboardInterrupt",
|
||
"evalue": "",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||
"\u001b[31mTypeError\u001b[39m Traceback (most recent call last)",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/lambda_wrappers/observation_lambda.py:68\u001b[39m, in \u001b[36maec_observation_lambda._modify_observation\u001b[39m\u001b[34m(self, agent, observation)\u001b[39m\n\u001b[32m 67\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m68\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mchange_observation_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobservation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mold_obs_space\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magent\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 69\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n",
|
||
"\u001b[31mTypeError\u001b[39m: basic_obs_wrapper.<locals>.change_obs() takes 2 positional arguments but 3 were given",
|
||
"\nDuring handling of the above exception, another exception occurred:\n",
|
||
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
|
||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[41]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m matchup_matrix = \u001b[43mrun_all_matchups\u001b[49m\u001b[43m(\u001b[49m\u001b[43magents\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepisodes_per_leg\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m20\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[40]\u001b[39m\u001b[32m, line 19\u001b[39m, in \u001b[36mrun_all_matchups\u001b[39m\u001b[34m(agents, episodes_per_leg)\u001b[39m\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m i == j:\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m19\u001b[39m result = \u001b[43mrun_match\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 20\u001b[39m \u001b[43m \u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magents\u001b[49m\u001b[43m[\u001b[49m\u001b[43mnames\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magents\u001b[49m\u001b[43m[\u001b[49m\u001b[43mnames\u001b[49m\u001b[43m[\u001b[49m\u001b[43mj\u001b[49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepisodes\u001b[49m\u001b[43m=\u001b[49m\u001b[43mepisodes_per_leg\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 21\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 22\u001b[39m matrix[i, j] = result[\u001b[33m\"\u001b[39m\u001b[33mfirst\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 24\u001b[39m env.close()\n",
|
||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[28]\u001b[39m\u001b[32m, line 37\u001b[39m, in \u001b[36mrun_match\u001b[39m\u001b[34m(env, agent_first, agent_second, episodes, max_steps)\u001b[39m\n\u001b[32m 34\u001b[39m current_agent = agent_first \u001b[38;5;28;01mif\u001b[39;00m agent_id == \u001b[33m\"\u001b[39m\u001b[33mfirst_0\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m agent_second\n\u001b[32m 35\u001b[39m action = current_agent.get_action(np.asarray(obs), epsilon=\u001b[32m0.0\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m37\u001b[39m \u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 39\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m step_idx + \u001b[32m1\u001b[39m >= max_steps:\n\u001b[32m 40\u001b[39m \u001b[38;5;28;01mbreak\u001b[39;00m\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/utils/base_aec_wrapper.py:46\u001b[39m, in \u001b[36mBaseWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 43\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m.terminations[agent] \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m.truncations[agent]):\n\u001b[32m 44\u001b[39m action = \u001b[38;5;28mself\u001b[39m._modify_action(agent, action)\n\u001b[32m---> \u001b[39m\u001b[32m46\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 48\u001b[39m \u001b[38;5;28mself\u001b[39m._update_step(\u001b[38;5;28mself\u001b[39m.agent_selection)\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/order_enforcing.py:70\u001b[39m, in \u001b[36mOrderEnforcingWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 68\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 69\u001b[39m \u001b[38;5;28mself\u001b[39m._has_updated = \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m70\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/base.py:47\u001b[39m, in \u001b[36mBaseWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 46\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\u001b[38;5;28mself\u001b[39m, action: ActionType) -> \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m47\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/generic_wrappers/utils/shared_wrapper_util.py:69\u001b[39m, in \u001b[36mshared_wrapper_aec.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 66\u001b[39m \u001b[38;5;28msuper\u001b[39m().step(action)\n\u001b[32m 67\u001b[39m \u001b[38;5;28mself\u001b[39m.add_modifiers(\u001b[38;5;28mself\u001b[39m.agents)\n\u001b[32m 68\u001b[39m \u001b[38;5;28mself\u001b[39m.modifiers[\u001b[38;5;28mself\u001b[39m.agent_selection].modify_obs(\n\u001b[32m---> \u001b[39m\u001b[32m69\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobserve\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43magent_selection\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 70\u001b[39m )\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/order_enforcing.py:75\u001b[39m, in \u001b[36mOrderEnforcingWrapper.observe\u001b[39m\u001b[34m(self, agent)\u001b[39m\n\u001b[32m 73\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._has_reset:\n\u001b[32m 74\u001b[39m EnvLogger.error_observe_before_reset()\n\u001b[32m---> \u001b[39m\u001b[32m75\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobserve\u001b[49m\u001b[43m(\u001b[49m\u001b[43magent\u001b[49m\u001b[43m)\u001b[49m\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/base.py:41\u001b[39m, in \u001b[36mBaseWrapper.observe\u001b[39m\u001b[34m(self, agent)\u001b[39m\n\u001b[32m 40\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mobserve\u001b[39m(\u001b[38;5;28mself\u001b[39m, agent: AgentID) -> ObsType | \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m41\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobserve\u001b[49m\u001b[43m(\u001b[49m\u001b[43magent\u001b[49m\u001b[43m)\u001b[49m\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/utils/base_aec_wrapper.py:38\u001b[39m, in \u001b[36mBaseWrapper.observe\u001b[39m\u001b[34m(self, agent)\u001b[39m\n\u001b[32m 34\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mobserve\u001b[39m(\u001b[38;5;28mself\u001b[39m, agent):\n\u001b[32m 35\u001b[39m obs = \u001b[38;5;28msuper\u001b[39m().observe(\n\u001b[32m 36\u001b[39m agent\n\u001b[32m 37\u001b[39m ) \u001b[38;5;66;03m# problem is in this line, the obs is sometimes a different size from the obs space\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m38\u001b[39m observation = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_modify_observation\u001b[49m\u001b[43m(\u001b[49m\u001b[43magent\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 39\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m observation\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/lambda_wrappers/observation_lambda.py:70\u001b[39m, in \u001b[36maec_observation_lambda._modify_observation\u001b[39m\u001b[34m(self, agent, observation)\u001b[39m\n\u001b[32m 68\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.change_observation_fn(observation, old_obs_space, agent)\n\u001b[32m 69\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m70\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mchange_observation_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobservation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mold_obs_space\u001b[49m\u001b[43m)\u001b[49m\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/generic_wrappers/basic_wrappers.py:19\u001b[39m, in \u001b[36mbasic_obs_wrapper.<locals>.change_obs\u001b[39m\u001b[34m(obs, obs_space)\u001b[39m\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mchange_obs\u001b[39m(obs, obs_space):\n\u001b[32m---> \u001b[39m\u001b[32m19\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodule\u001b[49m\u001b[43m.\u001b[49m\u001b[43mchange_observation\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobs_space\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparam\u001b[49m\u001b[43m)\u001b[49m\n",
|
||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/utils/basic_transforms/resize.py:21\u001b[39m, in \u001b[36mchange_observation\u001b[39m\u001b[34m(obs, obs_space, resize)\u001b[39m\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mchange_obs_space\u001b[39m(obs_space, param):\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m convert_box(\u001b[38;5;28;01mlambda\u001b[39;00m obs: change_observation(obs, obs_space, param), obs_space)\n\u001b[32m---> \u001b[39m\u001b[32m21\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mchange_observation\u001b[39m(obs, obs_space, resize):\n\u001b[32m 22\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtinyscaler\u001b[39;00m\n\u001b[32m 24\u001b[39m xsize, ysize, linear_interp = resize\n",
|
||
"\u001b[31mKeyboardInterrupt\u001b[39m: "
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"matchup_matrix = run_all_matchups(agents, episodes_per_leg=20)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "9668071a",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"ename": "NameError",
|
||
"evalue": "name 'matchup_matrix' is not defined",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||
"\u001b[31mNameError\u001b[39m Traceback (most recent call last)",
|
||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[42]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m plot_matchup_matrix(\u001b[43mmatchup_matrix\u001b[49m, episodes_per_leg=EPISODES_PER_LEG)\n",
|
||
"\u001b[31mNameError\u001b[39m: name 'matchup_matrix' is not defined"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"plot_matchup_matrix(matchup_matrix, list(agents.keys()), episodes_per_leg=20)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "32c37e5d",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Conclusion\n",
|
||
"\n",
|
||
"This project implemented and compared five Reinforcement Learning agents on Atari Tennis:\n",
|
||
"\n",
|
||
"| Agent | Type | Policy | Update Rule |\n",
|
||
"|-------|------|--------|-------------|\n",
|
||
"| **Random** | Baseline | Uniform random | None |\n",
|
||
"| **SARSA** | TD(0), on-policy | ε-greedy | $W_a \\leftarrow W_a + \\alpha \\cdot (r + \\gamma \\hat{q}(s', a') - \\hat{q}(s, a)) \\cdot \\phi(s)$ |\n",
|
||
"| **Q-Learning** | TD(0), off-policy | ε-greedy | $W_a \\leftarrow W_a + \\alpha \\cdot (r + \\gamma \\max_{a'} \\hat{q}(s', a') - \\hat{q}(s, a)) \\cdot \\phi(s)$ |\n",
|
||
"| **Monte Carlo** | First-visit MC | ε-greedy | $W_a \\leftarrow W_a + \\alpha \\cdot (G_t - \\hat{q}(s, a)) \\cdot \\phi(s)$ |\n",
|
||
"| **DQN** | Deep Q-Network | ε-greedy | Neural network (MLP 256→256) with experience replay and target network |\n",
|
||
"\n",
|
||
"**Architecture**:\n",
|
||
"- **Linear agents** (SARSA, Q-Learning, Monte Carlo): $\\hat{q}(s, a; \\mathbf{W}) = \\mathbf{W}_a^\\top \\phi(s)$ with $\\phi(s) \\in \\mathbb{R}^{28\\,224}$ (4 grayscale 84×84 frames, normalized)\n",
|
||
"- **DQN**: MLP network (28,224 → 256 → 256 → 18) trained with Adam optimizer, Huber loss, and periodic target network sync\n",
|
||
"\n",
|
||
"**Methodology**:\n",
|
||
"1. **Pre-training** each agent individually against Atari's built-in AI (5,000 episodes, ε decaying from 1.0 to 0.05)\n",
|
||
"2. **Evaluation vs Random** to validate learning (expected win rate > 50%)\n",
|
||
"3. **Head-to-head tournament** in matches via PettingZoo (2 × 20 episodes)\n",
|
||
"\n",
|
||
"> ⚠️ **Known Issue**: Monte Carlo and DQN agent checkpoints have loading issues. Their code is preserved here for reference."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "15f80f84",
|
||
"metadata": {},
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "studies (3.13.9)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.13.9"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|