mirror of
https://github.com/ArthurDanjou/ArtStudies.git
synced 2026-03-16 03:11:46 +01:00
Implement feature X to enhance user experience and optimize performance
This commit is contained in:
@@ -11,30 +11,20 @@
|
||||
"\n",
|
||||
"1. **SARSA** — Semi-gradient SARSA with linear approximation (inspired by Lab 7, on-policy update from Lab 5B)\n",
|
||||
"2. **Q-Learning** — Off-policy linear approximation (inspired by Lab 5B)\n",
|
||||
"3. **DQN** — Deep Q-Network with PyTorch MLP, experience replay and target network (inspired by Lab 6A + classic DQN), GPU-accelerated via MPS\n",
|
||||
"4. **Monte Carlo** — First-visit MC control with linear approximation (inspired by Lab 4)\n",
|
||||
"3. **Monte Carlo** — First-visit MC control with linear approximation (inspired by Lab 4)\n",
|
||||
"\n",
|
||||
"Each agent is **pre-trained independently** against the built-in Atari AI opponent, then evaluated in a comparative tournament."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 16,
|
||||
"id": "b50d7174",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"PyTorch device: cuda\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import itertools\n",
|
||||
"import pickle\n",
|
||||
"from collections import deque\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import ale_py # noqa: F401 — registers ALE environments\n",
|
||||
@@ -46,23 +36,12 @@
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"import seaborn as sns\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from torch import nn, optim\n",
|
||||
"\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" DEVICE = torch.device(\"cuda\")\n",
|
||||
"elif torch.backends.mps.is_available():\n",
|
||||
" DEVICE = torch.device(\"mps\")\n",
|
||||
"else:\n",
|
||||
" DEVICE = torch.device(\"cpu\")\n",
|
||||
"print(f\"PyTorch device: {DEVICE}\")\n"
|
||||
"import seaborn as sns\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 17,
|
||||
"id": "ff3486a4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -71,11 +50,10 @@
|
||||
"CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _ckpt_path(name: str) -> Path:\n",
|
||||
" \"\"\"Return the checkpoint path for an agent (DQN uses .pt, others use .pkl).\"\"\"\n",
|
||||
"def get_path(name: str) -> Path:\n",
|
||||
" \"\"\"Return the checkpoint path for an agent (.pkl).\"\"\"\n",
|
||||
" base = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\")\n",
|
||||
" ext = \".pt\" if name == \"DQN\" else \".pkl\"\n",
|
||||
" return CHECKPOINT_DIR / (base + ext)\n"
|
||||
" return CHECKPOINT_DIR / (base + \".pkl\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -99,7 +77,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 18,
|
||||
"id": "be85c130",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -183,7 +161,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 19,
|
||||
"id": "ded9b1fb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -237,7 +215,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 20,
|
||||
"id": "78bdc9d2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -270,7 +248,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 21,
|
||||
"id": "c124ed9a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -415,7 +393,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 22,
|
||||
"id": "f5b5b9ea",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -516,227 +494,6 @@
|
||||
" self.W[action] = np.nan_to_num(self.W[action], nan=0.0, posinf=1e6, neginf=-1e6)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f644e2ef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## DQN Agent — PyTorch MLP with Experience Replay and Target Network\n",
|
||||
"\n",
|
||||
"This agent implements the Deep Q-Network (DQN) using **PyTorch** for GPU-accelerated training (MPS on Apple Silicon).\n",
|
||||
"\n",
|
||||
"**Network architecture** (same structure as before, now as `torch.nn.Module`):\n",
|
||||
"$$\\text{Input}(n\\_features) \\to \\text{Linear}(256) \\to \\text{ReLU} \\to \\text{Linear}(256) \\to \\text{ReLU} \\to \\text{Linear}(n\\_actions)$$\n",
|
||||
"\n",
|
||||
"**Key techniques** (inspired by Lab 6A Dyna-Q + classic DQN):\n",
|
||||
"- **Experience Replay**: circular buffer of transitions, sampled as minibatches for off-policy updates\n",
|
||||
"- **Target Network**: periodically synchronized copy of the Q-network, stabilizes learning\n",
|
||||
"- **Gradient clipping**: prevents exploding gradients in deep networks\n",
|
||||
"- **GPU acceleration**: tensors on MPS/CUDA device for fast forward/backward passes"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec090a9a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class QNetwork(nn.Module):\n",
|
||||
" \"\"\"MLP Q-network: Input -> 256 -> ReLU -> 256 -> ReLU -> n_actions.\"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, n_features: int, n_actions: int) -> None:\n",
|
||||
" super().__init__()\n",
|
||||
" self.net = nn.Sequential(\n",
|
||||
" nn.Linear(n_features, 256),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Linear(256, 256),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Linear(256, n_actions),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
|
||||
" return self.net(x)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ReplayBuffer:\n",
|
||||
" \"\"\"Fixed-size circular replay buffer storing (s, a, r, s', done) transitions.\"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, capacity: int) -> None:\n",
|
||||
" self.buffer: deque[tuple[np.ndarray, int, float, np.ndarray, bool]] = deque(maxlen=capacity)\n",
|
||||
"\n",
|
||||
" def push(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool) -> None:\n",
|
||||
" self.buffer.append((state, action, reward, next_state, done))\n",
|
||||
"\n",
|
||||
" def sample(self, batch_size: int, rng: np.random.Generator) -> tuple[np.ndarray, ...]:\n",
|
||||
" indices = rng.choice(len(self.buffer), size=batch_size, replace=False)\n",
|
||||
" batch = [self.buffer[i] for i in indices]\n",
|
||||
" states = np.array([t[0] for t in batch])\n",
|
||||
" actions = np.array([t[1] for t in batch])\n",
|
||||
" rewards = np.array([t[2] for t in batch])\n",
|
||||
" next_states = np.array([t[3] for t in batch])\n",
|
||||
" dones = np.array([t[4] for t in batch], dtype=np.float32)\n",
|
||||
" return states, actions, rewards, next_states, dones\n",
|
||||
"\n",
|
||||
" def __len__(self) -> int:\n",
|
||||
" return len(self.buffer)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class DQNAgent(Agent):\n",
|
||||
" \"\"\"Deep Q-Network agent using PyTorch with GPU acceleration (MPS/CUDA).\n",
|
||||
"\n",
|
||||
" Inspired by:\n",
|
||||
" - Lab 6A Dyna-Q: experience replay (store transitions, sample for updates)\n",
|
||||
" - Classic DQN (Mnih et al., 2015): target network, minibatch SGD\n",
|
||||
"\n",
|
||||
" Uses Adam optimizer and Huber loss (smooth L1) for stable training.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" n_features: int,\n",
|
||||
" n_actions: int,\n",
|
||||
" lr: float = 1e-4,\n",
|
||||
" gamma: float = 0.99,\n",
|
||||
" buffer_size: int = 50_000,\n",
|
||||
" batch_size: int = 128,\n",
|
||||
" target_update_freq: int = 1000,\n",
|
||||
" seed: int = 42,\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Initialize DQN agent.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" n_features: Input feature dimension.\n",
|
||||
" n_actions: Number of discrete actions.\n",
|
||||
" lr: Learning rate for Adam optimizer.\n",
|
||||
" gamma: Discount factor.\n",
|
||||
" buffer_size: Maximum replay buffer capacity.\n",
|
||||
" batch_size: Minibatch size for updates.\n",
|
||||
" target_update_freq: Steps between target network syncs.\n",
|
||||
" seed: RNG seed.\n",
|
||||
"\n",
|
||||
" \"\"\"\n",
|
||||
" super().__init__(seed, n_actions)\n",
|
||||
" self.n_features = n_features\n",
|
||||
" self.lr = lr\n",
|
||||
" self.gamma = gamma\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" self.target_update_freq = target_update_freq\n",
|
||||
" self.update_step = 0\n",
|
||||
"\n",
|
||||
" # Q-network and target network on GPU\n",
|
||||
" torch.manual_seed(seed)\n",
|
||||
" self.q_net = QNetwork(n_features, n_actions).to(DEVICE)\n",
|
||||
" self.target_net = QNetwork(n_features, n_actions).to(DEVICE)\n",
|
||||
" self.target_net.load_state_dict(self.q_net.state_dict())\n",
|
||||
" self.target_net.eval()\n",
|
||||
"\n",
|
||||
" self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)\n",
|
||||
" self.loss_fn = nn.SmoothL1Loss() # Huber loss — more robust than MSE\n",
|
||||
"\n",
|
||||
" # Experience replay buffer\n",
|
||||
" self.replay_buffer = ReplayBuffer(buffer_size)\n",
|
||||
"\n",
|
||||
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
|
||||
" \"\"\"Select action using ε-greedy policy over Q-network outputs.\"\"\"\n",
|
||||
" if self.rng.random() < epsilon:\n",
|
||||
" return int(self.rng.integers(0, self.action_space))\n",
|
||||
"\n",
|
||||
" phi = normalize_obs(observation)\n",
|
||||
" with torch.no_grad():\n",
|
||||
" state_t = torch.from_numpy(phi).float().unsqueeze(0).to(DEVICE)\n",
|
||||
" q_vals = self.q_net(state_t).cpu().numpy().squeeze(0)\n",
|
||||
" return epsilon_greedy(q_vals, 0.0, self.rng)\n",
|
||||
"\n",
|
||||
" def update(\n",
|
||||
" self,\n",
|
||||
" state: np.ndarray,\n",
|
||||
" action: int,\n",
|
||||
" reward: float,\n",
|
||||
" next_state: np.ndarray,\n",
|
||||
" done: bool,\n",
|
||||
" next_action: int | None = None,\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Store transition and perform a minibatch DQN update.\n",
|
||||
"\n",
|
||||
" Steps:\n",
|
||||
" 1. Add transition to replay buffer\n",
|
||||
" 2. If buffer has enough samples, sample a minibatch\n",
|
||||
" 3. Compute targets using target network (max_a' Q_target(s', a'))\n",
|
||||
" 4. Compute Huber loss and backpropagate\n",
|
||||
" 5. Clip gradients and update weights with Adam\n",
|
||||
" 6. Periodically sync target network\n",
|
||||
"\n",
|
||||
" \"\"\"\n",
|
||||
" _ = next_action # DQN is off-policy\n",
|
||||
"\n",
|
||||
" # Store transition\n",
|
||||
" phi_s = normalize_obs(state)\n",
|
||||
" phi_sp = normalize_obs(next_state)\n",
|
||||
" self.replay_buffer.push(phi_s, action, reward, phi_sp, done)\n",
|
||||
"\n",
|
||||
" if len(self.replay_buffer) < self.batch_size:\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" # Sample minibatch\n",
|
||||
" states_b, actions_b, rewards_b, next_states_b, dones_b = self.replay_buffer.sample(\n",
|
||||
" self.batch_size, self.rng,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Convert to tensors on device\n",
|
||||
" states_t = torch.from_numpy(states_b).float().to(DEVICE)\n",
|
||||
" actions_t = torch.from_numpy(actions_b).long().to(DEVICE)\n",
|
||||
" rewards_t = torch.from_numpy(rewards_b).float().to(DEVICE)\n",
|
||||
" next_states_t = torch.from_numpy(next_states_b).float().to(DEVICE)\n",
|
||||
" dones_t = torch.from_numpy(dones_b).float().to(DEVICE)\n",
|
||||
"\n",
|
||||
" # Current Q-values for taken actions\n",
|
||||
" q_values = self.q_net(states_t)\n",
|
||||
" q_curr = q_values.gather(1, actions_t.unsqueeze(1)).squeeze(1)\n",
|
||||
"\n",
|
||||
" # Target Q-values (off-policy: max over actions in next state)\n",
|
||||
" with torch.no_grad():\n",
|
||||
" q_next = self.target_net(next_states_t).max(dim=1).values\n",
|
||||
" targets = rewards_t + (1.0 - dones_t) * self.gamma * q_next\n",
|
||||
"\n",
|
||||
" # Compute loss and update\n",
|
||||
" loss = self.loss_fn(q_curr, targets)\n",
|
||||
" self.optimizer.zero_grad()\n",
|
||||
" loss.backward()\n",
|
||||
" nn.utils.clip_grad_norm_(self.q_net.parameters(), max_norm=10.0)\n",
|
||||
" self.optimizer.step()\n",
|
||||
"\n",
|
||||
" # Sync target network periodically\n",
|
||||
" self.update_step += 1\n",
|
||||
" if self.update_step % self.target_update_freq == 0:\n",
|
||||
" self.target_net.load_state_dict(self.q_net.state_dict())\n",
|
||||
"\n",
|
||||
" def save(self, filename: str) -> None:\n",
|
||||
" \"\"\"Save agent state using torch.save (networks + optimizer + metadata).\"\"\"\n",
|
||||
" torch.save(\n",
|
||||
" {\n",
|
||||
" \"q_net\": self.q_net.state_dict(),\n",
|
||||
" \"target_net\": self.target_net.state_dict(),\n",
|
||||
" \"optimizer\": self.optimizer.state_dict(),\n",
|
||||
" \"update_step\": self.update_step,\n",
|
||||
" \"n_features\": self.n_features,\n",
|
||||
" \"action_space\": self.action_space,\n",
|
||||
" },\n",
|
||||
" filename,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def load(self, filename: str) -> None:\n",
|
||||
" \"\"\"Load agent state from a torch checkpoint.\"\"\"\n",
|
||||
" checkpoint = torch.load(filename, map_location=DEVICE, weights_only=False)\n",
|
||||
" self.q_net.load_state_dict(checkpoint[\"q_net\"])\n",
|
||||
" self.target_net.load_state_dict(checkpoint[\"target_net\"])\n",
|
||||
" self.optimizer.load_state_dict(checkpoint[\"optimizer\"])\n",
|
||||
" self.update_step = checkpoint[\"update_step\"]\n",
|
||||
" self.q_net.to(DEVICE)\n",
|
||||
" self.target_net.to(DEVICE)\n",
|
||||
" self.target_net.eval()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b7b63455",
|
||||
@@ -756,7 +513,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 23,
|
||||
"id": "3c9d74be",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -887,7 +644,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 24,
|
||||
"id": "f9a973dd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -927,7 +684,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 25,
|
||||
"id": "06b91580",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -1180,19 +937,18 @@
|
||||
"- **Random** — random baseline (no training needed)\n",
|
||||
"- **SARSA** — linear approximation, semi-gradient TD(0)\n",
|
||||
"- **Q-Learning** — linear approximation, off-policy\n",
|
||||
"- **DQN** — PyTorch MLP (2 hidden layers of 256, experience replay, target network, GPU-accelerated)\n",
|
||||
"- **Monte Carlo** — linear approximation, first-visit returns\n",
|
||||
"\n",
|
||||
"**Workflow**:\n",
|
||||
"1. Train **one** selected agent (`AGENT_TO_TRAIN`)\n",
|
||||
"2. Save its weights to `checkpoints/` (`.pkl` for linear agents, `.pt` for DQN)\n",
|
||||
"2. Save its weights to `checkpoints/` (`.pkl`)\n",
|
||||
"3. Repeat later for another agent without retraining previous ones\n",
|
||||
"4. Load all saved checkpoints before the final evaluation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 26,
|
||||
"id": "6f6ba8df",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1222,21 +978,19 @@
|
||||
"agent_random = RandomAgent(seed=42, action_space=int(n_actions))\n",
|
||||
"agent_sarsa = SarsaAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
|
||||
"agent_q = QLearningAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
|
||||
"agent_dqn = DQNAgent(n_features=n_features, n_actions=n_actions, lr=1e-4)\n",
|
||||
"agent_mc = MonteCarloAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
|
||||
"\n",
|
||||
"agents = {\n",
|
||||
" \"Random\": agent_random,\n",
|
||||
" \"SARSA\": agent_sarsa,\n",
|
||||
" \"Q-Learning\": agent_q,\n",
|
||||
" \"DQN\": agent_dqn,\n",
|
||||
" \"Monte Carlo\": agent_mc,\n",
|
||||
"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": null,
|
||||
"id": "4d449701",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1244,45 +998,31 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Selected agent: DQN\n",
|
||||
"Checkpoint path: checkpoints/dqn.pt\n",
|
||||
"Selected agent: Monte Carlo\n",
|
||||
"Checkpoint path: checkpoints/monte_carlo.pkl\n",
|
||||
"\n",
|
||||
"============================================================\n",
|
||||
"Training: DQN (2500 episodes)\n",
|
||||
"Training: Monte Carlo (2500 episodes)\n",
|
||||
"============================================================\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "47c9af1f459346dea335fb822a091b1b",
|
||||
"model_id": "83fbf34f73ba481f906d78f0686d3c89",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Training DQN: 0%| | 0/2500 [00:00<?, ?it/s]"
|
||||
"Training Monte Carlo: 0%| | 0/2500 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m/tmp/ipykernel_730/423872786.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{'='*60}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m training_histories[AGENT_TO_TRAIN] = train_agent(\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0magent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0magent\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/tmp/ipykernel_730/4063369955.py\u001b[0m in \u001b[0;36mtrain_agent\u001b[0;34m(env, agent, name, episodes, epsilon_start, epsilon_end, epsilon_decay, max_steps)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;31m# Update agent with the transition\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m agent.update(\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mobs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/tmp/ipykernel_730/4071728682.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, state, action, reward, next_state, done, next_action)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;31m# Sample minibatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 138\u001b[0;31m states_b, actions_b, rewards_b, next_states_b, dones_b = self.replay_buffer.sample(\n\u001b[0m\u001b[1;32m 139\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m )\n",
|
||||
"\u001b[0;32m/tmp/ipykernel_730/4071728682.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(self, batch_size, rng)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrng\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchoice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mstates\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0mactions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0mrewards\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"AGENT_TO_TRAIN = \"DQN\" # TODO: change to: \"Q-Learning\", \"DQN\", \"Monte Carlo\", \"Random\"\n",
|
||||
"AGENT_TO_TRAIN = \"Monte Carlo\" # TODO: change to: \"Q-Learning\", \"Monte Carlo\", \"Random\"\n",
|
||||
"TRAINING_EPISODES = 2500\n",
|
||||
"FORCE_RETRAIN = False\n",
|
||||
"\n",
|
||||
@@ -1292,16 +1032,16 @@
|
||||
"\n",
|
||||
"training_histories: dict[str, list[float]] = {}\n",
|
||||
"agent = agents[AGENT_TO_TRAIN]\n",
|
||||
"ckpt_path = _ckpt_path(AGENT_TO_TRAIN)\n",
|
||||
"checkpoint_path = get_path(AGENT_TO_TRAIN)\n",
|
||||
"\n",
|
||||
"print(f\"Selected agent: {AGENT_TO_TRAIN}\")\n",
|
||||
"print(f\"Checkpoint path: {ckpt_path}\")\n",
|
||||
"print(f\"Checkpoint path: {checkpoint_path}\")\n",
|
||||
"\n",
|
||||
"if AGENT_TO_TRAIN == \"Random\":\n",
|
||||
" print(\"Random is a baseline and is not trained.\")\n",
|
||||
" training_histories[AGENT_TO_TRAIN] = []\n",
|
||||
"elif ckpt_path.exists() and not FORCE_RETRAIN:\n",
|
||||
" agent.load(str(ckpt_path))\n",
|
||||
"elif checkpoint_path.exists() and not FORCE_RETRAIN:\n",
|
||||
" agent.load(str(checkpoint_path))\n",
|
||||
" print(\"Checkpoint found -> weights loaded, training skipped.\")\n",
|
||||
" training_histories[AGENT_TO_TRAIN] = []\n",
|
||||
"else:\n",
|
||||
@@ -1322,7 +1062,7 @@
|
||||
" avg_last_100 = np.mean(training_histories[AGENT_TO_TRAIN][-100:])\n",
|
||||
" print(f\"-> {AGENT_TO_TRAIN} avg reward (last 100 eps): {avg_last_100:.2f}\")\n",
|
||||
"\n",
|
||||
" agent.save(str(ckpt_path))\n",
|
||||
" agent.save(str(checkpoint_path))\n",
|
||||
" print(\"Checkpoint saved.\")\n"
|
||||
]
|
||||
},
|
||||
@@ -1349,43 +1089,6 @@
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0fc0c643",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Loaded checkpoints: ['SARSA', 'Q-Learning']\n",
|
||||
"Missing checkpoints (untrained/not saved yet): ['DQN', 'Monte Carlo']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Load all available checkpoints before final evaluation\n",
|
||||
"loaded_agents: list[str] = []\n",
|
||||
"missing_agents: list[str] = []\n",
|
||||
"\n",
|
||||
"for name, agent in agents.items():\n",
|
||||
" if name == \"Random\":\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" ckpt_path = _ckpt_path(name)\n",
|
||||
"\n",
|
||||
" if ckpt_path.exists():\n",
|
||||
" agent.load(str(ckpt_path))\n",
|
||||
" loaded_agents.append(name)\n",
|
||||
" else:\n",
|
||||
" missing_agents.append(name)\n",
|
||||
"\n",
|
||||
"print(f\"Loaded checkpoints: {loaded_agents}\")\n",
|
||||
"if missing_agents:\n",
|
||||
" print(f\"Missing checkpoints (untrained/not saved yet): {missing_agents}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0e5f2c49",
|
||||
@@ -1402,7 +1105,53 @@
|
||||
"execution_count": null,
|
||||
"id": "70f5d5cd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"checkpoints/sarsa.pkl\n",
|
||||
"checkpoints/q_learning.pkl\n",
|
||||
"checkpoints/monte_carlo.pkl\n",
|
||||
"Agents evaluated: ['Random', 'SARSA', 'Q-Learning']\n",
|
||||
"Skipped (no checkpoint yet): ['Monte Carlo']\n",
|
||||
"[Evaluation 1/3] Random\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "46b0cdfc1baa4ef2a8dcc20c6924346d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Evaluating Random: 0%| | 0/20 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||||
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
|
||||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[15]\u001b[39m\u001b[32m, line 24\u001b[39m\n\u001b[32m 21\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(eval_agents) < \u001b[32m2\u001b[39m:\n\u001b[32m 22\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mTrain at least one non-random agent before final evaluation.\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m24\u001b[39m results = \u001b[43mevaluate_tournament\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meval_agents\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepisodes_per_agent\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m20\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 25\u001b[39m plot_evaluation_comparison(results)\n\u001b[32m 27\u001b[39m \u001b[38;5;66;03m# Print summary table\u001b[39;00m\n",
|
||||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[12]\u001b[39m\u001b[32m, line 226\u001b[39m, in \u001b[36mevaluate_tournament\u001b[39m\u001b[34m(env, agents, episodes_per_agent)\u001b[39m\n\u001b[32m 224\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m idx, (name, agent) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(agents.items(), start=\u001b[32m1\u001b[39m):\n\u001b[32m 225\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m[Evaluation \u001b[39m\u001b[38;5;132;01m{\u001b[39;00midx\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mn_agents\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m] \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m226\u001b[39m results[name] = \u001b[43mevaluate_agent\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 227\u001b[39m \u001b[43m \u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magent\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepisodes\u001b[49m\u001b[43m=\u001b[49m\u001b[43mepisodes_per_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 228\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 229\u001b[39m mean_r = results[name][\u001b[33m\"\u001b[39m\u001b[33mmean_reward\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 230\u001b[39m wr = results[name][\u001b[33m\"\u001b[39m\u001b[33mwin_rate\u001b[39m\u001b[33m\"\u001b[39m]\n",
|
||||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[12]\u001b[39m\u001b[32m, line 117\u001b[39m, in \u001b[36mevaluate_agent\u001b[39m\u001b[34m(env, agent, name, episodes, max_steps)\u001b[39m\n\u001b[32m 115\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_steps):\n\u001b[32m 116\u001b[39m action = agent.get_action(np.asarray(obs), epsilon=\u001b[32m0.0\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m117\u001b[39m obs, reward, terminated, truncated, _info = \u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 118\u001b[39m reward = \u001b[38;5;28mfloat\u001b[39m(reward)\n\u001b[32m 119\u001b[39m total_reward += reward\n",
|
||||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/stateful_observation.py:425\u001b[39m, in \u001b[36mFrameStackObservation.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 414\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 415\u001b[39m \u001b[38;5;28mself\u001b[39m, action: WrapperActType\n\u001b[32m 416\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 417\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Steps through the environment, appending the observation to the frame buffer.\u001b[39;00m\n\u001b[32m 418\u001b[39m \n\u001b[32m 419\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 423\u001b[39m \u001b[33;03m Stacked observations, reward, terminated, truncated, and info from the environment\u001b[39;00m\n\u001b[32m 424\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m425\u001b[39m obs, reward, terminated, truncated, info = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 426\u001b[39m \u001b[38;5;28mself\u001b[39m.obs_queue.append(obs)\n\u001b[32m 428\u001b[39m updated_obs = deepcopy(\n\u001b[32m 429\u001b[39m concatenate(\u001b[38;5;28mself\u001b[39m.env.observation_space, \u001b[38;5;28mself\u001b[39m.obs_queue, \u001b[38;5;28mself\u001b[39m.stacked_obs)\n\u001b[32m 430\u001b[39m )\n",
|
||||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/core.py:560\u001b[39m, in \u001b[36mObservationWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 556\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 557\u001b[39m \u001b[38;5;28mself\u001b[39m, action: ActType\n\u001b[32m 558\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 559\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Modifies the :attr:`env` after calling :meth:`step` using :meth:`self.observation` on the returned observations.\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m560\u001b[39m observation, reward, terminated, truncated, info = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 561\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.observation(observation), reward, terminated, truncated, info\n",
|
||||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/common.py:393\u001b[39m, in \u001b[36mOrderEnforcing.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 391\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._has_reset:\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m ResetNeeded(\u001b[33m\"\u001b[39m\u001b[33mCannot call env.step() before calling env.reset()\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m393\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/core.py:327\u001b[39m, in \u001b[36mWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 323\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 324\u001b[39m \u001b[38;5;28mself\u001b[39m, action: WrapperActType\n\u001b[32m 325\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 326\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data.\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m327\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/common.py:285\u001b[39m, in \u001b[36mPassiveEnvChecker.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 283\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m env_step_passive_checker(\u001b[38;5;28mself\u001b[39m.env, action)\n\u001b[32m 284\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m285\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/ale_py/env.py:305\u001b[39m, in \u001b[36mAtariEnv.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 303\u001b[39m reward = \u001b[32m0.0\u001b[39m\n\u001b[32m 304\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(frameskip):\n\u001b[32m--> \u001b[39m\u001b[32m305\u001b[39m reward += \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43male\u001b[49m\u001b[43m.\u001b[49m\u001b[43mact\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrength\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 307\u001b[39m is_terminal = \u001b[38;5;28mself\u001b[39m.ale.game_over(with_truncation=\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m 308\u001b[39m is_truncated = \u001b[38;5;28mself\u001b[39m.ale.game_truncated()\n",
|
||||
"\u001b[31mKeyboardInterrupt\u001b[39m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Build evaluation set: Random + agents with existing checkpoints\n",
|
||||
"eval_agents: dict[str, Agent] = {\"Random\": agents[\"Random\"]}\n",
|
||||
@@ -1412,10 +1161,10 @@
|
||||
" if name == \"Random\":\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" ckpt_path = _ckpt_path(name)\n",
|
||||
" checkpoint_path = get_path(name)\n",
|
||||
"\n",
|
||||
" if ckpt_path.exists():\n",
|
||||
" agent.load(str(ckpt_path))\n",
|
||||
" if checkpoint_path.exists():\n",
|
||||
" agent.load(str(checkpoint_path))\n",
|
||||
" eval_agents[name] = agent\n",
|
||||
" else:\n",
|
||||
" missing_agents.append(name)\n",
|
||||
@@ -1497,7 +1246,6 @@
|
||||
"\n",
|
||||
"def run_pettingzoo_tournament(\n",
|
||||
" agents: dict[str, Agent],\n",
|
||||
" checkpoint_dir: Path,\n",
|
||||
" episodes_per_side: int = 10,\n",
|
||||
") -> tuple[np.ndarray, list[str]]:\n",
|
||||
" \"\"\"Round-robin tournament excluding Random, with seat-swap fairness.\"\"\"\n",
|
||||
@@ -1507,9 +1255,9 @@
|
||||
" # Keep only agents that have a checkpoint\n",
|
||||
" ready_names: list[str] = []\n",
|
||||
" for name in candidate_names:\n",
|
||||
" ckpt_path = _ckpt_path(name)\n",
|
||||
" if ckpt_path.exists():\n",
|
||||
" agents[name].load(str(ckpt_path))\n",
|
||||
" checkpoint_path = get_path(name)\n",
|
||||
" if checkpoint_path.exists():\n",
|
||||
" agents[name].load(str(checkpoint_path))\n",
|
||||
" ready_names.append(name)\n",
|
||||
"\n",
|
||||
" if len(ready_names) < 2:\n",
|
||||
@@ -1568,7 +1316,6 @@
|
||||
"# Run tournament (non-random agents only)\n",
|
||||
"win_matrix_pz, pz_names = run_pettingzoo_tournament(\n",
|
||||
" agents=agents,\n",
|
||||
" checkpoint_dir=CHECKPOINT_DIR,\n",
|
||||
" episodes_per_side=10,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
@@ -1613,16 +1360,22 @@
|
||||
"\n",
|
||||
"This tournament uses `from pettingzoo.atari import tennis_v3` to make trained agents play against each other directly.\n",
|
||||
"\n",
|
||||
"- Checkpoints are loaded from `checkpoints/` (`.pkl` for linear agents, `.pt` for DQN)\n",
|
||||
"- Checkpoints are loaded from `checkpoints/` (`.pkl`)\n",
|
||||
"- `Random` is **excluded** from ranking\n",
|
||||
"- Each pair plays in both seat positions (`first_0` and `second_0`) to reduce position bias\n",
|
||||
"- A win-rate matrix and final ranking are produced"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "150e6764",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "studies (3.13.9)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -1636,7 +1389,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.12"
|
||||
"version": "3.13.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user