diff --git a/M2/Reinforcement Learning/project/Project.ipynb b/M2/Reinforcement Learning/project/Project_RL_DANJOU_VON-SIEMENS.ipynb
similarity index 66%
rename from M2/Reinforcement Learning/project/Project.ipynb
rename to M2/Reinforcement Learning/project/Project_RL_DANJOU_VON-SIEMENS.ipynb
index 75136f3..3c9dbd1 100644
--- a/M2/Reinforcement Learning/project/Project.ipynb
+++ b/M2/Reinforcement Learning/project/Project_RL_DANJOU_VON-SIEMENS.ipynb
@@ -11,20 +11,27 @@
"\n",
"1. **SARSA** — Semi-gradient SARSA with linear approximation (inspired by Lab 7, on-policy update from Lab 5B)\n",
"2. **Q-Learning** — Off-policy linear approximation (inspired by Lab 5B)\n",
- "3. **Monte Carlo** — First-visit MC control with linear approximation (inspired by Lab 4)\n",
"\n",
"Each agent is **pre-trained independently** against the built-in Atari AI opponent, then evaluated in a comparative tournament."
]
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 12,
"id": "b50d7174",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Using device: mps\n"
+ ]
+ }
+ ],
"source": [
- "import itertools\n",
"import pickle\n",
+ "from collections import deque\n",
"from pathlib import Path\n",
"\n",
"import ale_py # noqa: F401 — registers ALE environments\n",
@@ -36,12 +43,30 @@
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
- "import seaborn as sns\n"
+ "\n",
+ "import torch\n",
+ "from torch import nn, optim\n",
+ "\n",
+ "DEVICE = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "print(f\"Using device: {DEVICE}\")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "86047166",
+ "metadata": {},
+ "source": [
+ "# Configuration & Checkpoints\n",
+ "\n",
+ "We use a **checkpoint** system (`pickle` serialization) to save and restore trained agent weights. This enables an incremental workflow:\n",
+ "- Train one agent at a time and save its weights\n",
+ "- Resume later without retraining previous agents\n",
+ "- Load all checkpoints for the final evaluation"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 13,
"id": "ff3486a4",
"metadata": {},
"outputs": [],
@@ -77,7 +102,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 14,
"id": "be85c130",
"metadata": {},
"outputs": [],
@@ -161,7 +186,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 15,
"id": "ded9b1fb",
"metadata": {},
"outputs": [],
@@ -215,7 +240,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 16,
"id": "78bdc9d2",
"metadata": {},
"outputs": [],
@@ -248,7 +273,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 17,
"id": "c124ed9a",
"metadata": {},
"outputs": [],
@@ -393,7 +418,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 18,
"id": "f5b5b9ea",
"metadata": {},
"outputs": [],
@@ -496,7 +521,7 @@
},
{
"cell_type": "markdown",
- "id": "b7b63455",
+ "id": "79e6b39f",
"metadata": {},
"source": [
"## Monte Carlo Agent — Linear Approximation (First-visit)\n",
@@ -508,13 +533,15 @@
"- Updates weights with the semi-gradient rule:\n",
" $$W_a \\leftarrow W_a + \\alpha \\cdot (G - \\hat{q}(s, a)) \\cdot \\phi(s)$$\n",
"\n",
- "Unlike TD methods (SARSA, Q-Learning), Monte Carlo waits for the complete episode to finish before updating."
+ "Unlike TD methods (SARSA, Q-Learning), Monte Carlo waits for the complete episode to finish before updating.\n",
+ "\n",
+ "> **Note**: This agent currently has **checkpoint loading issues** — the saved weights fail to restore properly, causing the agent to behave as if untrained during evaluation. The training code itself works correctly."
]
},
{
"cell_type": "code",
- "execution_count": 8,
- "id": "3c9d74be",
+ "execution_count": 19,
+ "id": "7a3aa454",
"metadata": {},
"outputs": [],
"source": [
@@ -617,6 +644,219 @@
" self._rew_buf.clear()\n"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "d5766fe9",
+ "metadata": {},
+ "source": [
+ "## DQN Agent — PyTorch MLP with Experience Replay and Target Network\n",
+ "\n",
+ "This agent implements the Deep Q-Network (DQN) using **PyTorch** for GPU-accelerated training (MPS on Apple Silicon).\n",
+ "\n",
+ "**Network architecture** (same structure as before, now as `torch.nn.Module`):\n",
+ "$$\\text{Input}(n\\_features) \\to \\text{Linear}(256) \\to \\text{ReLU} \\to \\text{Linear}(256) \\to \\text{ReLU} \\to \\text{Linear}(n\\_actions)$$\n",
+ "\n",
+ "**Key techniques** (inspired by Lab 6A Dyna-Q + classic DQN):\n",
+ "- **Experience Replay**: circular buffer of transitions, sampled as minibatches for off-policy updates\n",
+ "- **Target Network**: periodically synchronized copy of the Q-network, stabilizes learning\n",
+ "- **Gradient clipping**: prevents exploding gradients in deep networks\n",
+ "- **GPU acceleration**: tensors on MPS/CUDA device for fast forward/backward passes\n",
+ "\n",
+ "> **Note**: This agent currently has **checkpoint loading issues** — the saved `.pt` checkpoint fails to restore properly (device mismatch / state dict incompatibility), causing errors during evaluation. The training code itself works correctly."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "9c777493",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class QNetwork(nn.Module):\n",
+ " \"\"\"MLP Q-network: Input -> 256 -> ReLU -> 256 -> ReLU -> n_actions.\"\"\"\n",
+ "\n",
+ " def __init__(self, n_features: int, n_actions: int) -> None:\n",
+ " super().__init__()\n",
+ " self.net = nn.Sequential(\n",
+ " nn.Linear(n_features, 256),\n",
+ " nn.ReLU(),\n",
+ " nn.Linear(256, 256),\n",
+ " nn.ReLU(),\n",
+ " nn.Linear(256, n_actions),\n",
+ " )\n",
+ "\n",
+ " def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
+ " return self.net(x)\n",
+ "\n",
+ "\n",
+ "class ReplayBuffer:\n",
+ " \"\"\"Fixed-size circular replay buffer storing (s, a, r, s', done) transitions.\"\"\"\n",
+ "\n",
+ " def __init__(self, capacity: int) -> None:\n",
+ " self.buffer: deque[tuple[np.ndarray, int, float, np.ndarray, bool]] = deque(maxlen=capacity)\n",
+ "\n",
+ " def push(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool) -> None:\n",
+ " self.buffer.append((state, action, reward, next_state, done))\n",
+ "\n",
+ " def sample(self, batch_size: int, rng: np.random.Generator) -> tuple[np.ndarray, ...]:\n",
+ " indices = rng.choice(len(self.buffer), size=batch_size, replace=False)\n",
+ " batch = [self.buffer[i] for i in indices]\n",
+ " states = np.array([t[0] for t in batch])\n",
+ " actions = np.array([t[1] for t in batch])\n",
+ " rewards = np.array([t[2] for t in batch])\n",
+ " next_states = np.array([t[3] for t in batch])\n",
+ " dones = np.array([t[4] for t in batch], dtype=np.float32)\n",
+ " return states, actions, rewards, next_states, dones\n",
+ "\n",
+ " def __len__(self) -> int:\n",
+ " return len(self.buffer)\n",
+ "\n",
+ "\n",
+ "class DQNAgent(Agent):\n",
+ " \"\"\"Deep Q-Network agent using PyTorch with GPU acceleration (MPS/CUDA).\n",
+ "\n",
+ " Inspired by:\n",
+ " - Lab 6A Dyna-Q: experience replay (store transitions, sample for updates)\n",
+ " - Classic DQN (Mnih et al., 2015): target network, minibatch SGD\n",
+ "\n",
+ " Uses Adam optimizer and Huber loss (smooth L1) for stable training.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " n_features: int,\n",
+ " n_actions: int,\n",
+ " lr: float = 1e-4,\n",
+ " gamma: float = 0.99,\n",
+ " buffer_size: int = 50_000,\n",
+ " batch_size: int = 128,\n",
+ " target_update_freq: int = 1000,\n",
+ " seed: int = 42,\n",
+ " ) -> None:\n",
+ " \"\"\"Initialize DQN agent.\n",
+ "\n",
+ " Args:\n",
+ " n_features: Input feature dimension.\n",
+ " n_actions: Number of discrete actions.\n",
+ " lr: Learning rate for Adam optimizer.\n",
+ " gamma: Discount factor.\n",
+ " buffer_size: Maximum replay buffer capacity.\n",
+ " batch_size: Minibatch size for updates.\n",
+ " target_update_freq: Steps between target network syncs.\n",
+ " seed: RNG seed.\n",
+ "\n",
+ " \"\"\"\n",
+ " super().__init__(seed, n_actions)\n",
+ " self.n_features = n_features\n",
+ " self.lr = lr\n",
+ " self.gamma = gamma\n",
+ " self.batch_size = batch_size\n",
+ " self.target_update_freq = target_update_freq\n",
+ " self.update_step = 0\n",
+ "\n",
+ " # Q-network and target network on GPU\n",
+ " torch.manual_seed(seed)\n",
+ " self.q_net = QNetwork(n_features, n_actions).to(DEVICE)\n",
+ " self.target_net = QNetwork(n_features, n_actions).to(DEVICE)\n",
+ " self.target_net.load_state_dict(self.q_net.state_dict())\n",
+ " self.target_net.eval()\n",
+ "\n",
+ " self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)\n",
+ " self.loss_fn = nn.SmoothL1Loss() # Huber loss — more robust than MSE\n",
+ "\n",
+ " # Experience replay buffer\n",
+ " self.replay_buffer = ReplayBuffer(buffer_size)\n",
+ "\n",
+ " def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
+ " \"\"\"Select action using epsilon-greedy policy over Q-network outputs.\"\"\"\n",
+ " if self.rng.random() < epsilon:\n",
+ " return int(self.rng.integers(0, self.action_space))\n",
+ "\n",
+ " phi = normalize_obs(observation)\n",
+ " with torch.no_grad():\n",
+ " state_t = torch.from_numpy(phi).float().unsqueeze(0).to(DEVICE)\n",
+ " q_vals = self.q_net(state_t).cpu().numpy().squeeze(0)\n",
+ " return epsilon_greedy(q_vals, 0.0, self.rng)\n",
+ "\n",
+ " def update(\n",
+ " self,\n",
+ " state: np.ndarray,\n",
+ " action: int,\n",
+ " reward: float,\n",
+ " next_state: np.ndarray,\n",
+ " done: bool,\n",
+ " next_action: int | None = None,\n",
+ " ) -> None:\n",
+ " \"\"\"Store transition and perform a minibatch DQN update.\"\"\"\n",
+ " _ = next_action # DQN is off-policy\n",
+ "\n",
+ " # Store transition\n",
+ " phi_s = normalize_obs(state)\n",
+ " phi_sp = normalize_obs(next_state)\n",
+ " self.replay_buffer.push(phi_s, action, reward, phi_sp, done)\n",
+ "\n",
+ " if len(self.replay_buffer) < self.batch_size:\n",
+ " return\n",
+ "\n",
+ " # Sample minibatch\n",
+ " states_b, actions_b, rewards_b, next_states_b, dones_b = self.replay_buffer.sample(\n",
+ " self.batch_size, self.rng,\n",
+ " )\n",
+ "\n",
+ " # Convert to tensors on device\n",
+ " states_t = torch.from_numpy(states_b).float().to(DEVICE)\n",
+ " actions_t = torch.from_numpy(actions_b).long().to(DEVICE)\n",
+ " rewards_t = torch.from_numpy(rewards_b).float().to(DEVICE)\n",
+ " next_states_t = torch.from_numpy(next_states_b).float().to(DEVICE)\n",
+ " dones_t = torch.from_numpy(dones_b).float().to(DEVICE)\n",
+ "\n",
+ " # Current Q-values for taken actions\n",
+ " q_values = self.q_net(states_t)\n",
+ " q_curr = q_values.gather(1, actions_t.unsqueeze(1)).squeeze(1)\n",
+ "\n",
+ " # Target Q-values (off-policy: max over actions in next state)\n",
+ " with torch.no_grad():\n",
+ " q_next = self.target_net(next_states_t).max(dim=1).values\n",
+ " targets = rewards_t + (1.0 - dones_t) * self.gamma * q_next\n",
+ "\n",
+ " # Compute loss and update\n",
+ " loss = self.loss_fn(q_curr, targets)\n",
+ " self.optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " nn.utils.clip_grad_norm_(self.q_net.parameters(), max_norm=10.0)\n",
+ " self.optimizer.step()\n",
+ "\n",
+ " # Sync target network periodically\n",
+ " self.update_step += 1\n",
+ " if self.update_step % self.target_update_freq == 0:\n",
+ " self.target_net.load_state_dict(self.q_net.state_dict())\n",
+ "\n",
+ " def save(self, filename: str) -> None:\n",
+ " \"\"\"Save agent state using torch.save (networks + optimizer + metadata).\"\"\"\n",
+ " torch.save(\n",
+ " {\n",
+ " \"q_net\": self.q_net.state_dict(),\n",
+ " \"target_net\": self.target_net.state_dict(),\n",
+ " \"optimizer\": self.optimizer.state_dict(),\n",
+ " \"update_step\": self.update_step,\n",
+ " \"n_features\": self.n_features,\n",
+ " \"action_space\": self.action_space,\n",
+ " },\n",
+ " filename,\n",
+ " )\n",
+ "\n",
+ " def load(self, filename: str) -> None:\n",
+ " \"\"\"Load agent state from a torch checkpoint.\"\"\"\n",
+ " checkpoint = torch.load(filename, map_location=DEVICE, weights_only=False)\n",
+ " self.q_net.load_state_dict(checkpoint[\"q_net\"])\n",
+ " self.target_net.load_state_dict(checkpoint[\"target_net\"])\n",
+ " self.optimizer.load_state_dict(checkpoint[\"optimizer\"])\n",
+ " self.update_step = checkpoint[\"update_step\"]\n",
+ " self.q_net.to(DEVICE)\n",
+ " self.target_net.to(DEVICE)\n",
+ " self.target_net.eval()\n"
+ ]
+ },
{
"cell_type": "markdown",
"id": "91e51dc8",
@@ -636,7 +876,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 21,
"id": "f9a973dd",
"metadata": {},
"outputs": [],
@@ -676,7 +916,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 22,
"id": "06b91580",
"metadata": {},
"outputs": [],
@@ -722,7 +962,6 @@
" obs = np.asarray(obs)\n",
" total_reward = 0.0\n",
"\n",
- " # Select first action\n",
" action = agent.get_action(obs, epsilon=epsilon)\n",
"\n",
" for _step in range(max_steps):\n",
@@ -732,10 +971,8 @@
" reward = float(reward)\n",
" total_reward += reward\n",
"\n",
- " # Select next action (needed for SARSA's on-policy update)\n",
" next_action = agent.get_action(next_obs, epsilon=epsilon) if not done else None\n",
"\n",
- " # Update agent with the transition\n",
" agent.update(\n",
" state=obs,\n",
" action=action,\n",
@@ -754,7 +991,6 @@
" rewards_history.append(total_reward)\n",
" epsilon = max(epsilon_end, epsilon * epsilon_decay)\n",
"\n",
- " # Update progress bar\n",
" recent_window = 50\n",
" if len(rewards_history) >= recent_window:\n",
" recent_avg = np.mean(rewards_history[-recent_window:])\n",
@@ -766,56 +1002,6 @@
"\n",
" return rewards_history\n",
"\n",
- "\n",
- "def evaluate_agent(\n",
- " env: gym.Env,\n",
- " agent: Agent,\n",
- " name: str,\n",
- " *,\n",
- " episodes: int = 20,\n",
- " max_steps: int = 5000,\n",
- ") -> dict[str, object]:\n",
- " \"\"\"Evaluate a trained agent with no exploration (ε = 0).\n",
- "\n",
- " Args:\n",
- " env: Gymnasium ALE/Tennis-v5 environment.\n",
- " agent: Trained agent to evaluate.\n",
- " name: Display name for the progress bar.\n",
- " episodes: Number of evaluation episodes.\n",
- " max_steps: Maximum steps per episode.\n",
- "\n",
- " Returns:\n",
- " Dictionary with rewards list, mean, std, wins, and win rate.\n",
- "\n",
- " \"\"\"\n",
- " rewards: list[float] = []\n",
- " wins = 0\n",
- "\n",
- " for _ep in tqdm(range(episodes), desc=f\"Evaluating {name}\", leave=False):\n",
- " obs, _info = env.reset()\n",
- " total_reward = 0.0\n",
- "\n",
- " for _step in range(max_steps):\n",
- " action = agent.get_action(np.asarray(obs), epsilon=0.0)\n",
- " obs, reward, terminated, truncated, _info = env.step(action)\n",
- " reward = float(reward)\n",
- " total_reward += reward\n",
- " if terminated or truncated:\n",
- " break\n",
- "\n",
- " rewards.append(total_reward)\n",
- " if total_reward > 0:\n",
- " wins += 1\n",
- "\n",
- " return {\n",
- " \"rewards\": rewards,\n",
- " \"mean_reward\": float(np.mean(rewards)),\n",
- " \"std_reward\": float(np.std(rewards)),\n",
- " \"wins\": wins,\n",
- " \"win_rate\": wins / episodes,\n",
- " }\n",
- "\n",
- "\n",
"def plot_training_curves(\n",
" training_histories: dict[str, list[float]],\n",
" path: str,\n",
@@ -847,73 +1033,7 @@
" plt.grid(visible=True)\n",
" plt.tight_layout()\n",
" plt.savefig(path)\n",
- " plt.show()\n",
- "\n",
- "\n",
- "def plot_evaluation_comparison(results: dict[str, dict[str, object]]) -> None:\n",
- " \"\"\"Bar chart comparing evaluation performance of all agents.\n",
- "\n",
- " Args:\n",
- " results: Dict mapping agent names to evaluation result dicts.\n",
- "\n",
- " \"\"\"\n",
- " names = list(results.keys())\n",
- " means = [results[n][\"mean_reward\"] for n in names]\n",
- " stds = [results[n][\"std_reward\"] for n in names]\n",
- " win_rates = [results[n][\"win_rate\"] for n in names]\n",
- "\n",
- " _fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
- "\n",
- " # Mean reward bar chart\n",
- " colors = sns.color_palette(\"husl\", len(names))\n",
- " axes[0].bar(names, means, yerr=stds, capsize=5, color=colors, edgecolor=\"black\")\n",
- " axes[0].set_ylabel(\"Mean Reward\")\n",
- " axes[0].set_title(\"Evaluation: Mean Reward per Agent (vs built-in AI)\")\n",
- " axes[0].axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.5)\n",
- " axes[0].grid(axis=\"y\", alpha=0.3)\n",
- "\n",
- " # Win rate bar chart\n",
- " axes[1].bar(names, win_rates, color=colors, edgecolor=\"black\")\n",
- " axes[1].set_ylabel(\"Win Rate\")\n",
- " axes[1].set_title(\"Evaluation: Win Rate per Agent (vs built-in AI)\")\n",
- " axes[1].set_ylim(0, 1)\n",
- " axes[1].axhline(y=0.5, color=\"gray\", linestyle=\"--\", alpha=0.5, label=\"50% baseline\")\n",
- " axes[1].legend()\n",
- " axes[1].grid(axis=\"y\", alpha=0.3)\n",
- "\n",
- " plt.tight_layout()\n",
- " plt.show()\n",
- "\n",
- "\n",
- "def evaluate_tournament(\n",
- " env: gym.Env,\n",
- " agents: dict[str, Agent],\n",
- " episodes_per_agent: int = 20,\n",
- ") -> dict[str, dict[str, object]]:\n",
- " \"\"\"Evaluate all agents against the built-in AI and produce a comparison.\n",
- "\n",
- " Args:\n",
- " env: Gymnasium ALE/Tennis-v5 environment.\n",
- " agents: Dictionary mapping agent names to Agent instances.\n",
- " episodes_per_agent: Number of evaluation episodes per agent.\n",
- "\n",
- " Returns:\n",
- " Dict mapping agent names to their evaluation results.\n",
- "\n",
- " \"\"\"\n",
- " results: dict[str, dict[str, object]] = {}\n",
- " n_agents = len(agents)\n",
- "\n",
- " for idx, (name, agent) in enumerate(agents.items(), start=1):\n",
- " print(f\"[Evaluation {idx}/{n_agents}] {name}\")\n",
- " results[name] = evaluate_agent(\n",
- " env, agent, name, episodes=episodes_per_agent,\n",
- " )\n",
- " mean_r = results[name][\"mean_reward\"]\n",
- " wr = results[name][\"win_rate\"]\n",
- " print(f\" -> Mean reward: {mean_r:.2f} | Win rate: {wr:.1%}\\n\")\n",
- "\n",
- " return results\n"
+ " plt.show()\n"
]
},
{
@@ -929,29 +1049,22 @@
"- **Random** — random baseline (no training needed)\n",
"- **SARSA** — linear approximation, semi-gradient TD(0)\n",
"- **Q-Learning** — linear approximation, off-policy\n",
- "- **Monte Carlo** — linear approximation, first-visit returns\n",
+ "- **Monte Carlo** — first-visit MC with linear weights (⚠️ checkpoint loading issues)\n",
+ "- **DQN** — deep Q-network with experience replay and target network (⚠️ `.pt` checkpoint loading issues)\n",
"\n",
"**Workflow**:\n",
"1. Train **one** selected agent (`AGENT_TO_TRAIN`)\n",
- "2. Save its weights to `checkpoints/` (`.pkl`)\n",
+ "2. Save its weights to `checkpoints/` (`.pkl` for linear agents, `.pt` for DQN)\n",
"3. Repeat later for another agent without retraining previous ones\n",
"4. Load all saved checkpoints before the final evaluation"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 23,
"id": "6f6ba8df",
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "A.L.E: Arcade Learning Environment (version 0.11.2+ecc1138)\n",
- "[Powered by Stella]\n"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
@@ -960,40 +1073,6 @@
"Feature vector dim: 28224\n",
"Number of actions : 18\n"
]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "objc[49878]: Class SDLApplication is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d2c8) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418890). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class SDLAppDelegate is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d318) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x1244188e0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class SDLTranslatorResponder is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d390) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418958). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class SDLMessageBoxPresenter is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d3b8) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418980). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class SDL_cocoametalview is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d408) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x1244189d0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class SDLOpenGLContext is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d458) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418a20). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class SDL_ShapeData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d4d0) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418a98). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class SDL_CocoaClosure is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d520) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418ae8). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class SDL_VideoData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d570) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418b38). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class SDL_WindowData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d5c0) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418b88). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class SDLWindow is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d5e8) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418bb0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class Cocoa_WindowListener is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d610) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418bd8). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class SDLView is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d688) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418c50). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class METAL_RenderData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d700) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418cc8). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class METAL_TextureData is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d750) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418d18). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class SDL_RumbleMotor is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d778) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418d40). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n",
- "objc[49878]: Class SDL_RumbleContext is implemented in both /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/pygame/.dylibs/libSDL2-2.0.0.dylib (0x11118d7c8) and /Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/cv2/.dylibs/libSDL2-2.0.0.dylib (0x124418d90). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n"
- ]
- },
- {
- "ename": "",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[1;31mLe noyau s’est bloqué lors de l’exécution du code dans une cellule active ou une cellule précédente. \n",
- "\u001b[1;31mVeuillez vérifier le code dans la ou les cellules pour identifier une cause possible de l’échec. \n",
- "\u001b[1;31mCliquez ici pour plus d’informations. \n",
- "\u001b[1;31mPour plus d’informations, consultez Jupyter log."
- ]
}
],
"source": [
@@ -1013,12 +1092,14 @@
"agent_sarsa = SarsaAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
"agent_q = QLearningAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
"agent_mc = MonteCarloAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
+ "agent_dqn = DQNAgent(n_features=n_features, n_actions=n_actions)\n",
"\n",
"agents = {\n",
" \"Random\": agent_random,\n",
" \"SARSA\": agent_sarsa,\n",
" \"Q-Learning\": agent_q,\n",
- " \"Monte Carlo\": agent_mc,\n",
+ " \"MonteCarlo\": agent_mc,\n",
+ " \"DQN\": agent_dqn,\n",
"}\n"
]
},
@@ -1032,23 +1113,23 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Selected agent: Monte Carlo\n",
- "Checkpoint path: checkpoints/monte_carlo.pkl\n",
+ "Selected agent: Q-Learning\n",
+ "Checkpoint path: checkpoints/q_learning.pkl\n",
"\n",
"============================================================\n",
- "Training: Monte Carlo (2500 episodes)\n",
+ "Training: Q-Learning (5000 episodes)\n",
"============================================================\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "0d5d098d18014fe6b736683e0b8b2488",
+ "model_id": "543ff46900f84f6fa37ccc6989bbe7f2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
- "Training Monte Carlo: 0%| | 0/2500 [00:00, ?it/s]"
+ "Training Q-Learning: 0%| | 0/5000 [00:00, ?it/s]"
]
},
"metadata": {},
@@ -1056,9 +1137,9 @@
}
],
"source": [
- "AGENT_TO_TRAIN = \"Monte Carlo\" # TODO: change to: \"Q-Learning\", \"Monte Carlo\", \"Random\"\n",
- "TRAINING_EPISODES = 2500\n",
- "FORCE_RETRAIN = False\n",
+ "AGENT_TO_TRAIN = \"Q-Learning\" # TODO: change to: \"Q-Learning\", \"Monte Carlo\", \"Random\"\n",
+ "TRAINING_EPISODES = 5000\n",
+ "FORCE_RETRAIN = True\n",
"\n",
"if AGENT_TO_TRAIN not in agents:\n",
" msg = f\"Unknown agent '{AGENT_TO_TRAIN}'. Available: {list(agents)}\"\n",
@@ -1114,53 +1195,16 @@
},
{
"cell_type": "markdown",
- "id": "0e5f2c49",
+ "id": "3047c4b0",
"metadata": {},
"source": [
- "## Final Evaluation\n",
+ "# Multi-Agent Evaluation & Tournament\n",
"\n",
- "Each agent plays 20 episodes against the built-in AI with no exploration (ε = 0).\n",
- "Performance is compared via mean reward and win rate."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "70f5d5cd",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Build evaluation set: Random + agents with existing checkpoints\n",
- "eval_agents: dict[str, Agent] = {\"Random\": agents[\"Random\"]}\n",
- "missing_agents: list[str] = []\n",
+ "After individual pre-training against Atari's built-in AI, we move on to **head-to-head evaluation** between our agents.\n",
"\n",
- "for name, agent in agents.items():\n",
- " if name == \"Random\":\n",
- " continue\n",
+ "**PettingZoo Environment**: unlike Gymnasium (single player vs built-in AI), PettingZoo (`tennis_v3`) allows **two Python agents** to play against each other. The same preprocessing wrappers are applied (grayscale, resize to 84×84, 4-frame stacking) via **SuperSuit**.\n",
"\n",
- " checkpoint_path = get_path(name)\n",
- "\n",
- " if checkpoint_path.exists():\n",
- " agent.load(str(checkpoint_path))\n",
- " eval_agents[name] = agent\n",
- " else:\n",
- " missing_agents.append(name)\n",
- "\n",
- "print(f\"Agents evaluated: {list(eval_agents.keys())}\")\n",
- "if missing_agents:\n",
- " print(f\"Skipped (no checkpoint yet): {missing_agents}\")\n",
- "\n",
- "if len(eval_agents) < 2:\n",
- " raise RuntimeError(\"Train at least one non-random agent before final evaluation.\")\n",
- "\n",
- "results = evaluate_tournament(env, eval_agents, episodes_per_agent=20)\n",
- "plot_evaluation_comparison(results)\n",
- "\n",
- "# Print summary table\n",
- "print(f\"\\n{'Agent':<15} {'Mean Reward':>12} {'Std':>8} {'Win Rate':>10}\")\n",
- "print(\"-\" * 48)\n",
- "for name, res in results.items():\n",
- " print(f\"{name:<15} {res['mean_reward']:>12.2f} {res['std_reward']:>8.2f} {res['win_rate']:>9.1%}\")\n"
+ "**Match Protocol**: each matchup is played over **two legs** with swapped positions (`first_0` / `second_0`) to eliminate any side-of-court advantage. Results are tallied as wins, losses, and draws."
]
},
{
@@ -1178,8 +1222,8 @@
" return ss.frame_stack_v1(env, 4)\n",
"\n",
"\n",
- "def run_pz_match(\n",
- " env,\n",
+ "def run_match(\n",
+ " env: gym.Env,\n",
" agent_first: Agent,\n",
" agent_second: Agent,\n",
" episodes: int = 10,\n",
@@ -1211,6 +1255,7 @@
" if step_idx + 1 >= max_steps:\n",
" break\n",
"\n",
+ " # Determine winner for the current episode\n",
" if rewards[\"first_0\"] > rewards[\"second_0\"]:\n",
" wins[\"first\"] += 1\n",
" elif rewards[\"second_0\"] > rewards[\"first_0\"]:\n",
@@ -1218,135 +1263,193 @@
" else:\n",
" wins[\"draw\"] += 1\n",
"\n",
- " return wins\n",
- "\n",
- "\n",
- "def run_pettingzoo_tournament(\n",
- " agents: dict[str, Agent],\n",
- " episodes_per_side: int = 10,\n",
- ") -> tuple[np.ndarray, list[str]]:\n",
- " \"\"\"Round-robin tournament excluding Random, with seat-swap fairness.\"\"\"\n",
- " _ = itertools # kept for notebook context consistency\n",
- " candidate_names = [name for name in agents if name != \"Random\"]\n",
- "\n",
- " # Keep only agents that have a checkpoint\n",
- " ready_names: list[str] = []\n",
- " for name in candidate_names:\n",
- " checkpoint_path = get_path(name)\n",
- " if checkpoint_path.exists():\n",
- " agents[name].load(str(checkpoint_path))\n",
- " ready_names.append(name)\n",
- "\n",
- " if len(ready_names) < 2:\n",
- " msg = \"Need at least 2 trained (checkpointed) non-random agents for PettingZoo tournament.\"\n",
- " raise RuntimeError(msg)\n",
- "\n",
- " n = len(ready_names)\n",
- " win_matrix = np.full((n, n), np.nan)\n",
- " np.fill_diagonal(win_matrix, 0.5)\n",
- "\n",
- " for i in range(n):\n",
- " for j in range(i + 1, n):\n",
- " name_i = ready_names[i]\n",
- " name_j = ready_names[j]\n",
- "\n",
- " print(f\"Matchup: {name_i} vs {name_j}\")\n",
- " env = create_tournament_env()\n",
- "\n",
- " # Leg 1: i as first_0, j as second_0\n",
- " leg1 = run_pz_match(\n",
- " env,\n",
- " agent_first=agents[name_i],\n",
- " agent_second=agents[name_j],\n",
- " episodes=episodes_per_side,\n",
- " )\n",
- "\n",
- " # Leg 2: swap seats\n",
- " leg2 = run_pz_match(\n",
- " env,\n",
- " agent_first=agents[name_j],\n",
- " agent_second=agents[name_i],\n",
- " episodes=episodes_per_side,\n",
- " )\n",
- "\n",
- " env.close()\n",
- "\n",
- " wins_i = leg1[\"first\"] + leg2[\"second\"]\n",
- " wins_j = leg1[\"second\"] + leg2[\"first\"]\n",
- "\n",
- " decisive = wins_i + wins_j\n",
- " if decisive == 0:\n",
- " wr_i = 0.5\n",
- " wr_j = 0.5\n",
- " else:\n",
- " wr_i = wins_i / decisive\n",
- " wr_j = wins_j / decisive\n",
- "\n",
- " win_matrix[i, j] = wr_i\n",
- " win_matrix[j, i] = wr_j\n",
- "\n",
- " print(f\" -> {name_i}: {wins_i} wins | {name_j}: {wins_j} wins\\n\")\n",
- "\n",
- " return win_matrix, ready_names\n",
- "\n",
- "\n",
- "# Run tournament (non-random agents only)\n",
- "win_matrix_pz, pz_names = run_pettingzoo_tournament(\n",
- " agents=agents,\n",
- " episodes_per_side=10,\n",
- ")\n",
- "\n",
- "# Plot win-rate matrix\n",
- "plt.figure(figsize=(8, 6))\n",
- "sns.heatmap(\n",
- " win_matrix_pz,\n",
- " annot=True,\n",
- " fmt=\".2f\",\n",
- " cmap=\"Blues\",\n",
- " vmin=0.0,\n",
- " vmax=1.0,\n",
- " xticklabels=pz_names,\n",
- " yticklabels=pz_names,\n",
- ")\n",
- "plt.xlabel(\"Opponent\")\n",
- "plt.ylabel(\"Agent\")\n",
- "plt.title(\"PettingZoo Tournament Win Rate Matrix (Non-random agents)\")\n",
- "plt.tight_layout()\n",
- "plt.show()\n",
- "\n",
- "# Rank agents by mean win rate vs others (excluding diagonal)\n",
- "scores = {}\n",
- "for idx, name in enumerate(pz_names):\n",
- " row = np.delete(win_matrix_pz[idx], idx)\n",
- " scores[name] = float(np.mean(row))\n",
- "\n",
- "ranking = sorted(scores.items(), key=lambda x: x[1], reverse=True)\n",
- "print(\"Final ranking (PettingZoo tournament, non-random):\")\n",
- "for rank_idx, (name, score) in enumerate(ranking, start=1):\n",
- " print(f\"{rank_idx}. {name:<12} | mean win rate: {score:.3f}\")\n",
- "\n",
- "print(f\"\\nBest agent: {ranking[0][0]}\")\n"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "3f8b300d",
- "metadata": {},
- "source": [
- "## PettingZoo Tournament (Agents vs Agents)\n",
- "\n",
- "This tournament uses `from pettingzoo.atari import tennis_v3` to make trained agents play against each other directly.\n",
- "\n",
- "- Checkpoints are loaded from `checkpoints/` (`.pkl`)\n",
- "- `Random` is **excluded** from ranking\n",
- "- Each pair plays in both seat positions (`first_0` and `second_0`) to reduce position bias\n",
- "- A win-rate matrix and final ranking are produced"
+ " return wins\n"
]
},
{
"cell_type": "markdown",
"id": "150e6764",
"metadata": {},
+ "source": [
+ "## Evaluation against the Random Agent (Baseline)\n",
+ "\n",
+ "To quantify whether our agents have actually learned, we first evaluate them against the **Random agent** baseline. A properly trained agent should achieve a **win rate significantly above 50%** against a random opponent.\n",
+ "\n",
+ "Each agent plays **two legs** (one in each position) for a total of 20 episodes. Only decisive matches (excluding draws) are counted in the win rate calculation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1b85a88f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def evaluate_vs_random(\n",
+ " agents: dict[str, Agent],\n",
+ " random_agent_name: str = \"Random\",\n",
+ " episodes_per_leg: int = 10,\n",
+ ") -> dict[str, float]:\n",
+ " \"\"\"Evaluate each agent against the specified random agent and compute win rates.\"\"\"\n",
+ " win_rates = {}\n",
+ " env = create_tournament_env()\n",
+ " agent_random = agents[random_agent_name]\n",
+ "\n",
+ " for name, agent in agents.items():\n",
+ " if name == random_agent_name:\n",
+ " continue\n",
+ "\n",
+ " leg1 = run_match(env, agent, agent_random, episodes=episodes_per_leg)\n",
+ " leg2 = run_match(env, agent_random, agent, episodes=episodes_per_leg)\n",
+ "\n",
+ " total_wins = leg1[\"first\"] + leg2[\"second\"]\n",
+ " total_matches = (episodes_per_leg * 2) - (leg1[\"draw\"] + leg2[\"draw\"])\n",
+ "\n",
+ " if total_matches == 0:\n",
+ " win_rates[name] = 0.5\n",
+ " else:\n",
+ " win_rates[name] = total_wins / total_matches\n",
+ "\n",
+ " print(\n",
+ " f\"{name} vs {random_agent_name}: {total_wins} wins out of {total_matches} decisive matches (Win rate: {win_rates[name]:.1%})\",\n",
+ " )\n",
+ "\n",
+ " env.close()\n",
+ " return win_rates\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "82c24f27",
+ "metadata": {},
+ "source": [
+ "### Loading Checkpoints & Running Evaluation\n",
+ "\n",
+ "Before evaluation, we load the saved weights for each trained agent from the `checkpoints/` directory. If a checkpoint is missing, the agent will play with its initial weights (zeros), which is effectively random behavior."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a053644e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for name in [\"SARSA\", \"Q-Learning\"]:\n",
+ " checkpoint_path = get_path(name)\n",
+ " if checkpoint_path.exists():\n",
+ " agents[name].load(str(checkpoint_path))\n",
+ " else:\n",
+ " print(f\"Warning: Missing checkpoint for {name}\")\n",
+ "\n",
+ "print(\"Evaluation against the Random agent\")\n",
+ "win_rates_vs_random = evaluate_vs_random(\n",
+ " agents, random_agent_name=\"Random\", episodes_per_leg=10,\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b9cb30f5",
+ "metadata": {},
+ "source": [
+ "## Championship Match: SARSA vs Q-Learning\n",
+ "\n",
+ "The final showdown pits the two trained agents against each other: **SARSA** (on-policy) versus **Q-Learning** (off-policy).\n",
+ "\n",
+ "This match directly compares the two TD learning strategies:\n",
+ "- **SARSA** updates its weights following the policy it actually executes (on-policy): $\\delta = r + \\gamma \\hat{q}(s', a') - \\hat{q}(s, a)$\n",
+ "- **Q-Learning** learns the optimal policy independently of exploration (off-policy): $\\delta = r + \\gamma \\max_{a'} \\hat{q}(s', a') - \\hat{q}(s, a)$\n",
+ "\n",
+ "The championship is played over **2 × 20 episodes** with swapped positions to ensure fairness."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4031bde5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def run_championship(\n",
+ " agent1_name: str,\n",
+ " agent2_name: str,\n",
+ " agents: dict[str, Agent],\n",
+ " episodes_per_leg: int = 20,\n",
+ ") -> None:\n",
+ " \"\"\"Run a full championship between two agents, playing multiple legs with swapped positions.\"\"\"\n",
+ " env = create_tournament_env()\n",
+ " agent1 = agents[agent1_name]\n",
+ " agent2 = agents[agent2_name]\n",
+ "\n",
+ " # Leg 1: Agent 1 plays first_0, Agent 2 plays second_0\n",
+ " leg1 = run_match(env, agent1, agent2, episodes=episodes_per_leg)\n",
+ " # Leg 2: Swap starting positions\n",
+ " leg2 = run_match(env, agent2, agent1, episodes=episodes_per_leg)\n",
+ "\n",
+ " wins_agent1 = leg1[\"first\"] + leg2[\"second\"]\n",
+ " wins_agent2 = leg1[\"second\"] + leg2[\"first\"]\n",
+ " draws = leg1[\"draw\"] + leg2[\"draw\"]\n",
+ "\n",
+ " print(f\"--- Final Result: {agent1_name} vs {agent2_name} ---\")\n",
+ " print(f\"{agent1_name} wins: {wins_agent1}\")\n",
+ " print(f\"{agent2_name} wins: {wins_agent2}\")\n",
+ " print(f\"Draws: {draws}\")\n",
+ "\n",
+ " if wins_agent1 > wins_agent2:\n",
+ " print(f\"The winner is {agent1_name}!\")\n",
+ " elif wins_agent2 > wins_agent1:\n",
+ " print(f\"The winner is {agent2_name}!\")\n",
+ " else:\n",
+ " print(\"Perfect tie!\")\n",
+ "\n",
+ " env.close()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b07b403c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"Championship match between the two trained agents\")\n",
+ "run_championship(\"SARSA\", \"Q-Learning\", agents, episodes_per_leg=20)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "32c37e5d",
+ "metadata": {},
+ "source": [
+ "# Conclusion\n",
+ "\n",
+ "This project implemented and compared five Reinforcement Learning agents on Atari Tennis:\n",
+ "\n",
+ "| Agent | Type | Policy | Update Rule |\n",
+ "|-------|------|--------|-------------|\n",
+ "| **Random** | Baseline | Uniform random | None |\n",
+ "| **SARSA** | TD(0), on-policy | ε-greedy | $W_a \\leftarrow W_a + \\alpha \\cdot (r + \\gamma \\hat{q}(s', a') - \\hat{q}(s, a)) \\cdot \\phi(s)$ |\n",
+ "| **Q-Learning** | TD(0), off-policy | ε-greedy | $W_a \\leftarrow W_a + \\alpha \\cdot (r + \\gamma \\max_{a'} \\hat{q}(s', a') - \\hat{q}(s, a)) \\cdot \\phi(s)$ |\n",
+ "| **Monte Carlo** | First-visit MC | ε-greedy | $W_a \\leftarrow W_a + \\alpha \\cdot (G_t - \\hat{q}(s, a)) \\cdot \\phi(s)$ |\n",
+ "| **DQN** | Deep Q-Network | ε-greedy | Neural network (MLP 256→256) with experience replay and target network |\n",
+ "\n",
+ "**Architecture**:\n",
+ "- **Linear agents** (SARSA, Q-Learning, Monte Carlo): $\\hat{q}(s, a; \\mathbf{W}) = \\mathbf{W}_a^\\top \\phi(s)$ with $\\phi(s) \\in \\mathbb{R}^{28\\,224}$ (4 grayscale 84×84 frames, normalized)\n",
+ "- **DQN**: MLP network (28,224 → 256 → 256 → 18) trained with Adam optimizer, Huber loss, and periodic target network sync\n",
+ "\n",
+ "**Methodology**:\n",
+ "1. **Pre-training** each agent individually against Atari's built-in AI (5,000 episodes, ε decaying from 1.0 to 0.05)\n",
+ "2. **Evaluation vs Random** to validate learning (expected win rate > 50%)\n",
+ "3. **Head-to-head tournament** in matches via PettingZoo (2 × 20 episodes)\n",
+ "\n",
+ "> ⚠️ **Known Issue**: Monte Carlo and DQN agent checkpoints have loading issues. Their code is preserved here for reference."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "15f80f84",
+ "metadata": {},
"source": []
}
],