mirror of
https://github.com/ArthurDanjou/ArtStudies.git
synced 2026-03-16 07:10:13 +01:00
1459 lines
58 KiB
Plaintext
1459 lines
58 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "fef45687",
|
||
"metadata": {},
|
||
"source": [
|
||
"# RL Project: Atari Tennis Tournament (NumPy Agents)\n",
|
||
"\n",
|
||
"This notebook implements four Reinforcement Learning algorithms **without PyTorch** to play Atari Tennis (`ALE/Tennis-v5` via Gymnasium):\n",
|
||
"\n",
|
||
"1. **SARSA** — Semi-gradient SARSA with linear approximation (inspired by Lab 7, on-policy update from Lab 5B)\n",
|
||
"2. **Q-Learning** — Off-policy linear approximation (inspired by Lab 5B)\n",
|
||
"3. **DQN** — Deep Q-Network with pure numpy MLP, experience replay and target network (inspired by Lab 6A + classic DQN)\n",
|
||
"4. **Monte Carlo** — First-visit MC control with linear approximation (inspired by Lab 4)\n",
|
||
"\n",
|
||
"Each agent is **pre-trained independently** against the built-in Atari AI opponent, then evaluated in a comparative tournament."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"id": "b50d7174",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import pickle\n",
|
||
"from pathlib import Path\n",
|
||
"\n",
|
||
"import ale_py # noqa: F401 — registers ALE environments\n",
|
||
"import gymnasium as gym\n",
|
||
"from gymnasium.wrappers import FrameStackObservation, ResizeObservation\n",
|
||
"from tqdm.auto import tqdm\n",
|
||
"\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"import seaborn as sns\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"id": "ff3486a4",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"CHECKPOINT_DIR = Path(\"checkpoints\")\n",
|
||
"CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ec691487",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Utility Functions\n",
|
||
"\n",
|
||
"## Observation Normalization\n",
|
||
"\n",
|
||
"The Tennis environment produces image observations of shape `(4, 84, 84)` after preprocessing (grayscale + resize + frame stack).\n",
|
||
"We normalize them into 1D `float64` vectors divided by 255, as in Lab 7 (continuous feature normalization).\n",
|
||
"\n",
|
||
"## ε-greedy Policy\n",
|
||
"\n",
|
||
"Follows the pattern from Lab 5B (`epsilon_greedy`) and Lab 7 (`epsilon_greedy_action`):\n",
|
||
"- With probability ε: random action (exploration)\n",
|
||
"- With probability 1−ε: action maximizing $\\hat{q}(s, a)$ with uniform tie-breaking (`np.flatnonzero`)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"id": "be85c130",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def normalize_obs(observation: np.ndarray) -> np.ndarray:\n",
|
||
" \"\"\"Flatten and normalize an observation to a 1D float64 vector.\n",
|
||
"\n",
|
||
" Replicates the /255.0 normalization used in all agents from the original project.\n",
|
||
" For image observations of shape (4, 84, 84), this produces a vector of length 28_224.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" observation: Raw observation array from the environment.\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" 1D numpy array of dtype float64, values in [0, 1].\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" return observation.flatten().astype(np.float64) / 255.0\n",
|
||
"\n",
|
||
"\n",
|
||
"def epsilon_greedy(\n",
|
||
" q_values: np.ndarray,\n",
|
||
" epsilon: float,\n",
|
||
" rng: np.random.Generator,\n",
|
||
") -> int:\n",
|
||
" \"\"\"Select an action using an ε-greedy policy with fair tie-breaking.\n",
|
||
"\n",
|
||
" Follows the same logic as Lab 5B epsilon_greedy and Lab 7 epsilon_greedy_action:\n",
|
||
" - With probability epsilon: choose a random action (exploration).\n",
|
||
" - With probability 1-epsilon: choose the action with highest Q-value (exploitation).\n",
|
||
" - If multiple actions share the maximum Q-value, break ties uniformly at random.\n",
|
||
"\n",
|
||
" Handles edge cases: empty q_values, NaN/Inf values.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" q_values: Array of Q-values for each action, shape (n_actions,).\n",
|
||
" epsilon: Exploration probability in [0, 1].\n",
|
||
" rng: NumPy random number generator.\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" Selected action index.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" q_values = np.asarray(q_values, dtype=np.float64).reshape(-1)\n",
|
||
"\n",
|
||
" if q_values.size == 0:\n",
|
||
" msg = \"q_values is empty.\"\n",
|
||
" raise ValueError(msg)\n",
|
||
"\n",
|
||
" if rng.random() < epsilon:\n",
|
||
" return int(rng.integers(0, q_values.size))\n",
|
||
"\n",
|
||
" # Handle NaN/Inf values safely\n",
|
||
" finite_mask = np.isfinite(q_values)\n",
|
||
" if not np.any(finite_mask):\n",
|
||
" return int(rng.integers(0, q_values.size))\n",
|
||
"\n",
|
||
" safe_q = q_values.copy()\n",
|
||
" safe_q[~finite_mask] = -np.inf\n",
|
||
" max_val = np.max(safe_q)\n",
|
||
" best = np.flatnonzero(safe_q == max_val)\n",
|
||
"\n",
|
||
" if best.size == 0:\n",
|
||
" return int(rng.integers(0, q_values.size))\n",
|
||
"\n",
|
||
" return int(rng.choice(best))\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "bb53da28",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Agent Definitions\n",
|
||
"\n",
|
||
"## Base Class `Agent`\n",
|
||
"\n",
|
||
"Common interface for all agents, same signatures: `get_action`, `update`, `save`, `load`.\n",
|
||
"Serialization uses `pickle` (compatible with numpy arrays)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"id": "ded9b1fb",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class Agent:\n",
|
||
" \"\"\"Base class for reinforcement learning agents.\n",
|
||
"\n",
|
||
" All agents share this interface so they are compatible with the tournament system.\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" def __init__(self, seed: int, action_space: int) -> None:\n",
|
||
" \"\"\"Initialize the agent with its action space and a reproducible RNG.\"\"\"\n",
|
||
" self.action_space = action_space\n",
|
||
" self.rng = np.random.default_rng(seed=seed)\n",
|
||
"\n",
|
||
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
|
||
" \"\"\"Select an action from the current observation.\"\"\"\n",
|
||
" raise NotImplementedError\n",
|
||
"\n",
|
||
" def update(\n",
|
||
" self,\n",
|
||
" state: np.ndarray,\n",
|
||
" action: int,\n",
|
||
" reward: float,\n",
|
||
" next_state: np.ndarray,\n",
|
||
" done: bool,\n",
|
||
" next_action: int | None = None,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Update agent parameters from one transition.\"\"\"\n",
|
||
"\n",
|
||
" def save(self, filename: str) -> None:\n",
|
||
" \"\"\"Save the agent state to disk using pickle.\"\"\"\n",
|
||
" with Path(filename).open(\"wb\") as f:\n",
|
||
" pickle.dump(self.__dict__, f)\n",
|
||
"\n",
|
||
" def load(self, filename: str) -> None:\n",
|
||
" \"\"\"Load the agent state from disk.\"\"\"\n",
|
||
" with Path(filename).open(\"rb\") as f:\n",
|
||
" self.__dict__.update(pickle.load(f)) # noqa: S301\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8a4eae79",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Random Agent (baseline)\n",
|
||
"\n",
|
||
"Serves as a reference to evaluate the performance of learning agents."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"id": "78bdc9d2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class RandomAgent(Agent):\n",
|
||
" \"\"\"A simple agent that selects actions uniformly at random (baseline).\"\"\"\n",
|
||
"\n",
|
||
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
|
||
" \"\"\"Select a random action, ignoring the observation and epsilon.\"\"\"\n",
|
||
" _ = observation, epsilon\n",
|
||
" return int(self.rng.integers(0, self.action_space))\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "5f679032",
|
||
"metadata": {},
|
||
"source": [
|
||
"## SARSA Agent — Linear Approximation (Semi-gradient)\n",
|
||
"\n",
|
||
"This agent combines:\n",
|
||
"- **Linear approximation** from Lab 7 (`SarsaAgent`): $\\hat{q}(s, a; \\mathbf{W}) = \\mathbf{W}_a^\\top \\phi(s)$\n",
|
||
"- **On-policy SARSA update** from Lab 5B (`train_sarsa`): $\\delta = r + \\gamma \\hat{q}(s', a') - \\hat{q}(s, a)$\n",
|
||
"\n",
|
||
"The semi-gradient update rule is:\n",
|
||
"$$W_a \\leftarrow W_a + \\alpha \\cdot \\delta \\cdot \\phi(s)$$\n",
|
||
"\n",
|
||
"where $\\phi(s)$ is the normalized observation vector (analogous to tile coding features in Lab 7, but in dense form)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"id": "c124ed9a",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class SarsaAgent(Agent):\n",
|
||
" \"\"\"Semi-gradient SARSA agent with linear function approximation.\n",
|
||
"\n",
|
||
" Inspired by:\n",
|
||
" - Lab 7 SarsaAgent: linear q(s,a) = W_a . phi(s), semi-gradient update\n",
|
||
" - Lab 5B train_sarsa: on-policy TD target using Q(s', a')\n",
|
||
"\n",
|
||
" The weight matrix W has shape (n_actions, n_features).\n",
|
||
" For a given state s, q(s, a) = W[a] @ phi(s) is the dot product\n",
|
||
" of the action's weight row with the normalized observation.\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" def __init__(\n",
|
||
" self,\n",
|
||
" n_features: int,\n",
|
||
" n_actions: int,\n",
|
||
" alpha: float = 0.001,\n",
|
||
" gamma: float = 0.99,\n",
|
||
" seed: int = 42,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Initialize SARSA agent with linear weights.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" n_features: Dimension of the feature vector phi(s).\n",
|
||
" n_actions: Number of discrete actions.\n",
|
||
" alpha: Learning rate (kept small for high-dim features).\n",
|
||
" gamma: Discount factor.\n",
|
||
" seed: RNG seed for reproducibility.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" super().__init__(seed, n_actions)\n",
|
||
" self.n_features = n_features\n",
|
||
" self.alpha = alpha\n",
|
||
" self.gamma = gamma\n",
|
||
" # Weight matrix: one row per action, analogous to Lab 7's self.w\n",
|
||
" # but organized as (n_actions, n_features) for dense features.\n",
|
||
" self.W = np.zeros((n_actions, n_features), dtype=np.float64)\n",
|
||
"\n",
|
||
" def _q_values(self, phi: np.ndarray) -> np.ndarray:\n",
|
||
" \"\"\"Compute Q-values for all actions given feature vector phi(s).\n",
|
||
"\n",
|
||
" Equivalent to Lab 7's self.q(s, a) = self.w[idx].sum()\n",
|
||
" but using dense linear approximation: q(s, a) = W[a] @ phi.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" phi: Normalized feature vector, shape (n_features,).\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" Array of Q-values, shape (n_actions,).\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" return self.W @ phi # shape (n_actions,)\n",
|
||
"\n",
|
||
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
|
||
" \"\"\"Select action using ε-greedy policy over linear Q-values.\n",
|
||
"\n",
|
||
" Same pattern as Lab 7 SarsaAgent.eps_greedy:\n",
|
||
" compute q-values for all actions, then apply epsilon_greedy.\n",
|
||
" \"\"\"\n",
|
||
" phi = normalize_obs(observation)\n",
|
||
" q_vals = self._q_values(phi)\n",
|
||
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
|
||
"\n",
|
||
" def update(\n",
|
||
" self,\n",
|
||
" state: np.ndarray,\n",
|
||
" action: int,\n",
|
||
" reward: float,\n",
|
||
" next_state: np.ndarray,\n",
|
||
" done: bool,\n",
|
||
" next_action: int | None = None,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Perform one semi-gradient SARSA update.\n",
|
||
"\n",
|
||
" Follows the SARSA update from Lab 5B train_sarsa:\n",
|
||
" td_target = r + gamma * Q(s', a') * (0 if done else 1)\n",
|
||
" Q(s, a) += alpha * (td_target - Q(s, a))\n",
|
||
"\n",
|
||
" In continuous form with linear approximation (Lab 7 SarsaAgent.update):\n",
|
||
" delta = target - q(s, a)\n",
|
||
" W[a] += alpha * delta * phi(s)\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" state: Current observation.\n",
|
||
" action: Action taken.\n",
|
||
" reward: Reward received.\n",
|
||
" next_state: Next observation.\n",
|
||
" done: Whether the episode ended.\n",
|
||
" next_action: Action chosen in next state (required for SARSA).\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" phi = np.nan_to_num(normalize_obs(state), nan=0.0, posinf=0.0, neginf=0.0)\n",
|
||
" q_sa = float(self.W[action] @ phi) # current estimate q(s, a)\n",
|
||
" if not np.isfinite(q_sa):\n",
|
||
" q_sa = 0.0\n",
|
||
"\n",
|
||
" if done:\n",
|
||
" # Terminal: no future value (Lab 5B: gamma * Q[s2, a2] * 0)\n",
|
||
" target = reward\n",
|
||
" else:\n",
|
||
" # On-policy: use q(s', a') where a' is the actual next action\n",
|
||
" # This is the key SARSA property (Lab 5B)\n",
|
||
" phi_next = np.nan_to_num(normalize_obs(next_state), nan=0.0, posinf=0.0, neginf=0.0)\n",
|
||
" if next_action is None:\n",
|
||
" next_action = 0 # fallback, should not happen in practice\n",
|
||
" q_sp_ap = float(self.W[next_action] @ phi_next)\n",
|
||
" if not np.isfinite(q_sp_ap):\n",
|
||
" q_sp_ap = 0.0\n",
|
||
" target = float(reward) + self.gamma * q_sp_ap\n",
|
||
"\n",
|
||
" # Semi-gradient update: W[a] += alpha * delta * phi(s)\n",
|
||
" # Analogous to Lab 7: self.w[idx] += self.alpha * delta\n",
|
||
" if not np.isfinite(target):\n",
|
||
" return\n",
|
||
"\n",
|
||
" delta = float(target - q_sa)\n",
|
||
" if not np.isfinite(delta):\n",
|
||
" return\n",
|
||
"\n",
|
||
" td_step = float(np.clip(delta, -1_000.0, 1_000.0))\n",
|
||
" self.W[action] += self.alpha * td_step * phi\n",
|
||
" self.W[action] = np.nan_to_num(self.W[action], nan=0.0, posinf=1e6, neginf=-1e6)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "d4e18536",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Q-Learning Agent — Linear Approximation (Off-policy)\n",
|
||
"\n",
|
||
"Same architecture as SARSA but with the **off-policy update** from Lab 5B (`train_q_learning`):\n",
|
||
"\n",
|
||
"$$\\delta = r + \\gamma \\max_{a'} \\hat{q}(s', a') - \\hat{q}(s, a)$$\n",
|
||
"\n",
|
||
"The key difference from SARSA: we use $\\max_{a'} Q(s', a')$ instead of $Q(s', a')$ where $a'$ is the action actually chosen. This allows learning the optimal policy independently of the exploration policy."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"id": "f5b5b9ea",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class QLearningAgent(Agent):\n",
|
||
" \"\"\"Q-Learning agent with linear function approximation (off-policy).\n",
|
||
"\n",
|
||
" Inspired by:\n",
|
||
" - Lab 5B train_q_learning: off-policy TD target using max_a' Q(s', a')\n",
|
||
" - Lab 7 SarsaAgent: linear approximation q(s,a) = W[a] @ phi(s)\n",
|
||
"\n",
|
||
" The only difference from SarsaAgent is the TD target:\n",
|
||
" SARSA uses Q(s', a') (on-policy), Q-Learning uses max_a' Q(s', a') (off-policy).\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" def __init__(\n",
|
||
" self,\n",
|
||
" n_features: int,\n",
|
||
" n_actions: int,\n",
|
||
" alpha: float = 0.001,\n",
|
||
" gamma: float = 0.99,\n",
|
||
" seed: int = 42,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Initialize Q-Learning agent with linear weights.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" n_features: Dimension of the feature vector phi(s).\n",
|
||
" n_actions: Number of discrete actions.\n",
|
||
" alpha: Learning rate.\n",
|
||
" gamma: Discount factor.\n",
|
||
" seed: RNG seed.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" super().__init__(seed, n_actions)\n",
|
||
" self.n_features = n_features\n",
|
||
" self.alpha = alpha\n",
|
||
" self.gamma = gamma\n",
|
||
" self.W = np.zeros((n_actions, n_features), dtype=np.float64)\n",
|
||
"\n",
|
||
" def _q_values(self, phi: np.ndarray) -> np.ndarray:\n",
|
||
" \"\"\"Compute Q-values for all actions: q(s, a) = W[a] @ phi for each a.\"\"\"\n",
|
||
" return self.W @ phi\n",
|
||
"\n",
|
||
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
|
||
" \"\"\"Select action using ε-greedy policy over linear Q-values.\"\"\"\n",
|
||
" phi = normalize_obs(observation)\n",
|
||
" q_vals = self._q_values(phi)\n",
|
||
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
|
||
"\n",
|
||
" def update(\n",
|
||
" self,\n",
|
||
" state: np.ndarray,\n",
|
||
" action: int,\n",
|
||
" reward: float,\n",
|
||
" next_state: np.ndarray,\n",
|
||
" done: bool,\n",
|
||
" next_action: int | None = None,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Perform one Q-learning update.\n",
|
||
"\n",
|
||
" Follows Lab 5B train_q_learning:\n",
|
||
" td_target = r + gamma * max(Q[s2]) * (0 if terminated else 1)\n",
|
||
" Q[s, a] += alpha * (td_target - Q[s, a])\n",
|
||
"\n",
|
||
" In continuous form with linear approximation:\n",
|
||
" delta = target - q(s, a)\n",
|
||
" W[a] += alpha * delta * phi(s)\n",
|
||
" \"\"\"\n",
|
||
" _ = next_action # Q-learning is off-policy: next_action is not used\n",
|
||
" phi = np.nan_to_num(normalize_obs(state), nan=0.0, posinf=0.0, neginf=0.0)\n",
|
||
" q_sa = float(self.W[action] @ phi)\n",
|
||
" if not np.isfinite(q_sa):\n",
|
||
" q_sa = 0.0\n",
|
||
"\n",
|
||
" if done:\n",
|
||
" # Terminal state: no future value\n",
|
||
" # Lab 5B: gamma * np.max(Q[s2]) * (0 if terminated else 1)\n",
|
||
" target = reward\n",
|
||
" else:\n",
|
||
" # Off-policy: use max over all actions in next state\n",
|
||
" # This is the key Q-learning property (Lab 5B)\n",
|
||
" phi_next = np.nan_to_num(normalize_obs(next_state), nan=0.0, posinf=0.0, neginf=0.0)\n",
|
||
" q_next_all = self._q_values(phi_next) # q(s', a') for all a'\n",
|
||
" q_next_max = float(np.max(q_next_all))\n",
|
||
" if not np.isfinite(q_next_max):\n",
|
||
" q_next_max = 0.0\n",
|
||
" target = float(reward) + self.gamma * q_next_max\n",
|
||
"\n",
|
||
" if not np.isfinite(target):\n",
|
||
" return\n",
|
||
"\n",
|
||
" delta = float(target - q_sa)\n",
|
||
" if not np.isfinite(delta):\n",
|
||
" return\n",
|
||
"\n",
|
||
" td_step = float(np.clip(delta, -1_000.0, 1_000.0))\n",
|
||
" self.W[action] += self.alpha * td_step * phi\n",
|
||
" self.W[action] = np.nan_to_num(self.W[action], nan=0.0, posinf=1e6, neginf=-1e6)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "f644e2ef",
|
||
"metadata": {},
|
||
"source": [
|
||
"## DQN Agent — NumPy MLP with Experience Replay and Target Network\n",
|
||
"\n",
|
||
"This agent implements the Deep Q-Network (DQN) **entirely in numpy**, without PyTorch.\n",
|
||
"\n",
|
||
"**Network architecture** (identical to the original DQNAgent structure):\n",
|
||
"$$\\text{Input}(n\\_features) \\to \\text{Linear}(128) \\to \\text{ReLU} \\to \\text{Linear}(128) \\to \\text{ReLU} \\to \\text{Linear}(n\\_actions)$$\n",
|
||
"\n",
|
||
"**Key techniques** (inspired by Lab 6A Dyna-Q + classic DQN):\n",
|
||
"- **Experience Replay**: like the `self.model` dict in Dyna-Q (Lab 6A) that stores past transitions, we store transitions in a circular buffer and sample minibatches for updates.\n",
|
||
"- **Target Network**: periodically synchronized copy of the network, stabilizes learning.\n",
|
||
"- **Manual backpropagation**: MSE loss gradient backpropagated layer by layer."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"id": "ec090a9a",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def _relu(x: np.ndarray) -> np.ndarray:\n",
|
||
" \"\"\"ReLU activation: max(0, x).\"\"\"\n",
|
||
" return np.maximum(0, x)\n",
|
||
"\n",
|
||
"\n",
|
||
"def _relu_grad(x: np.ndarray) -> np.ndarray:\n",
|
||
" \"\"\"Compute the ReLU derivative: return 1 where x > 0, else 0.\"\"\"\n",
|
||
" return (x > 0).astype(np.float64)\n",
|
||
"\n",
|
||
"\n",
|
||
"def _init_layer(\n",
|
||
" fan_in: int, fan_out: int, rng: np.random.Generator,\n",
|
||
") -> tuple[np.ndarray, np.ndarray]:\n",
|
||
" \"\"\"Initialize a linear layer with He initialization.\n",
|
||
"\n",
|
||
" He init: W ~ N(0, sqrt(2/fan_in)) — standard for ReLU networks.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" fan_in: Number of input features.\n",
|
||
" fan_out: Number of output features.\n",
|
||
" rng: Random number generator.\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" Weight matrix W of shape (fan_in, fan_out) and bias vector b of shape (fan_out,).\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" scale = np.sqrt(2.0 / fan_in)\n",
|
||
" W = rng.normal(0, scale, size=(fan_in, fan_out)).astype(np.float64)\n",
|
||
" b = np.zeros(fan_out, dtype=np.float64)\n",
|
||
" return W, b\n",
|
||
"\n",
|
||
"\n",
|
||
"class DQNAgent(Agent):\n",
|
||
" \"\"\"Deep Q-Network agent with a 2-hidden-layer MLP implemented in pure numpy.\n",
|
||
"\n",
|
||
" Architecture: Input -> Linear(128) -> ReLU -> Linear(128) -> ReLU -> Linear(n_actions)\n",
|
||
" Matches the original PyTorch DQNAgent structure.\n",
|
||
"\n",
|
||
" Features:\n",
|
||
" - Experience replay buffer (like Dyna-Q model in Lab 6A: store past transitions, sample for updates)\n",
|
||
" - Target network (periodically synchronized copy for stable targets)\n",
|
||
" - Manual forward pass and backpropagation through the MLP\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" def __init__(\n",
|
||
" self,\n",
|
||
" n_features: int,\n",
|
||
" n_actions: int,\n",
|
||
" lr: float = 0.0001,\n",
|
||
" gamma: float = 0.99,\n",
|
||
" buffer_size: int = 10000,\n",
|
||
" batch_size: int = 64,\n",
|
||
" target_update_freq: int = 1000,\n",
|
||
" seed: int = 42,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Initialize DQN agent.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" n_features: Input feature dimension.\n",
|
||
" n_actions: Number of discrete actions.\n",
|
||
" lr: Learning rate for gradient descent.\n",
|
||
" gamma: Discount factor.\n",
|
||
" buffer_size: Maximum size of the replay buffer.\n",
|
||
" batch_size: Minibatch size for updates.\n",
|
||
" target_update_freq: Steps between target network syncs.\n",
|
||
" seed: RNG seed.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" super().__init__(seed, n_actions)\n",
|
||
" self.n_features = n_features\n",
|
||
" self.lr = lr\n",
|
||
" self.gamma = gamma\n",
|
||
" self.buffer_size = buffer_size\n",
|
||
" self.batch_size = batch_size\n",
|
||
" self.target_update_freq = target_update_freq\n",
|
||
" self.update_step = 0\n",
|
||
"\n",
|
||
" # Initialize network parameters: 3 layers\n",
|
||
" # Layer 1: n_features -> 128\n",
|
||
" # Layer 2: 128 -> 128\n",
|
||
" # Layer 3: 128 -> n_actions\n",
|
||
" self.params = self._init_params(self.rng)\n",
|
||
"\n",
|
||
" # Target network: deep copy of params (like Dyna-Q's model copy concept)\n",
|
||
" self.target_params = {k: v.copy() for k, v in self.params.items()}\n",
|
||
"\n",
|
||
" # Experience replay buffer: list of (s, a, r, s', done) tuples\n",
|
||
" # Analogous to Dyna-Q's self.model dict + self.observed_sa in Lab 6A,\n",
|
||
" # but storing raw transitions for off-policy sampling\n",
|
||
" self.replay_buffer: list[tuple[np.ndarray, int, float, np.ndarray, bool]] = []\n",
|
||
"\n",
|
||
" def _init_params(self, rng: np.random.Generator) -> dict[str, np.ndarray]:\n",
|
||
" \"\"\"Initialize all MLP parameters with He initialization.\"\"\"\n",
|
||
" W1, b1 = _init_layer(self.n_features, 128, rng)\n",
|
||
" W2, b2 = _init_layer(128, 128, rng)\n",
|
||
" W3, b3 = _init_layer(128, self.action_space, rng)\n",
|
||
" return {\"W1\": W1, \"b1\": b1, \"W2\": W2, \"b2\": b2, \"W3\": W3, \"b3\": b3}\n",
|
||
"\n",
|
||
" def _forward(\n",
|
||
" self, x: np.ndarray, params: dict[str, np.ndarray],\n",
|
||
" ) -> tuple[np.ndarray, dict[str, np.ndarray]]:\n",
|
||
" \"\"\"Forward pass through the MLP. Returns output and cached activations for backprop.\n",
|
||
"\n",
|
||
" Architecture: x -> Linear -> ReLU -> Linear -> ReLU -> Linear -> output\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" x: Input array of shape (batch, n_features) or (n_features,).\n",
|
||
" params: Dictionary of network parameters.\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" output: Q-values of shape (batch, n_actions) or (n_actions,).\n",
|
||
" cache: Dictionary of intermediate values for backpropagation.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" # Layer 1\n",
|
||
" z1 = x @ params[\"W1\"] + params[\"b1\"]\n",
|
||
" h1 = _relu(z1)\n",
|
||
" # Layer 2\n",
|
||
" z2 = h1 @ params[\"W2\"] + params[\"b2\"]\n",
|
||
" h2 = _relu(z2)\n",
|
||
" # Layer 3 (output, no activation)\n",
|
||
" out = h2 @ params[\"W3\"] + params[\"b3\"]\n",
|
||
" cache = {\"x\": x, \"z1\": z1, \"h1\": h1, \"z2\": z2, \"h2\": h2}\n",
|
||
" return out, cache\n",
|
||
"\n",
|
||
" def _backward(\n",
|
||
" self,\n",
|
||
" d_out: np.ndarray,\n",
|
||
" cache: dict[str, np.ndarray],\n",
|
||
" params: dict[str, np.ndarray],\n",
|
||
" ) -> dict[str, np.ndarray]:\n",
|
||
" \"\"\"Backward pass: compute gradients of loss w.r.t. all parameters.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" d_out: Gradient of loss w.r.t. output, shape (batch, n_actions).\n",
|
||
" cache: Cached activations from forward pass.\n",
|
||
" params: Current network parameters.\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" Dictionary of gradients for each parameter.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" batch = d_out.shape[0] if d_out.ndim > 1 else 1\n",
|
||
" if d_out.ndim == 1:\n",
|
||
" d_out = d_out.reshape(1, -1)\n",
|
||
"\n",
|
||
" x = cache[\"x\"] if cache[\"x\"].ndim > 1 else cache[\"x\"].reshape(1, -1)\n",
|
||
" h1 = cache[\"h1\"] if cache[\"h1\"].ndim > 1 else cache[\"h1\"].reshape(1, -1)\n",
|
||
" h2 = cache[\"h2\"] if cache[\"h2\"].ndim > 1 else cache[\"h2\"].reshape(1, -1)\n",
|
||
" z1 = cache[\"z1\"] if cache[\"z1\"].ndim > 1 else cache[\"z1\"].reshape(1, -1)\n",
|
||
" z2 = cache[\"z2\"] if cache[\"z2\"].ndim > 1 else cache[\"z2\"].reshape(1, -1)\n",
|
||
"\n",
|
||
" # Layer 3 gradients\n",
|
||
" dW3 = h2.T @ d_out / batch\n",
|
||
" db3 = d_out.mean(axis=0)\n",
|
||
" dh2 = d_out @ params[\"W3\"].T\n",
|
||
"\n",
|
||
" # Layer 2 gradients (through ReLU)\n",
|
||
" dz2 = dh2 * _relu_grad(z2)\n",
|
||
" dW2 = h1.T @ dz2 / batch\n",
|
||
" db2 = dz2.mean(axis=0)\n",
|
||
" dh1 = dz2 @ params[\"W2\"].T\n",
|
||
"\n",
|
||
" # Layer 1 gradients (through ReLU)\n",
|
||
" dz1 = dh1 * _relu_grad(z1)\n",
|
||
" dW1 = x.T @ dz1 / batch\n",
|
||
" db1 = dz1.mean(axis=0)\n",
|
||
"\n",
|
||
" return {\"W1\": dW1, \"b1\": db1, \"W2\": dW2, \"b2\": db2, \"W3\": dW3, \"b3\": db3}\n",
|
||
"\n",
|
||
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
|
||
" \"\"\"Select action using ε-greedy policy over MLP Q-values.\"\"\"\n",
|
||
" phi = normalize_obs(observation)\n",
|
||
" q_vals, _ = self._forward(phi, self.params)\n",
|
||
" if q_vals.ndim > 1:\n",
|
||
" q_vals = q_vals.squeeze(0)\n",
|
||
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
|
||
"\n",
|
||
" def update(\n",
|
||
" self,\n",
|
||
" state: np.ndarray,\n",
|
||
" action: int,\n",
|
||
" reward: float,\n",
|
||
" next_state: np.ndarray,\n",
|
||
" done: bool,\n",
|
||
" next_action: int | None = None,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Store transition and perform a minibatch DQN update.\n",
|
||
"\n",
|
||
" Steps:\n",
|
||
" 1. Add transition to replay buffer (like Dyna-Q model update in Lab 6A)\n",
|
||
" 2. If buffer has enough samples, sample a minibatch\n",
|
||
" 3. Compute targets using target network (Q-learning: max_a' Q_target(s', a'))\n",
|
||
" 4. Backpropagate MSE loss through the MLP\n",
|
||
" 5. Update weights with gradient descent\n",
|
||
" 6. Periodically sync target network\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" _ = next_action # DQN is off-policy\n",
|
||
"\n",
|
||
" # Store transition in replay buffer\n",
|
||
" # Analogous to Lab 6A: self.model[(s,a)] = (r, sp)\n",
|
||
" phi_s = normalize_obs(state)\n",
|
||
" phi_sp = normalize_obs(next_state)\n",
|
||
" self.replay_buffer.append((phi_s, action, reward, phi_sp, done))\n",
|
||
" if len(self.replay_buffer) > self.buffer_size:\n",
|
||
" self.replay_buffer.pop(0)\n",
|
||
"\n",
|
||
" # Don't update until we have enough samples\n",
|
||
" if len(self.replay_buffer) < self.batch_size:\n",
|
||
" return\n",
|
||
"\n",
|
||
" # Sample minibatch from replay buffer\n",
|
||
" # Like Dyna-Q planning: sample from observed transitions (Lab 6A)\n",
|
||
" idx = self.rng.choice(len(self.replay_buffer), size=self.batch_size, replace=False)\n",
|
||
" batch = [self.replay_buffer[i] for i in idx]\n",
|
||
"\n",
|
||
" states_b = np.array([t[0] for t in batch]) # (batch, n_features)\n",
|
||
" actions_b = np.array([t[1] for t in batch]) # (batch,)\n",
|
||
" rewards_b = np.array([t[2] for t in batch]) # (batch,)\n",
|
||
" next_states_b = np.array([t[3] for t in batch]) # (batch, n_features)\n",
|
||
" dones_b = np.array([t[4] for t in batch], dtype=np.float64)\n",
|
||
"\n",
|
||
" # Forward pass on current states\n",
|
||
" q_all, cache = self._forward(states_b, self.params)\n",
|
||
" # Gather Q-values for the taken actions\n",
|
||
" q_curr = q_all[np.arange(self.batch_size), actions_b] # (batch,)\n",
|
||
"\n",
|
||
" # Compute targets using target network (off-policy: max over actions)\n",
|
||
" # Follows Lab 5B Q-learning: td_target = r + gamma * max(Q[s2]) * (0 if terminated else 1)\n",
|
||
" q_next_all, _ = self._forward(next_states_b, self.target_params)\n",
|
||
" q_next_max = np.max(q_next_all, axis=1) # (batch,)\n",
|
||
" targets = rewards_b + (1.0 - dones_b) * self.gamma * q_next_max\n",
|
||
"\n",
|
||
" # MSE loss gradient: d_loss/d_q_curr = 2 * (q_curr - targets) / batch\n",
|
||
" # We only backprop through the action that was taken\n",
|
||
" d_out = np.zeros_like(q_all)\n",
|
||
" d_out[np.arange(self.batch_size), actions_b] = 2.0 * (q_curr - targets) / self.batch_size\n",
|
||
"\n",
|
||
" # Backward pass\n",
|
||
" grads = self._backward(d_out, cache, self.params)\n",
|
||
"\n",
|
||
" # Gradient descent update\n",
|
||
" for key in self.params:\n",
|
||
" self.params[key] -= self.lr * grads[key]\n",
|
||
"\n",
|
||
" # Sync target network periodically\n",
|
||
" self.update_step += 1\n",
|
||
" if self.update_step % self.target_update_freq == 0:\n",
|
||
" self.target_params = {k: v.copy() for k, v in self.params.items()}\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b7b63455",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Monte Carlo Agent — Linear Approximation (First-visit)\n",
|
||
"\n",
|
||
"This agent is inspired by Lab 4 (`mc_control_epsilon_soft`):\n",
|
||
"- Accumulates transitions in an episode buffer `(state, action, reward)`\n",
|
||
"- At the end of the episode (`done=True`), computes **cumulative returns** by traversing the buffer backward:\n",
|
||
" $$G \\leftarrow \\gamma \\cdot G + r$$\n",
|
||
"- Updates weights with the semi-gradient rule:\n",
|
||
" $$W_a \\leftarrow W_a + \\alpha \\cdot (G - \\hat{q}(s, a)) \\cdot \\phi(s)$$\n",
|
||
"\n",
|
||
"Unlike TD methods (SARSA, Q-Learning), Monte Carlo waits for the complete episode to finish before updating."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"id": "3c9d74be",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class MonteCarloAgent(Agent):\n",
|
||
" \"\"\"Monte Carlo control agent with linear function approximation.\n",
|
||
"\n",
|
||
" Inspired by Lab 4 mc_control_epsilon_soft:\n",
|
||
" - Accumulates transitions in an episode buffer\n",
|
||
" - At episode end (done=True), computes discounted returns backward:\n",
|
||
" G = gamma * G + r (same as Lab 4's reversed loop)\n",
|
||
" - Updates weights with semi-gradient: W[a] += alpha * (G - q(s,a)) * phi(s)\n",
|
||
"\n",
|
||
" Unlike TD methods (SARSA, Q-Learning), no update occurs until the episode ends.\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" def __init__(\n",
|
||
" self,\n",
|
||
" n_features: int,\n",
|
||
" n_actions: int,\n",
|
||
" alpha: float = 0.001,\n",
|
||
" gamma: float = 0.99,\n",
|
||
" seed: int = 42,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Initialize Monte Carlo agent.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" n_features: Dimension of the feature vector phi(s).\n",
|
||
" n_actions: Number of discrete actions.\n",
|
||
" alpha: Learning rate.\n",
|
||
" gamma: Discount factor.\n",
|
||
" seed: RNG seed.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" super().__init__(seed, n_actions)\n",
|
||
" self.n_features = n_features\n",
|
||
" self.alpha = alpha\n",
|
||
" self.gamma = gamma\n",
|
||
" self.W = np.zeros((n_actions, n_features), dtype=np.float64)\n",
|
||
" # Episode buffer: stores (state, action, reward) tuples\n",
|
||
" # Analogous to Lab 4's episode list in generate_episode\n",
|
||
" self.episode_buffer: list[tuple[np.ndarray, int, float]] = []\n",
|
||
"\n",
|
||
" def _q_values(self, phi: np.ndarray) -> np.ndarray:\n",
|
||
" \"\"\"Compute Q-values for all actions: q(s, a) = W[a] @ phi for each a.\"\"\"\n",
|
||
" return self.W @ phi\n",
|
||
"\n",
|
||
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
|
||
" \"\"\"Select action using ε-greedy policy over linear Q-values.\"\"\"\n",
|
||
" phi = normalize_obs(observation)\n",
|
||
" q_vals = self._q_values(phi)\n",
|
||
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
|
||
"\n",
|
||
" def update(\n",
|
||
" self,\n",
|
||
" state: np.ndarray,\n",
|
||
" action: int,\n",
|
||
" reward: float,\n",
|
||
" next_state: np.ndarray,\n",
|
||
" done: bool,\n",
|
||
" next_action: int | None = None,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Accumulate transitions and update at episode end with MC returns.\n",
|
||
"\n",
|
||
" Follows Lab 4 mc_control_epsilon_soft / mc_control_exploring_starts:\n",
|
||
" 1. Append (state, action, reward) to episode buffer\n",
|
||
" 2. If not done: wait (no update yet)\n",
|
||
" 3. If done: compute returns backward and update weights\n",
|
||
"\n",
|
||
" The backward loop is exactly the Lab 4 pattern:\n",
|
||
" G = 0\n",
|
||
" for s, a, r in reversed(episode_buffer):\n",
|
||
" G = gamma * G + r\n",
|
||
" # update Q(s, a) toward G\n",
|
||
" \"\"\"\n",
|
||
" _ = next_state, next_action # Not used in MC\n",
|
||
"\n",
|
||
" self.episode_buffer.append((state, action, reward))\n",
|
||
"\n",
|
||
" if not done:\n",
|
||
" return # Wait until episode ends\n",
|
||
"\n",
|
||
" # Episode finished: compute MC returns and update\n",
|
||
" # Backward pass through episode (Lab 4 pattern)\n",
|
||
" returns = 0.0\n",
|
||
" for s, a, r in reversed(self.episode_buffer):\n",
|
||
" returns = self.gamma * returns + r\n",
|
||
"\n",
|
||
" phi = np.nan_to_num(normalize_obs(s), nan=0.0, posinf=0.0, neginf=0.0)\n",
|
||
" q_sa = float(self.W[a] @ phi)\n",
|
||
" if not np.isfinite(q_sa):\n",
|
||
" q_sa = 0.0\n",
|
||
"\n",
|
||
" # Semi-gradient update toward the MC return G\n",
|
||
" # Analogous to Lab 4: Q[(s,a)] += (G - Q[(s,a)]) / N[(s,a)]\n",
|
||
" # but with linear approximation and fixed step size\n",
|
||
" if not np.isfinite(returns):\n",
|
||
" continue\n",
|
||
"\n",
|
||
" delta = float(returns - q_sa)\n",
|
||
" if not np.isfinite(delta):\n",
|
||
" continue\n",
|
||
"\n",
|
||
" td_step = float(np.clip(delta, -1_000.0, 1_000.0))\n",
|
||
" self.W[a] += self.alpha * td_step * phi\n",
|
||
" self.W[a] = np.nan_to_num(self.W[a], nan=0.0, posinf=1e6, neginf=-1e6)\n",
|
||
"\n",
|
||
" # Clear episode buffer for next episode\n",
|
||
" self.episode_buffer = []\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "91e51dc8",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Tennis Environment\n",
|
||
"\n",
|
||
"Creation of the Atari Tennis environment via Gymnasium (`ALE/Tennis-v5`) with standard wrappers:\n",
|
||
"- **Grayscale**: `obs_type=\"grayscale\"` — single-channel observations\n",
|
||
"- **Resize**: `ResizeObservation(84, 84)` — downscale to 84×84\n",
|
||
"- **Frame stack**: `FrameStackObservation(4)` — stack 4 consecutive frames\n",
|
||
"\n",
|
||
"The final observation is an array of shape `(4, 84, 84)`, which flattens to 28,224 features.\n",
|
||
"\n",
|
||
"The agent plays against the **built-in Atari AI opponent**."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"id": "f9a973dd",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def create_env() -> gym.Env:\n",
|
||
" \"\"\"Create the ALE/Tennis-v5 environment with preprocessing wrappers.\n",
|
||
"\n",
|
||
" Applies:\n",
|
||
" - obs_type=\"grayscale\": grayscale observation (210, 160)\n",
|
||
" - ResizeObservation(84, 84): downscale to 84x84\n",
|
||
" - FrameStackObservation(4): stack 4 consecutive frames -> (4, 84, 84)\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" Gymnasium environment ready for training.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" env = gym.make(\"ALE/Tennis-v5\", obs_type=\"grayscale\")\n",
|
||
" env = ResizeObservation(env, shape=(84, 84))\n",
|
||
" return FrameStackObservation(env, stack_size=4)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "18cb28d8",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Training & Evaluation Infrastructure\n",
|
||
"\n",
|
||
"Functions for training and evaluating agents in the single-agent Gymnasium environment:\n",
|
||
"\n",
|
||
"1. **`train_agent`** — Pre-trains an agent against the built-in AI for a given number of episodes with ε-greedy exploration\n",
|
||
"2. **`evaluate_agent`** — Evaluates a trained agent (no exploration, ε = 0) and returns performance metrics\n",
|
||
"3. **`plot_training_curves`** — Plots the training reward history (moving average) for all agents\n",
|
||
"4. **`plot_evaluation_comparison`** — Bar chart comparing final evaluation scores across agents\n",
|
||
"5. **`evaluate_tournament`** — Evaluates all agents and produces a summary comparison"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"id": "06b91580",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def train_agent(\n",
|
||
" env: gym.Env,\n",
|
||
" agent: Agent,\n",
|
||
" name: str,\n",
|
||
" *,\n",
|
||
" episodes: int = 5000,\n",
|
||
" epsilon_start: float = 1.0,\n",
|
||
" epsilon_end: float = 0.05,\n",
|
||
" epsilon_decay: float = 0.999,\n",
|
||
" max_steps: int = 5000,\n",
|
||
") -> list[float]:\n",
|
||
" \"\"\"Pre-train an agent against the built-in Atari AI opponent.\n",
|
||
"\n",
|
||
" Each agent learns independently by playing full episodes. This is the\n",
|
||
" self-play pre-training phase: the agent interacts with the environment's\n",
|
||
" built-in opponent and updates its parameters after each transition.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" env: Gymnasium ALE/Tennis-v5 environment.\n",
|
||
" agent: Agent instance to train.\n",
|
||
" name: Display name for the progress bar.\n",
|
||
" episodes: Number of training episodes.\n",
|
||
" epsilon_start: Initial exploration rate.\n",
|
||
" epsilon_end: Minimum exploration rate.\n",
|
||
" epsilon_decay: Multiplicative decay per episode.\n",
|
||
" max_steps: Maximum steps per episode.\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" List of total rewards per episode.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" rewards_history: list[float] = []\n",
|
||
" epsilon = epsilon_start\n",
|
||
"\n",
|
||
" pbar = tqdm(range(episodes), desc=f\"Training {name}\", leave=True)\n",
|
||
"\n",
|
||
" for _ep in pbar:\n",
|
||
" obs, _info = env.reset()\n",
|
||
" obs = np.asarray(obs)\n",
|
||
" total_reward = 0.0\n",
|
||
"\n",
|
||
" # Select first action\n",
|
||
" action = agent.get_action(obs, epsilon=epsilon)\n",
|
||
"\n",
|
||
" for _step in range(max_steps):\n",
|
||
" next_obs, reward, terminated, truncated, _info = env.step(action)\n",
|
||
" next_obs = np.asarray(next_obs)\n",
|
||
" done = terminated or truncated\n",
|
||
" reward = float(reward)\n",
|
||
" total_reward += reward\n",
|
||
"\n",
|
||
" # Select next action (needed for SARSA's on-policy update)\n",
|
||
" next_action = agent.get_action(next_obs, epsilon=epsilon) if not done else None\n",
|
||
"\n",
|
||
" # Update agent with the transition\n",
|
||
" agent.update(\n",
|
||
" state=obs,\n",
|
||
" action=action,\n",
|
||
" reward=reward,\n",
|
||
" next_state=next_obs,\n",
|
||
" done=done,\n",
|
||
" next_action=next_action,\n",
|
||
" )\n",
|
||
"\n",
|
||
" if done:\n",
|
||
" break\n",
|
||
"\n",
|
||
" obs = next_obs\n",
|
||
" action = next_action\n",
|
||
"\n",
|
||
" rewards_history.append(total_reward)\n",
|
||
" epsilon = max(epsilon_end, epsilon * epsilon_decay)\n",
|
||
"\n",
|
||
" # Update progress bar\n",
|
||
" recent_window = 50\n",
|
||
" if len(rewards_history) >= recent_window:\n",
|
||
" recent_avg = np.mean(rewards_history[-recent_window:])\n",
|
||
" pbar.set_postfix(\n",
|
||
" avg50=f\"{recent_avg:.1f}\",\n",
|
||
" eps=f\"{epsilon:.3f}\",\n",
|
||
" rew=f\"{total_reward:.0f}\",\n",
|
||
" )\n",
|
||
"\n",
|
||
" return rewards_history\n",
|
||
"\n",
|
||
"\n",
|
||
"def evaluate_agent(\n",
|
||
" env: gym.Env,\n",
|
||
" agent: Agent,\n",
|
||
" name: str,\n",
|
||
" *,\n",
|
||
" episodes: int = 20,\n",
|
||
" max_steps: int = 5000,\n",
|
||
") -> dict[str, object]:\n",
|
||
" \"\"\"Evaluate a trained agent with no exploration (ε = 0).\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" env: Gymnasium ALE/Tennis-v5 environment.\n",
|
||
" agent: Trained agent to evaluate.\n",
|
||
" name: Display name for the progress bar.\n",
|
||
" episodes: Number of evaluation episodes.\n",
|
||
" max_steps: Maximum steps per episode.\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" Dictionary with rewards list, mean, std, wins, and win rate.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" rewards: list[float] = []\n",
|
||
" wins = 0\n",
|
||
"\n",
|
||
" for _ep in tqdm(range(episodes), desc=f\"Evaluating {name}\", leave=False):\n",
|
||
" obs, _info = env.reset()\n",
|
||
" total_reward = 0.0\n",
|
||
"\n",
|
||
" for _step in range(max_steps):\n",
|
||
" action = agent.get_action(np.asarray(obs), epsilon=0.0)\n",
|
||
" obs, reward, terminated, truncated, _info = env.step(action)\n",
|
||
" reward = float(reward)\n",
|
||
" total_reward += reward\n",
|
||
" if terminated or truncated:\n",
|
||
" break\n",
|
||
"\n",
|
||
" rewards.append(total_reward)\n",
|
||
" if total_reward > 0:\n",
|
||
" wins += 1\n",
|
||
"\n",
|
||
" return {\n",
|
||
" \"rewards\": rewards,\n",
|
||
" \"mean_reward\": float(np.mean(rewards)),\n",
|
||
" \"std_reward\": float(np.std(rewards)),\n",
|
||
" \"wins\": wins,\n",
|
||
" \"win_rate\": wins / episodes,\n",
|
||
" }\n",
|
||
"\n",
|
||
"\n",
|
||
"def plot_training_curves(\n",
|
||
" training_histories: dict[str, list[float]],\n",
|
||
" window: int = 100,\n",
|
||
") -> None:\n",
|
||
" \"\"\"Plot training reward curves for all agents on a single figure.\n",
|
||
"\n",
|
||
" Uses a moving average to smooth the curves.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" training_histories: Dict mapping agent names to reward lists.\n",
|
||
" window: Moving average window size.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" plt.figure(figsize=(12, 6))\n",
|
||
"\n",
|
||
" for name, rewards in training_histories.items():\n",
|
||
" if len(rewards) >= window:\n",
|
||
" ma = np.convolve(rewards, np.ones(window) / window, mode=\"valid\")\n",
|
||
" plt.plot(np.arange(window - 1, len(rewards)), ma, label=name)\n",
|
||
" else:\n",
|
||
" plt.plot(rewards, label=f\"{name} (raw)\")\n",
|
||
"\n",
|
||
" plt.xlabel(\"Episodes\")\n",
|
||
" plt.ylabel(f\"Average Reward (Window={window})\")\n",
|
||
" plt.title(\"Training Curves (vs built-in AI)\")\n",
|
||
" plt.legend()\n",
|
||
" plt.grid(visible=True)\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.show()\n",
|
||
"\n",
|
||
"\n",
|
||
"def plot_evaluation_comparison(results: dict[str, dict[str, object]]) -> None:\n",
|
||
" \"\"\"Bar chart comparing evaluation performance of all agents.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" results: Dict mapping agent names to evaluation result dicts.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" names = list(results.keys())\n",
|
||
" means = [results[n][\"mean_reward\"] for n in names]\n",
|
||
" stds = [results[n][\"std_reward\"] for n in names]\n",
|
||
" win_rates = [results[n][\"win_rate\"] for n in names]\n",
|
||
"\n",
|
||
" _fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
||
"\n",
|
||
" # Mean reward bar chart\n",
|
||
" colors = sns.color_palette(\"husl\", len(names))\n",
|
||
" axes[0].bar(names, means, yerr=stds, capsize=5, color=colors, edgecolor=\"black\")\n",
|
||
" axes[0].set_ylabel(\"Mean Reward\")\n",
|
||
" axes[0].set_title(\"Evaluation: Mean Reward per Agent (vs built-in AI)\")\n",
|
||
" axes[0].axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.5)\n",
|
||
" axes[0].grid(axis=\"y\", alpha=0.3)\n",
|
||
"\n",
|
||
" # Win rate bar chart\n",
|
||
" axes[1].bar(names, win_rates, color=colors, edgecolor=\"black\")\n",
|
||
" axes[1].set_ylabel(\"Win Rate\")\n",
|
||
" axes[1].set_title(\"Evaluation: Win Rate per Agent (vs built-in AI)\")\n",
|
||
" axes[1].set_ylim(0, 1)\n",
|
||
" axes[1].axhline(y=0.5, color=\"gray\", linestyle=\"--\", alpha=0.5, label=\"50% baseline\")\n",
|
||
" axes[1].legend()\n",
|
||
" axes[1].grid(axis=\"y\", alpha=0.3)\n",
|
||
"\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.show()\n",
|
||
"\n",
|
||
"\n",
|
||
"def evaluate_tournament(\n",
|
||
" env: gym.Env,\n",
|
||
" agents: dict[str, Agent],\n",
|
||
" episodes_per_agent: int = 20,\n",
|
||
") -> dict[str, dict[str, object]]:\n",
|
||
" \"\"\"Evaluate all agents against the built-in AI and produce a comparison.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" env: Gymnasium ALE/Tennis-v5 environment.\n",
|
||
" agents: Dictionary mapping agent names to Agent instances.\n",
|
||
" episodes_per_agent: Number of evaluation episodes per agent.\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" Dict mapping agent names to their evaluation results.\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" results: dict[str, dict[str, object]] = {}\n",
|
||
" n_agents = len(agents)\n",
|
||
"\n",
|
||
" for idx, (name, agent) in enumerate(agents.items(), start=1):\n",
|
||
" print(f\"[Evaluation {idx}/{n_agents}] {name}\")\n",
|
||
" results[name] = evaluate_agent(\n",
|
||
" env, agent, name, episodes=episodes_per_agent,\n",
|
||
" )\n",
|
||
" mean_r = results[name][\"mean_reward\"]\n",
|
||
" wr = results[name][\"win_rate\"]\n",
|
||
" print(f\" -> Mean reward: {mean_r:.2f} | Win rate: {wr:.1%}\\n\")\n",
|
||
"\n",
|
||
" return results\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9605e9c4",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Agent Instantiation & Incremental Training (One Agent at a Time)\n",
|
||
"\n",
|
||
"**Environment**: `ALE/Tennis-v5` (grayscale, 84×84×4 frames → 28,224 features, 18 actions).\n",
|
||
"\n",
|
||
"**Agents**:\n",
|
||
"- **Random** — random baseline (no training needed)\n",
|
||
"- **SARSA** — linear approximation, semi-gradient TD(0)\n",
|
||
"- **Q-Learning** — linear approximation, off-policy\n",
|
||
"- **DQN** — NumPy MLP (2 hidden layers of 128, experience replay, target network)\n",
|
||
"- **Monte Carlo** — linear approximation, first-visit returns\n",
|
||
"\n",
|
||
"**Workflow**:\n",
|
||
"1. Train **one** selected agent (`AGENT_TO_TRAIN`)\n",
|
||
"2. Save its weights to `checkpoints/<agent>.pkl`\n",
|
||
"3. Repeat later for another agent without retraining previous ones\n",
|
||
"4. Load all saved checkpoints before the final evaluation"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"id": "6f6ba8df",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Observation shape : (4, 84, 84)\n",
|
||
"Feature vector dim: 28224\n",
|
||
"Number of actions : 18\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Create environment\n",
|
||
"env = create_env()\n",
|
||
"obs, _info = env.reset()\n",
|
||
"\n",
|
||
"n_actions = int(env.action_space.n)\n",
|
||
"n_features = int(np.prod(obs.shape))\n",
|
||
"\n",
|
||
"print(f\"Observation shape : {obs.shape}\")\n",
|
||
"print(f\"Feature vector dim: {n_features}\")\n",
|
||
"print(f\"Number of actions : {n_actions}\")\n",
|
||
"\n",
|
||
"# Instantiate agents\n",
|
||
"agent_random = RandomAgent(seed=42, action_space=int(n_actions))\n",
|
||
"agent_sarsa = SarsaAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
|
||
"agent_q = QLearningAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
|
||
"agent_dqn = DQNAgent(n_features=n_features, n_actions=n_actions, lr=1e-4)\n",
|
||
"agent_mc = MonteCarloAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
|
||
"\n",
|
||
"agents = {\n",
|
||
" \"Random\": agent_random,\n",
|
||
" \"SARSA\": agent_sarsa,\n",
|
||
" \"Q-Learning\": agent_q,\n",
|
||
" \"DQN\": agent_dqn,\n",
|
||
" \"Monte Carlo\": agent_mc,\n",
|
||
"}\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "4d449701",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Selected agent: SARSA\n",
|
||
"Checkpoint path: checkpoints/sarsa.pkl\n",
|
||
"\n",
|
||
"============================================================\n",
|
||
"Training: SARSA (5000 episodes)\n",
|
||
"============================================================\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Selected agent: SARSA\n",
|
||
"Checkpoint path: checkpoints/sarsa.pkl\n",
|
||
"\n",
|
||
"============================================================\n",
|
||
"Training: SARSA (5000 episodes)\n",
|
||
"============================================================\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "2809e1c28b15458daaeba87066ad1c74",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
"Training SARSA: 0%| | 0/5000 [00:00<?, ?it/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"AGENT_TO_TRAIN = \"SARSA\" # change to: \"Q-Learning\", \"DQN\", \"Monte Carlo\", \"Random\"\n",
|
||
"TRAINING_EPISODES = 5000\n",
|
||
"FORCE_RETRAIN = False\n",
|
||
"\n",
|
||
"if AGENT_TO_TRAIN not in agents:\n",
|
||
" msg = f\"Unknown agent '{AGENT_TO_TRAIN}'. Available: {list(agents)}\"\n",
|
||
" raise ValueError(msg)\n",
|
||
"\n",
|
||
"training_histories: dict[str, list[float]] = {}\n",
|
||
"agent = agents[AGENT_TO_TRAIN]\n",
|
||
"ckpt_name = AGENT_TO_TRAIN.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n",
|
||
"ckpt_path = CHECKPOINT_DIR / ckpt_name\n",
|
||
"\n",
|
||
"print(f\"Selected agent: {AGENT_TO_TRAIN}\")\n",
|
||
"print(f\"Checkpoint path: {ckpt_path}\")\n",
|
||
"\n",
|
||
"if AGENT_TO_TRAIN == \"Random\":\n",
|
||
" print(\"Random is a baseline and is not trained.\")\n",
|
||
" training_histories[AGENT_TO_TRAIN] = []\n",
|
||
"elif ckpt_path.exists() and not FORCE_RETRAIN:\n",
|
||
" agent.load(str(ckpt_path))\n",
|
||
" print(\"Checkpoint found -> weights loaded, training skipped.\")\n",
|
||
" training_histories[AGENT_TO_TRAIN] = []\n",
|
||
"else:\n",
|
||
" print(f\"\\n{'='*60}\")\n",
|
||
" print(f\"Training: {AGENT_TO_TRAIN} ({TRAINING_EPISODES} episodes)\")\n",
|
||
" print(f\"{'='*60}\")\n",
|
||
"\n",
|
||
" training_histories[AGENT_TO_TRAIN] = train_agent(\n",
|
||
" env=env,\n",
|
||
" agent=agent,\n",
|
||
" name=AGENT_TO_TRAIN,\n",
|
||
" episodes=TRAINING_EPISODES,\n",
|
||
" epsilon_start=1.0,\n",
|
||
" epsilon_end=0.05,\n",
|
||
" epsilon_decay=0.999,\n",
|
||
" )\n",
|
||
"\n",
|
||
" avg_last_100 = np.mean(training_histories[AGENT_TO_TRAIN][-100:])\n",
|
||
" print(f\"-> {AGENT_TO_TRAIN} avg reward (last 100 eps): {avg_last_100:.2f}\")\n",
|
||
"\n",
|
||
" agent.save(str(ckpt_path))\n",
|
||
" print(\"Checkpoint saved.\")\n",
|
||
"\n",
|
||
"if training_histories.get(AGENT_TO_TRAIN):\n",
|
||
" plot_training_curves(training_histories, window=100)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "0fc0c643",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Load all available checkpoints before final evaluation\n",
|
||
"CHECKPOINT_DIR = Path(\"checkpoints\")\n",
|
||
"loaded_agents: list[str] = []\n",
|
||
"missing_agents: list[str] = []\n",
|
||
"\n",
|
||
"for name, agent in agents.items():\n",
|
||
" if name == \"Random\":\n",
|
||
" continue\n",
|
||
"\n",
|
||
" ckpt_name = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n",
|
||
" ckpt_path = CHECKPOINT_DIR / ckpt_name\n",
|
||
"\n",
|
||
" if ckpt_path.exists():\n",
|
||
" agent.load(str(ckpt_path))\n",
|
||
" loaded_agents.append(name)\n",
|
||
" else:\n",
|
||
" missing_agents.append(name)\n",
|
||
"\n",
|
||
"print(f\"Loaded checkpoints: {loaded_agents}\")\n",
|
||
"if missing_agents:\n",
|
||
" print(f\"Missing checkpoints (untrained/not saved yet): {missing_agents}\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "0e5f2c49",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Final Evaluation\n",
|
||
"\n",
|
||
"Each agent plays 20 episodes against the built-in AI with no exploration (ε = 0).\n",
|
||
"Performance is compared via mean reward and win rate."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "70f5d5cd",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Build evaluation set: Random + agents with existing checkpoints\n",
|
||
"eval_agents: dict[str, Agent] = {\"Random\": agents[\"Random\"]}\n",
|
||
"missing_agents: list[str] = []\n",
|
||
"\n",
|
||
"for name, agent in agents.items():\n",
|
||
" if name == \"Random\":\n",
|
||
" continue\n",
|
||
"\n",
|
||
" ckpt_name = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n",
|
||
" ckpt_path = CHECKPOINT_DIR / ckpt_name\n",
|
||
"\n",
|
||
" if ckpt_path.exists():\n",
|
||
" agent.load(str(ckpt_path))\n",
|
||
" eval_agents[name] = agent\n",
|
||
" else:\n",
|
||
" missing_agents.append(name)\n",
|
||
"\n",
|
||
"print(f\"Agents evaluated: {list(eval_agents.keys())}\")\n",
|
||
"if missing_agents:\n",
|
||
" print(f\"Skipped (no checkpoint yet): {missing_agents}\")\n",
|
||
"\n",
|
||
"if len(eval_agents) < 2:\n",
|
||
" raise RuntimeError(\"Train at least one non-random agent before final evaluation.\")\n",
|
||
"\n",
|
||
"results = evaluate_tournament(env, eval_agents, episodes_per_agent=20)\n",
|
||
"plot_evaluation_comparison(results)\n",
|
||
"\n",
|
||
"# Print summary table\n",
|
||
"print(f\"\\n{'Agent':<15} {'Mean Reward':>12} {'Std':>8} {'Win Rate':>10}\")\n",
|
||
"print(\"-\" * 48)\n",
|
||
"for name, res in results.items():\n",
|
||
" print(f\"{name:<15} {res['mean_reward']:>12.2f} {res['std_reward']:>8.2f} {res['win_rate']:>9.1%}\")\n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "studies (3.13.9)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.13.9"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|