diff --git a/M2/Reinforcement Learning/project/Project.ipynb b/M2/Reinforcement Learning/project/Project.ipynb index 01e5188..d79d5d1 100644 --- a/M2/Reinforcement Learning/project/Project.ipynb +++ b/M2/Reinforcement Learning/project/Project.ipynb @@ -11,30 +11,20 @@ "\n", "1. **SARSA** — Semi-gradient SARSA with linear approximation (inspired by Lab 7, on-policy update from Lab 5B)\n", "2. **Q-Learning** — Off-policy linear approximation (inspired by Lab 5B)\n", - "3. **DQN** — Deep Q-Network with PyTorch MLP, experience replay and target network (inspired by Lab 6A + classic DQN), GPU-accelerated via MPS\n", - "4. **Monte Carlo** — First-visit MC control with linear approximation (inspired by Lab 4)\n", + "3. **Monte Carlo** — First-visit MC control with linear approximation (inspired by Lab 4)\n", "\n", "Each agent is **pre-trained independently** against the built-in Atari AI opponent, then evaluated in a comparative tournament." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "id": "b50d7174", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "PyTorch device: cuda\n" - ] - } - ], + "outputs": [], "source": [ "import itertools\n", "import pickle\n", - "from collections import deque\n", "from pathlib import Path\n", "\n", "import ale_py # noqa: F401 — registers ALE environments\n", @@ -46,23 +36,12 @@ "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "import seaborn as sns\n", - "\n", - "import torch\n", - "from torch import nn, optim\n", - "\n", - "if torch.cuda.is_available():\n", - " DEVICE = torch.device(\"cuda\")\n", - "elif torch.backends.mps.is_available():\n", - " DEVICE = torch.device(\"mps\")\n", - "else:\n", - " DEVICE = torch.device(\"cpu\")\n", - "print(f\"PyTorch device: {DEVICE}\")\n" + "import seaborn as sns\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "ff3486a4", "metadata": {}, "outputs": [], @@ -71,11 +50,10 @@ "CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n", "\n", "\n", - "def _ckpt_path(name: str) -> Path:\n", - " \"\"\"Return the checkpoint path for an agent (DQN uses .pt, others use .pkl).\"\"\"\n", + "def get_path(name: str) -> Path:\n", + " \"\"\"Return the checkpoint path for an agent (.pkl).\"\"\"\n", " base = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\")\n", - " ext = \".pt\" if name == \"DQN\" else \".pkl\"\n", - " return CHECKPOINT_DIR / (base + ext)\n" + " return CHECKPOINT_DIR / (base + \".pkl\")\n" ] }, { @@ -99,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "id": "be85c130", "metadata": {}, "outputs": [], @@ -183,7 +161,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "id": "ded9b1fb", "metadata": {}, "outputs": [], @@ -237,7 +215,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "id": "78bdc9d2", "metadata": {}, "outputs": [], @@ -270,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "id": "c124ed9a", "metadata": {}, "outputs": [], @@ -415,7 +393,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "id": "f5b5b9ea", "metadata": {}, "outputs": [], @@ -516,227 +494,6 @@ " self.W[action] = np.nan_to_num(self.W[action], nan=0.0, posinf=1e6, neginf=-1e6)\n" ] }, - { - "cell_type": "markdown", - "id": "f644e2ef", - "metadata": {}, - "source": [ - "## DQN Agent — PyTorch MLP with Experience Replay and Target Network\n", - "\n", - "This agent implements the Deep Q-Network (DQN) using **PyTorch** for GPU-accelerated training (MPS on Apple Silicon).\n", - "\n", - "**Network architecture** (same structure as before, now as `torch.nn.Module`):\n", - "$$\\text{Input}(n\\_features) \\to \\text{Linear}(256) \\to \\text{ReLU} \\to \\text{Linear}(256) \\to \\text{ReLU} \\to \\text{Linear}(n\\_actions)$$\n", - "\n", - "**Key techniques** (inspired by Lab 6A Dyna-Q + classic DQN):\n", - "- **Experience Replay**: circular buffer of transitions, sampled as minibatches for off-policy updates\n", - "- **Target Network**: periodically synchronized copy of the Q-network, stabilizes learning\n", - "- **Gradient clipping**: prevents exploding gradients in deep networks\n", - "- **GPU acceleration**: tensors on MPS/CUDA device for fast forward/backward passes" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ec090a9a", - "metadata": {}, - "outputs": [], - "source": [ - "class QNetwork(nn.Module):\n", - " \"\"\"MLP Q-network: Input -> 256 -> ReLU -> 256 -> ReLU -> n_actions.\"\"\"\n", - "\n", - " def __init__(self, n_features: int, n_actions: int) -> None:\n", - " super().__init__()\n", - " self.net = nn.Sequential(\n", - " nn.Linear(n_features, 256),\n", - " nn.ReLU(),\n", - " nn.Linear(256, 256),\n", - " nn.ReLU(),\n", - " nn.Linear(256, n_actions),\n", - " )\n", - "\n", - " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", - " return self.net(x)\n", - "\n", - "\n", - "class ReplayBuffer:\n", - " \"\"\"Fixed-size circular replay buffer storing (s, a, r, s', done) transitions.\"\"\"\n", - "\n", - " def __init__(self, capacity: int) -> None:\n", - " self.buffer: deque[tuple[np.ndarray, int, float, np.ndarray, bool]] = deque(maxlen=capacity)\n", - "\n", - " def push(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool) -> None:\n", - " self.buffer.append((state, action, reward, next_state, done))\n", - "\n", - " def sample(self, batch_size: int, rng: np.random.Generator) -> tuple[np.ndarray, ...]:\n", - " indices = rng.choice(len(self.buffer), size=batch_size, replace=False)\n", - " batch = [self.buffer[i] for i in indices]\n", - " states = np.array([t[0] for t in batch])\n", - " actions = np.array([t[1] for t in batch])\n", - " rewards = np.array([t[2] for t in batch])\n", - " next_states = np.array([t[3] for t in batch])\n", - " dones = np.array([t[4] for t in batch], dtype=np.float32)\n", - " return states, actions, rewards, next_states, dones\n", - "\n", - " def __len__(self) -> int:\n", - " return len(self.buffer)\n", - "\n", - "\n", - "class DQNAgent(Agent):\n", - " \"\"\"Deep Q-Network agent using PyTorch with GPU acceleration (MPS/CUDA).\n", - "\n", - " Inspired by:\n", - " - Lab 6A Dyna-Q: experience replay (store transitions, sample for updates)\n", - " - Classic DQN (Mnih et al., 2015): target network, minibatch SGD\n", - "\n", - " Uses Adam optimizer and Huber loss (smooth L1) for stable training.\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " n_features: int,\n", - " n_actions: int,\n", - " lr: float = 1e-4,\n", - " gamma: float = 0.99,\n", - " buffer_size: int = 50_000,\n", - " batch_size: int = 128,\n", - " target_update_freq: int = 1000,\n", - " seed: int = 42,\n", - " ) -> None:\n", - " \"\"\"Initialize DQN agent.\n", - "\n", - " Args:\n", - " n_features: Input feature dimension.\n", - " n_actions: Number of discrete actions.\n", - " lr: Learning rate for Adam optimizer.\n", - " gamma: Discount factor.\n", - " buffer_size: Maximum replay buffer capacity.\n", - " batch_size: Minibatch size for updates.\n", - " target_update_freq: Steps between target network syncs.\n", - " seed: RNG seed.\n", - "\n", - " \"\"\"\n", - " super().__init__(seed, n_actions)\n", - " self.n_features = n_features\n", - " self.lr = lr\n", - " self.gamma = gamma\n", - " self.batch_size = batch_size\n", - " self.target_update_freq = target_update_freq\n", - " self.update_step = 0\n", - "\n", - " # Q-network and target network on GPU\n", - " torch.manual_seed(seed)\n", - " self.q_net = QNetwork(n_features, n_actions).to(DEVICE)\n", - " self.target_net = QNetwork(n_features, n_actions).to(DEVICE)\n", - " self.target_net.load_state_dict(self.q_net.state_dict())\n", - " self.target_net.eval()\n", - "\n", - " self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)\n", - " self.loss_fn = nn.SmoothL1Loss() # Huber loss — more robust than MSE\n", - "\n", - " # Experience replay buffer\n", - " self.replay_buffer = ReplayBuffer(buffer_size)\n", - "\n", - " def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n", - " \"\"\"Select action using ε-greedy policy over Q-network outputs.\"\"\"\n", - " if self.rng.random() < epsilon:\n", - " return int(self.rng.integers(0, self.action_space))\n", - "\n", - " phi = normalize_obs(observation)\n", - " with torch.no_grad():\n", - " state_t = torch.from_numpy(phi).float().unsqueeze(0).to(DEVICE)\n", - " q_vals = self.q_net(state_t).cpu().numpy().squeeze(0)\n", - " return epsilon_greedy(q_vals, 0.0, self.rng)\n", - "\n", - " def update(\n", - " self,\n", - " state: np.ndarray,\n", - " action: int,\n", - " reward: float,\n", - " next_state: np.ndarray,\n", - " done: bool,\n", - " next_action: int | None = None,\n", - " ) -> None:\n", - " \"\"\"Store transition and perform a minibatch DQN update.\n", - "\n", - " Steps:\n", - " 1. Add transition to replay buffer\n", - " 2. If buffer has enough samples, sample a minibatch\n", - " 3. Compute targets using target network (max_a' Q_target(s', a'))\n", - " 4. Compute Huber loss and backpropagate\n", - " 5. Clip gradients and update weights with Adam\n", - " 6. Periodically sync target network\n", - "\n", - " \"\"\"\n", - " _ = next_action # DQN is off-policy\n", - "\n", - " # Store transition\n", - " phi_s = normalize_obs(state)\n", - " phi_sp = normalize_obs(next_state)\n", - " self.replay_buffer.push(phi_s, action, reward, phi_sp, done)\n", - "\n", - " if len(self.replay_buffer) < self.batch_size:\n", - " return\n", - "\n", - " # Sample minibatch\n", - " states_b, actions_b, rewards_b, next_states_b, dones_b = self.replay_buffer.sample(\n", - " self.batch_size, self.rng,\n", - " )\n", - "\n", - " # Convert to tensors on device\n", - " states_t = torch.from_numpy(states_b).float().to(DEVICE)\n", - " actions_t = torch.from_numpy(actions_b).long().to(DEVICE)\n", - " rewards_t = torch.from_numpy(rewards_b).float().to(DEVICE)\n", - " next_states_t = torch.from_numpy(next_states_b).float().to(DEVICE)\n", - " dones_t = torch.from_numpy(dones_b).float().to(DEVICE)\n", - "\n", - " # Current Q-values for taken actions\n", - " q_values = self.q_net(states_t)\n", - " q_curr = q_values.gather(1, actions_t.unsqueeze(1)).squeeze(1)\n", - "\n", - " # Target Q-values (off-policy: max over actions in next state)\n", - " with torch.no_grad():\n", - " q_next = self.target_net(next_states_t).max(dim=1).values\n", - " targets = rewards_t + (1.0 - dones_t) * self.gamma * q_next\n", - "\n", - " # Compute loss and update\n", - " loss = self.loss_fn(q_curr, targets)\n", - " self.optimizer.zero_grad()\n", - " loss.backward()\n", - " nn.utils.clip_grad_norm_(self.q_net.parameters(), max_norm=10.0)\n", - " self.optimizer.step()\n", - "\n", - " # Sync target network periodically\n", - " self.update_step += 1\n", - " if self.update_step % self.target_update_freq == 0:\n", - " self.target_net.load_state_dict(self.q_net.state_dict())\n", - "\n", - " def save(self, filename: str) -> None:\n", - " \"\"\"Save agent state using torch.save (networks + optimizer + metadata).\"\"\"\n", - " torch.save(\n", - " {\n", - " \"q_net\": self.q_net.state_dict(),\n", - " \"target_net\": self.target_net.state_dict(),\n", - " \"optimizer\": self.optimizer.state_dict(),\n", - " \"update_step\": self.update_step,\n", - " \"n_features\": self.n_features,\n", - " \"action_space\": self.action_space,\n", - " },\n", - " filename,\n", - " )\n", - "\n", - " def load(self, filename: str) -> None:\n", - " \"\"\"Load agent state from a torch checkpoint.\"\"\"\n", - " checkpoint = torch.load(filename, map_location=DEVICE, weights_only=False)\n", - " self.q_net.load_state_dict(checkpoint[\"q_net\"])\n", - " self.target_net.load_state_dict(checkpoint[\"target_net\"])\n", - " self.optimizer.load_state_dict(checkpoint[\"optimizer\"])\n", - " self.update_step = checkpoint[\"update_step\"]\n", - " self.q_net.to(DEVICE)\n", - " self.target_net.to(DEVICE)\n", - " self.target_net.eval()\n" - ] - }, { "cell_type": "markdown", "id": "b7b63455", @@ -756,7 +513,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "id": "3c9d74be", "metadata": {}, "outputs": [], @@ -887,7 +644,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "id": "f9a973dd", "metadata": {}, "outputs": [], @@ -927,7 +684,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "id": "06b91580", "metadata": {}, "outputs": [], @@ -1180,19 +937,18 @@ "- **Random** — random baseline (no training needed)\n", "- **SARSA** — linear approximation, semi-gradient TD(0)\n", "- **Q-Learning** — linear approximation, off-policy\n", - "- **DQN** — PyTorch MLP (2 hidden layers of 256, experience replay, target network, GPU-accelerated)\n", "- **Monte Carlo** — linear approximation, first-visit returns\n", "\n", "**Workflow**:\n", "1. Train **one** selected agent (`AGENT_TO_TRAIN`)\n", - "2. Save its weights to `checkpoints/` (`.pkl` for linear agents, `.pt` for DQN)\n", + "2. Save its weights to `checkpoints/` (`.pkl`)\n", "3. Repeat later for another agent without retraining previous ones\n", "4. Load all saved checkpoints before the final evaluation" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "id": "6f6ba8df", "metadata": {}, "outputs": [ @@ -1222,21 +978,19 @@ "agent_random = RandomAgent(seed=42, action_space=int(n_actions))\n", "agent_sarsa = SarsaAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n", "agent_q = QLearningAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n", - "agent_dqn = DQNAgent(n_features=n_features, n_actions=n_actions, lr=1e-4)\n", "agent_mc = MonteCarloAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n", "\n", "agents = {\n", " \"Random\": agent_random,\n", " \"SARSA\": agent_sarsa,\n", " \"Q-Learning\": agent_q,\n", - " \"DQN\": agent_dqn,\n", " \"Monte Carlo\": agent_mc,\n", "}\n" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "id": "4d449701", "metadata": {}, "outputs": [ @@ -1244,45 +998,31 @@ "name": "stdout", "output_type": "stream", "text": [ - "Selected agent: DQN\n", - "Checkpoint path: checkpoints/dqn.pt\n", + "Selected agent: Monte Carlo\n", + "Checkpoint path: checkpoints/monte_carlo.pkl\n", "\n", "============================================================\n", - "Training: DQN (2500 episodes)\n", + "Training: Monte Carlo (2500 episodes)\n", "============================================================\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "47c9af1f459346dea335fb822a091b1b", + "model_id": "83fbf34f73ba481f906d78f0686d3c89", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Training DQN: 0%| | 0/2500 [00:00\u001b[0;34m()\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{'='*60}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m training_histories[AGENT_TO_TRAIN] = train_agent(\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0magent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0magent\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/tmp/ipykernel_730/4063369955.py\u001b[0m in \u001b[0;36mtrain_agent\u001b[0;34m(env, agent, name, episodes, epsilon_start, epsilon_end, epsilon_decay, max_steps)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;31m# Update agent with the transition\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m agent.update(\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mobs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/tmp/ipykernel_730/4071728682.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, state, action, reward, next_state, done, next_action)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;31m# Sample minibatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 138\u001b[0;31m states_b, actions_b, rewards_b, next_states_b, dones_b = self.replay_buffer.sample(\n\u001b[0m\u001b[1;32m 139\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m )\n", - "\u001b[0;32m/tmp/ipykernel_730/4071728682.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(self, batch_size, rng)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrng\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchoice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mstates\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0mactions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0mrewards\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] } ], "source": [ - "AGENT_TO_TRAIN = \"DQN\" # TODO: change to: \"Q-Learning\", \"DQN\", \"Monte Carlo\", \"Random\"\n", + "AGENT_TO_TRAIN = \"Monte Carlo\" # TODO: change to: \"Q-Learning\", \"Monte Carlo\", \"Random\"\n", "TRAINING_EPISODES = 2500\n", "FORCE_RETRAIN = False\n", "\n", @@ -1292,16 +1032,16 @@ "\n", "training_histories: dict[str, list[float]] = {}\n", "agent = agents[AGENT_TO_TRAIN]\n", - "ckpt_path = _ckpt_path(AGENT_TO_TRAIN)\n", + "checkpoint_path = get_path(AGENT_TO_TRAIN)\n", "\n", "print(f\"Selected agent: {AGENT_TO_TRAIN}\")\n", - "print(f\"Checkpoint path: {ckpt_path}\")\n", + "print(f\"Checkpoint path: {checkpoint_path}\")\n", "\n", "if AGENT_TO_TRAIN == \"Random\":\n", " print(\"Random is a baseline and is not trained.\")\n", " training_histories[AGENT_TO_TRAIN] = []\n", - "elif ckpt_path.exists() and not FORCE_RETRAIN:\n", - " agent.load(str(ckpt_path))\n", + "elif checkpoint_path.exists() and not FORCE_RETRAIN:\n", + " agent.load(str(checkpoint_path))\n", " print(\"Checkpoint found -> weights loaded, training skipped.\")\n", " training_histories[AGENT_TO_TRAIN] = []\n", "else:\n", @@ -1322,7 +1062,7 @@ " avg_last_100 = np.mean(training_histories[AGENT_TO_TRAIN][-100:])\n", " print(f\"-> {AGENT_TO_TRAIN} avg reward (last 100 eps): {avg_last_100:.2f}\")\n", "\n", - " agent.save(str(ckpt_path))\n", + " agent.save(str(checkpoint_path))\n", " print(\"Checkpoint saved.\")\n" ] }, @@ -1349,43 +1089,6 @@ ")\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "0fc0c643", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded checkpoints: ['SARSA', 'Q-Learning']\n", - "Missing checkpoints (untrained/not saved yet): ['DQN', 'Monte Carlo']\n" - ] - } - ], - "source": [ - "# Load all available checkpoints before final evaluation\n", - "loaded_agents: list[str] = []\n", - "missing_agents: list[str] = []\n", - "\n", - "for name, agent in agents.items():\n", - " if name == \"Random\":\n", - " continue\n", - "\n", - " ckpt_path = _ckpt_path(name)\n", - "\n", - " if ckpt_path.exists():\n", - " agent.load(str(ckpt_path))\n", - " loaded_agents.append(name)\n", - " else:\n", - " missing_agents.append(name)\n", - "\n", - "print(f\"Loaded checkpoints: {loaded_agents}\")\n", - "if missing_agents:\n", - " print(f\"Missing checkpoints (untrained/not saved yet): {missing_agents}\")\n" - ] - }, { "cell_type": "markdown", "id": "0e5f2c49", @@ -1402,7 +1105,53 @@ "execution_count": null, "id": "70f5d5cd", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "checkpoints/sarsa.pkl\n", + "checkpoints/q_learning.pkl\n", + "checkpoints/monte_carlo.pkl\n", + "Agents evaluated: ['Random', 'SARSA', 'Q-Learning']\n", + "Skipped (no checkpoint yet): ['Monte Carlo']\n", + "[Evaluation 1/3] Random\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "46b0cdfc1baa4ef2a8dcc20c6924346d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Evaluating Random: 0%| | 0/20 [00:00 \u001b[39m\u001b[32m24\u001b[39m results = \u001b[43mevaluate_tournament\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meval_agents\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepisodes_per_agent\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m20\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 25\u001b[39m plot_evaluation_comparison(results)\n\u001b[32m 27\u001b[39m \u001b[38;5;66;03m# Print summary table\u001b[39;00m\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[12]\u001b[39m\u001b[32m, line 226\u001b[39m, in \u001b[36mevaluate_tournament\u001b[39m\u001b[34m(env, agents, episodes_per_agent)\u001b[39m\n\u001b[32m 224\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m idx, (name, agent) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(agents.items(), start=\u001b[32m1\u001b[39m):\n\u001b[32m 225\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m[Evaluation \u001b[39m\u001b[38;5;132;01m{\u001b[39;00midx\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mn_agents\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m] \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m226\u001b[39m results[name] = \u001b[43mevaluate_agent\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 227\u001b[39m \u001b[43m \u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magent\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepisodes\u001b[49m\u001b[43m=\u001b[49m\u001b[43mepisodes_per_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 228\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 229\u001b[39m mean_r = results[name][\u001b[33m\"\u001b[39m\u001b[33mmean_reward\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 230\u001b[39m wr = results[name][\u001b[33m\"\u001b[39m\u001b[33mwin_rate\u001b[39m\u001b[33m\"\u001b[39m]\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[12]\u001b[39m\u001b[32m, line 117\u001b[39m, in \u001b[36mevaluate_agent\u001b[39m\u001b[34m(env, agent, name, episodes, max_steps)\u001b[39m\n\u001b[32m 115\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_steps):\n\u001b[32m 116\u001b[39m action = agent.get_action(np.asarray(obs), epsilon=\u001b[32m0.0\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m117\u001b[39m obs, reward, terminated, truncated, _info = \u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 118\u001b[39m reward = \u001b[38;5;28mfloat\u001b[39m(reward)\n\u001b[32m 119\u001b[39m total_reward += reward\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/stateful_observation.py:425\u001b[39m, in \u001b[36mFrameStackObservation.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 414\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 415\u001b[39m \u001b[38;5;28mself\u001b[39m, action: WrapperActType\n\u001b[32m 416\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 417\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Steps through the environment, appending the observation to the frame buffer.\u001b[39;00m\n\u001b[32m 418\u001b[39m \n\u001b[32m 419\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 423\u001b[39m \u001b[33;03m Stacked observations, reward, terminated, truncated, and info from the environment\u001b[39;00m\n\u001b[32m 424\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m425\u001b[39m obs, reward, terminated, truncated, info = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 426\u001b[39m \u001b[38;5;28mself\u001b[39m.obs_queue.append(obs)\n\u001b[32m 428\u001b[39m updated_obs = deepcopy(\n\u001b[32m 429\u001b[39m concatenate(\u001b[38;5;28mself\u001b[39m.env.observation_space, \u001b[38;5;28mself\u001b[39m.obs_queue, \u001b[38;5;28mself\u001b[39m.stacked_obs)\n\u001b[32m 430\u001b[39m )\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/core.py:560\u001b[39m, in \u001b[36mObservationWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 556\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 557\u001b[39m \u001b[38;5;28mself\u001b[39m, action: ActType\n\u001b[32m 558\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 559\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Modifies the :attr:`env` after calling :meth:`step` using :meth:`self.observation` on the returned observations.\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m560\u001b[39m observation, reward, terminated, truncated, info = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 561\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.observation(observation), reward, terminated, truncated, info\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/common.py:393\u001b[39m, in \u001b[36mOrderEnforcing.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 391\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._has_reset:\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m ResetNeeded(\u001b[33m\"\u001b[39m\u001b[33mCannot call env.step() before calling env.reset()\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m393\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/core.py:327\u001b[39m, in \u001b[36mWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 323\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 324\u001b[39m \u001b[38;5;28mself\u001b[39m, action: WrapperActType\n\u001b[32m 325\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 326\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data.\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m327\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/common.py:285\u001b[39m, in \u001b[36mPassiveEnvChecker.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 283\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m env_step_passive_checker(\u001b[38;5;28mself\u001b[39m.env, action)\n\u001b[32m 284\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m285\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/ale_py/env.py:305\u001b[39m, in \u001b[36mAtariEnv.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 303\u001b[39m reward = \u001b[32m0.0\u001b[39m\n\u001b[32m 304\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(frameskip):\n\u001b[32m--> \u001b[39m\u001b[32m305\u001b[39m reward += \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43male\u001b[49m\u001b[43m.\u001b[49m\u001b[43mact\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrength\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 307\u001b[39m is_terminal = \u001b[38;5;28mself\u001b[39m.ale.game_over(with_truncation=\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m 308\u001b[39m is_truncated = \u001b[38;5;28mself\u001b[39m.ale.game_truncated()\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + } + ], "source": [ "# Build evaluation set: Random + agents with existing checkpoints\n", "eval_agents: dict[str, Agent] = {\"Random\": agents[\"Random\"]}\n", @@ -1412,10 +1161,10 @@ " if name == \"Random\":\n", " continue\n", "\n", - " ckpt_path = _ckpt_path(name)\n", + " checkpoint_path = get_path(name)\n", "\n", - " if ckpt_path.exists():\n", - " agent.load(str(ckpt_path))\n", + " if checkpoint_path.exists():\n", + " agent.load(str(checkpoint_path))\n", " eval_agents[name] = agent\n", " else:\n", " missing_agents.append(name)\n", @@ -1497,7 +1246,6 @@ "\n", "def run_pettingzoo_tournament(\n", " agents: dict[str, Agent],\n", - " checkpoint_dir: Path,\n", " episodes_per_side: int = 10,\n", ") -> tuple[np.ndarray, list[str]]:\n", " \"\"\"Round-robin tournament excluding Random, with seat-swap fairness.\"\"\"\n", @@ -1507,9 +1255,9 @@ " # Keep only agents that have a checkpoint\n", " ready_names: list[str] = []\n", " for name in candidate_names:\n", - " ckpt_path = _ckpt_path(name)\n", - " if ckpt_path.exists():\n", - " agents[name].load(str(ckpt_path))\n", + " checkpoint_path = get_path(name)\n", + " if checkpoint_path.exists():\n", + " agents[name].load(str(checkpoint_path))\n", " ready_names.append(name)\n", "\n", " if len(ready_names) < 2:\n", @@ -1568,7 +1316,6 @@ "# Run tournament (non-random agents only)\n", "win_matrix_pz, pz_names = run_pettingzoo_tournament(\n", " agents=agents,\n", - " checkpoint_dir=CHECKPOINT_DIR,\n", " episodes_per_side=10,\n", ")\n", "\n", @@ -1613,16 +1360,22 @@ "\n", "This tournament uses `from pettingzoo.atari import tennis_v3` to make trained agents play against each other directly.\n", "\n", - "- Checkpoints are loaded from `checkpoints/` (`.pkl` for linear agents, `.pt` for DQN)\n", + "- Checkpoints are loaded from `checkpoints/` (`.pkl`)\n", "- `Random` is **excluded** from ranking\n", "- Each pair plays in both seat positions (`first_0` and `second_0`) to reduce position bias\n", "- A win-rate matrix and final ranking are produced" ] + }, + { + "cell_type": "markdown", + "id": "150e6764", + "metadata": {}, + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "studies (3.13.9)", "language": "python", "name": "python3" }, @@ -1636,7 +1389,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.12" + "version": "3.13.9" } }, "nbformat": 4,