Compare commits

...

2 Commits

Author SHA1 Message Date
ae74e6f585 Implement feature X to enhance user experience and optimize performance 2026-03-03 11:21:57 +01:00
7f2f820650 Update dependencies in pyproject.toml for improved compatibility and features
- Bump catboost from 1.2.8 to 1.2.10
- Update google-api-python-client from 2.190.0 to 2.191.0
- Upgrade langchain from 1.2.0 to 1.2.10
- Update langchain-core from 1.2.16 to 1.2.17
- Upgrade langchain-huggingface from 1.2.0 to 1.2.1
- Bump marimo from 0.19.11 to 0.20.2
- Update matplotlib from 3.10.1 to 3.10.8
- Upgrade numpy from 2.2.5 to 2.4.2
- Update opencv-python from 4.11.0.86 to 4.13.0.92
- Bump pandas from 2.2.3 to 3.0.1
- Update plotly from 6.3.0 to 6.6.0
- Upgrade polars from 1.37.0 to 1.38.1
- Bump rasterio from 1.4.4 to 1.5.0
- Update scikit-learn from 1.6.1 to 1.8.0
- Upgrade scipy from 1.15.2 to 1.17.1
- Bump shap from 0.49.1 to 0.50.0
- Adjust isort section order for better readability
2026-03-03 10:59:12 +01:00
2 changed files with 136 additions and 415 deletions

View File

@@ -5,41 +5,23 @@
"id": "fef45687",
"metadata": {},
"source": [
"# RL Project: Atari Tennis Tournament (NumPy Agents)\n",
"# RL Project: Atari Tennis Tournament\n",
"\n",
"This notebook implements four Reinforcement Learning algorithms **without PyTorch** to play Atari Tennis (`ALE/Tennis-v5` via Gymnasium):\n",
"This notebook implements four Reinforcement Learning algorithms to play Atari Tennis (`ALE/Tennis-v5` via Gymnasium):\n",
"\n",
"1. **SARSA** — Semi-gradient SARSA with linear approximation (inspired by Lab 7, on-policy update from Lab 5B)\n",
"2. **Q-Learning** — Off-policy linear approximation (inspired by Lab 5B)\n",
"3. **DQN** — Deep Q-Network with pure numpy MLP, experience replay and target network (inspired by Lab 6A + classic DQN)\n",
"4. **Monte Carlo** — First-visit MC control with linear approximation (inspired by Lab 4)\n",
"3. **Monte Carlo** — First-visit MC control with linear approximation (inspired by Lab 4)\n",
"\n",
"Each agent is **pre-trained independently** against the built-in Atari AI opponent, then evaluated in a comparative tournament."
]
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 16,
"id": "b50d7174",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'pygame'",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[39]\u001b[39m\u001b[32m, line 9\u001b[39m\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01msupersuit\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mss\u001b[39;00m\n\u001b[32m 8\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mgymnasium\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mwrappers\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m FrameStackObservation, ResizeObservation\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpettingzoo\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01matari\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m tennis_v3\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtqdm\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mauto\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m tqdm\n\u001b[32m 12\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmatplotlib\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mpyplot\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mplt\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/atari/__init__.py:5\u001b[39m, in \u001b[36m__getattr__\u001b[39m\u001b[34m(env_name)\u001b[39m\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__getattr__\u001b[39m(env_name):\n\u001b[32m----> \u001b[39m\u001b[32m5\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mdeprecated_handler\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m__path__\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[34;43m__name__\u001b[39;49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/deprecated_module.py:70\u001b[39m, in \u001b[36mdeprecated_handler\u001b[39m\u001b[34m(env_name, module_path, module_name)\u001b[39m\n\u001b[32m 67\u001b[39m \u001b[38;5;66;03m# This executes the module and will raise any exceptions\u001b[39;00m\n\u001b[32m 68\u001b[39m \u001b[38;5;66;03m# that would typically be raised by just `import blah`\u001b[39;00m\n\u001b[32m 69\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m spec.loader\n\u001b[32m---> \u001b[39m\u001b[32m70\u001b[39m \u001b[43mspec\u001b[49m\u001b[43m.\u001b[49m\u001b[43mloader\u001b[49m\u001b[43m.\u001b[49m\u001b[43mexec_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodule\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 71\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m module\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/atari/tennis_v3.py:1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpettingzoo\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01matari\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtennis\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtennis\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m env, parallel_env, raw_env\n\u001b[32m 3\u001b[39m __all__ = [\u001b[33m\"\u001b[39m\u001b[33menv\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mparallel_env\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mraw_env\u001b[39m\u001b[33m\"\u001b[39m]\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/atari/tennis/tennis.py:77\u001b[39m\n\u001b[32m 74\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mos\u001b[39;00m\n\u001b[32m 75\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mglob\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m glob\n\u001b[32m---> \u001b[39m\u001b[32m77\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpettingzoo\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01matari\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbase_atari_env\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[32m 78\u001b[39m BaseAtariEnv,\n\u001b[32m 79\u001b[39m base_env_wrapper_fn,\n\u001b[32m 80\u001b[39m parallel_wrapper_fn,\n\u001b[32m 81\u001b[39m )\n\u001b[32m 84\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mraw_env\u001b[39m(**kwargs):\n\u001b[32m 85\u001b[39m name = os.path.basename(\u001b[34m__file__\u001b[39m).split(\u001b[33m\"\u001b[39m\u001b[33m.\u001b[39m\u001b[33m\"\u001b[39m)[\u001b[32m0\u001b[39m]\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/atari/base_atari_env.py:6\u001b[39m\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmulti_agent_ale_py\u001b[39;00m\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpygame\u001b[39;00m\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mgymnasium\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m spaces\n\u001b[32m 8\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mgymnasium\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m EzPickle, seeding\n",
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'pygame'"
]
}
],
"outputs": [],
"source": [
"import itertools\n",
"import pickle\n",
@@ -59,13 +41,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"id": "ff3486a4",
"metadata": {},
"outputs": [],
"source": [
"CHECKPOINT_DIR = Path(\"checkpoints\")\n",
"CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n"
"CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n",
"\n",
"\n",
"def get_path(name: str) -> Path:\n",
" \"\"\"Return the checkpoint path for an agent (.pkl).\"\"\"\n",
" base = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\")\n",
" return CHECKPOINT_DIR / (base + \".pkl\")\n"
]
},
{
@@ -89,7 +77,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 18,
"id": "be85c130",
"metadata": {},
"outputs": [],
@@ -173,7 +161,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 19,
"id": "ded9b1fb",
"metadata": {},
"outputs": [],
@@ -227,7 +215,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 20,
"id": "78bdc9d2",
"metadata": {},
"outputs": [],
@@ -260,7 +248,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 21,
"id": "c124ed9a",
"metadata": {},
"outputs": [],
@@ -405,7 +393,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 22,
"id": "f5b5b9ea",
"metadata": {},
"outputs": [],
@@ -506,283 +494,6 @@
" self.W[action] = np.nan_to_num(self.W[action], nan=0.0, posinf=1e6, neginf=-1e6)\n"
]
},
{
"cell_type": "markdown",
"id": "f644e2ef",
"metadata": {},
"source": [
"## DQN Agent — NumPy MLP with Experience Replay and Target Network\n",
"\n",
"This agent implements the Deep Q-Network (DQN) **entirely in numpy**, without PyTorch.\n",
"\n",
"**Network architecture** (identical to the original DQNAgent structure):\n",
"$$\\text{Input}(n\\_features) \\to \\text{Linear}(128) \\to \\text{ReLU} \\to \\text{Linear}(128) \\to \\text{ReLU} \\to \\text{Linear}(n\\_actions)$$\n",
"\n",
"**Key techniques** (inspired by Lab 6A Dyna-Q + classic DQN):\n",
"- **Experience Replay**: like the `self.model` dict in Dyna-Q (Lab 6A) that stores past transitions, we store transitions in a circular buffer and sample minibatches for updates.\n",
"- **Target Network**: periodically synchronized copy of the network, stabilizes learning.\n",
"- **Manual backpropagation**: MSE loss gradient backpropagated layer by layer."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec090a9a",
"metadata": {},
"outputs": [],
"source": [
"def _relu(x: np.ndarray) -> np.ndarray:\n",
" \"\"\"ReLU activation: max(0, x).\"\"\"\n",
" return np.maximum(0, x)\n",
"\n",
"\n",
"def _relu_grad(x: np.ndarray) -> np.ndarray:\n",
" \"\"\"Compute the ReLU derivative: return 1 where x > 0, else 0.\"\"\"\n",
" return (x > 0).astype(np.float64)\n",
"\n",
"\n",
"def _init_layer(\n",
" fan_in: int, fan_out: int, rng: np.random.Generator,\n",
") -> tuple[np.ndarray, np.ndarray]:\n",
" \"\"\"Initialize a linear layer with He initialization.\n",
"\n",
" He init: W ~ N(0, sqrt(2/fan_in)) — standard for ReLU networks.\n",
"\n",
" Args:\n",
" fan_in: Number of input features.\n",
" fan_out: Number of output features.\n",
" rng: Random number generator.\n",
"\n",
" Returns:\n",
" Weight matrix W of shape (fan_in, fan_out) and bias vector b of shape (fan_out,).\n",
"\n",
" \"\"\"\n",
" scale = np.sqrt(2.0 / fan_in)\n",
" W = rng.normal(0, scale, size=(fan_in, fan_out)).astype(np.float64)\n",
" b = np.zeros(fan_out, dtype=np.float64)\n",
" return W, b\n",
"\n",
"\n",
"class DQNAgent(Agent):\n",
" \"\"\"Deep Q-Network agent with a 2-hidden-layer MLP implemented in pure numpy.\n",
"\n",
" Architecture: Input -> Linear(128) -> ReLU -> Linear(128) -> ReLU -> Linear(n_actions)\n",
" Matches the original PyTorch DQNAgent structure.\n",
"\n",
" Features:\n",
" - Experience replay buffer (like Dyna-Q model in Lab 6A: store past transitions, sample for updates)\n",
" - Target network (periodically synchronized copy for stable targets)\n",
" - Manual forward pass and backpropagation through the MLP\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" n_features: int,\n",
" n_actions: int,\n",
" lr: float = 0.0001,\n",
" gamma: float = 0.99,\n",
" buffer_size: int = 10000,\n",
" batch_size: int = 64,\n",
" target_update_freq: int = 1000,\n",
" seed: int = 42,\n",
" ) -> None:\n",
" \"\"\"Initialize DQN agent.\n",
"\n",
" Args:\n",
" n_features: Input feature dimension.\n",
" n_actions: Number of discrete actions.\n",
" lr: Learning rate for gradient descent.\n",
" gamma: Discount factor.\n",
" buffer_size: Maximum size of the replay buffer.\n",
" batch_size: Minibatch size for updates.\n",
" target_update_freq: Steps between target network syncs.\n",
" seed: RNG seed.\n",
"\n",
" \"\"\"\n",
" super().__init__(seed, n_actions)\n",
" self.n_features = n_features\n",
" self.lr = lr\n",
" self.gamma = gamma\n",
" self.buffer_size = buffer_size\n",
" self.batch_size = batch_size\n",
" self.target_update_freq = target_update_freq\n",
" self.update_step = 0\n",
"\n",
" # Initialize network parameters: 3 layers\n",
" # Layer 1: n_features -> 128\n",
" # Layer 2: 128 -> 128\n",
" # Layer 3: 128 -> n_actions\n",
" self.params = self._init_params(self.rng)\n",
"\n",
" # Target network: deep copy of params (like Dyna-Q's model copy concept)\n",
" self.target_params = {k: v.copy() for k, v in self.params.items()}\n",
"\n",
" # Experience replay buffer: list of (s, a, r, s', done) tuples\n",
" # Analogous to Dyna-Q's self.model dict + self.observed_sa in Lab 6A,\n",
" # but storing raw transitions for off-policy sampling\n",
" self.replay_buffer: list[tuple[np.ndarray, int, float, np.ndarray, bool]] = []\n",
"\n",
" def _init_params(self, rng: np.random.Generator) -> dict[str, np.ndarray]:\n",
" \"\"\"Initialize all MLP parameters with He initialization.\"\"\"\n",
" W1, b1 = _init_layer(self.n_features, 128, rng)\n",
" W2, b2 = _init_layer(128, 128, rng)\n",
" W3, b3 = _init_layer(128, self.action_space, rng)\n",
" return {\"W1\": W1, \"b1\": b1, \"W2\": W2, \"b2\": b2, \"W3\": W3, \"b3\": b3}\n",
"\n",
" def _forward(\n",
" self, x: np.ndarray, params: dict[str, np.ndarray],\n",
" ) -> tuple[np.ndarray, dict[str, np.ndarray]]:\n",
" \"\"\"Forward pass through the MLP. Returns output and cached activations for backprop.\n",
"\n",
" Architecture: x -> Linear -> ReLU -> Linear -> ReLU -> Linear -> output\n",
"\n",
" Args:\n",
" x: Input array of shape (batch, n_features) or (n_features,).\n",
" params: Dictionary of network parameters.\n",
"\n",
" Returns:\n",
" output: Q-values of shape (batch, n_actions) or (n_actions,).\n",
" cache: Dictionary of intermediate values for backpropagation.\n",
"\n",
" \"\"\"\n",
" # Layer 1\n",
" z1 = x @ params[\"W1\"] + params[\"b1\"]\n",
" h1 = _relu(z1)\n",
" # Layer 2\n",
" z2 = h1 @ params[\"W2\"] + params[\"b2\"]\n",
" h2 = _relu(z2)\n",
" # Layer 3 (output, no activation)\n",
" out = h2 @ params[\"W3\"] + params[\"b3\"]\n",
" cache = {\"x\": x, \"z1\": z1, \"h1\": h1, \"z2\": z2, \"h2\": h2}\n",
" return out, cache\n",
"\n",
" def _backward(\n",
" self,\n",
" d_out: np.ndarray,\n",
" cache: dict[str, np.ndarray],\n",
" params: dict[str, np.ndarray],\n",
" ) -> dict[str, np.ndarray]:\n",
" \"\"\"Backward pass: compute gradients of loss w.r.t. all parameters.\n",
"\n",
" Args:\n",
" d_out: Gradient of loss w.r.t. output, shape (batch, n_actions).\n",
" cache: Cached activations from forward pass.\n",
" params: Current network parameters.\n",
"\n",
" Returns:\n",
" Dictionary of gradients for each parameter.\n",
"\n",
" \"\"\"\n",
" batch = d_out.shape[0] if d_out.ndim > 1 else 1\n",
" if d_out.ndim == 1:\n",
" d_out = d_out.reshape(1, -1)\n",
"\n",
" x = cache[\"x\"] if cache[\"x\"].ndim > 1 else cache[\"x\"].reshape(1, -1)\n",
" h1 = cache[\"h1\"] if cache[\"h1\"].ndim > 1 else cache[\"h1\"].reshape(1, -1)\n",
" h2 = cache[\"h2\"] if cache[\"h2\"].ndim > 1 else cache[\"h2\"].reshape(1, -1)\n",
" z1 = cache[\"z1\"] if cache[\"z1\"].ndim > 1 else cache[\"z1\"].reshape(1, -1)\n",
" z2 = cache[\"z2\"] if cache[\"z2\"].ndim > 1 else cache[\"z2\"].reshape(1, -1)\n",
"\n",
" # Layer 3 gradients\n",
" dW3 = h2.T @ d_out / batch\n",
" db3 = d_out.mean(axis=0)\n",
" dh2 = d_out @ params[\"W3\"].T\n",
"\n",
" # Layer 2 gradients (through ReLU)\n",
" dz2 = dh2 * _relu_grad(z2)\n",
" dW2 = h1.T @ dz2 / batch\n",
" db2 = dz2.mean(axis=0)\n",
" dh1 = dz2 @ params[\"W2\"].T\n",
"\n",
" # Layer 1 gradients (through ReLU)\n",
" dz1 = dh1 * _relu_grad(z1)\n",
" dW1 = x.T @ dz1 / batch\n",
" db1 = dz1.mean(axis=0)\n",
"\n",
" return {\"W1\": dW1, \"b1\": db1, \"W2\": dW2, \"b2\": db2, \"W3\": dW3, \"b3\": db3}\n",
"\n",
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
" \"\"\"Select action using ε-greedy policy over MLP Q-values.\"\"\"\n",
" phi = normalize_obs(observation)\n",
" q_vals, _ = self._forward(phi, self.params)\n",
" if q_vals.ndim > 1:\n",
" q_vals = q_vals.squeeze(0)\n",
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
"\n",
" def update(\n",
" self,\n",
" state: np.ndarray,\n",
" action: int,\n",
" reward: float,\n",
" next_state: np.ndarray,\n",
" done: bool,\n",
" next_action: int | None = None,\n",
" ) -> None:\n",
" \"\"\"Store transition and perform a minibatch DQN update.\n",
"\n",
" Steps:\n",
" 1. Add transition to replay buffer (like Dyna-Q model update in Lab 6A)\n",
" 2. If buffer has enough samples, sample a minibatch\n",
" 3. Compute targets using target network (Q-learning: max_a' Q_target(s', a'))\n",
" 4. Backpropagate MSE loss through the MLP\n",
" 5. Update weights with gradient descent\n",
" 6. Periodically sync target network\n",
"\n",
" \"\"\"\n",
" _ = next_action # DQN is off-policy\n",
"\n",
" # Store transition in replay buffer\n",
" # Analogous to Lab 6A: self.model[(s,a)] = (r, sp)\n",
" phi_s = normalize_obs(state)\n",
" phi_sp = normalize_obs(next_state)\n",
" self.replay_buffer.append((phi_s, action, reward, phi_sp, done))\n",
" if len(self.replay_buffer) > self.buffer_size:\n",
" self.replay_buffer.pop(0)\n",
"\n",
" # Don't update until we have enough samples\n",
" if len(self.replay_buffer) < self.batch_size:\n",
" return\n",
"\n",
" # Sample minibatch from replay buffer\n",
" # Like Dyna-Q planning: sample from observed transitions (Lab 6A)\n",
" idx = self.rng.choice(len(self.replay_buffer), size=self.batch_size, replace=False)\n",
" batch = [self.replay_buffer[i] for i in idx]\n",
"\n",
" states_b = np.array([t[0] for t in batch]) # (batch, n_features)\n",
" actions_b = np.array([t[1] for t in batch]) # (batch,)\n",
" rewards_b = np.array([t[2] for t in batch]) # (batch,)\n",
" next_states_b = np.array([t[3] for t in batch]) # (batch, n_features)\n",
" dones_b = np.array([t[4] for t in batch], dtype=np.float64)\n",
"\n",
" # Forward pass on current states\n",
" q_all, cache = self._forward(states_b, self.params)\n",
" # Gather Q-values for the taken actions\n",
" q_curr = q_all[np.arange(self.batch_size), actions_b] # (batch,)\n",
"\n",
" # Compute targets using target network (off-policy: max over actions)\n",
" # Follows Lab 5B Q-learning: td_target = r + gamma * max(Q[s2]) * (0 if terminated else 1)\n",
" q_next_all, _ = self._forward(next_states_b, self.target_params)\n",
" q_next_max = np.max(q_next_all, axis=1) # (batch,)\n",
" targets = rewards_b + (1.0 - dones_b) * self.gamma * q_next_max\n",
"\n",
" # MSE loss gradient: d_loss/d_q_curr = 2 * (q_curr - targets) / batch\n",
" # We only backprop through the action that was taken\n",
" d_out = np.zeros_like(q_all)\n",
" d_out[np.arange(self.batch_size), actions_b] = 2.0 * (q_curr - targets) / self.batch_size\n",
"\n",
" # Backward pass\n",
" grads = self._backward(d_out, cache, self.params)\n",
"\n",
" # Gradient descent update\n",
" for key in self.params:\n",
" self.params[key] -= self.lr * grads[key]\n",
"\n",
" # Sync target network periodically\n",
" self.update_step += 1\n",
" if self.update_step % self.target_update_freq == 0:\n",
" self.target_params = {k: v.copy() for k, v in self.params.items()}\n"
]
},
{
"cell_type": "markdown",
"id": "b7b63455",
@@ -802,7 +513,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 23,
"id": "3c9d74be",
"metadata": {},
"outputs": [],
@@ -933,7 +644,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 24,
"id": "f9a973dd",
"metadata": {},
"outputs": [],
@@ -973,7 +684,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 25,
"id": "06b91580",
"metadata": {},
"outputs": [],
@@ -1226,19 +937,18 @@
"- **Random** — random baseline (no training needed)\n",
"- **SARSA** — linear approximation, semi-gradient TD(0)\n",
"- **Q-Learning** — linear approximation, off-policy\n",
"- **DQN** — NumPy MLP (2 hidden layers of 128, experience replay, target network)\n",
"- **Monte Carlo** — linear approximation, first-visit returns\n",
"\n",
"**Workflow**:\n",
"1. Train **one** selected agent (`AGENT_TO_TRAIN`)\n",
"2. Save its weights to `checkpoints/<agent>.pkl`\n",
"2. Save its weights to `checkpoints/` (`.pkl`)\n",
"3. Repeat later for another agent without retraining previous ones\n",
"4. Load all saved checkpoints before the final evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 26,
"id": "6f6ba8df",
"metadata": {},
"outputs": [
@@ -1268,21 +978,19 @@
"agent_random = RandomAgent(seed=42, action_space=int(n_actions))\n",
"agent_sarsa = SarsaAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
"agent_q = QLearningAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
"agent_dqn = DQNAgent(n_features=n_features, n_actions=n_actions, lr=1e-4)\n",
"agent_mc = MonteCarloAgent(n_features=n_features, n_actions=n_actions, alpha=1e-5)\n",
"\n",
"agents = {\n",
" \"Random\": agent_random,\n",
" \"SARSA\": agent_sarsa,\n",
" \"Q-Learning\": agent_q,\n",
" \"DQN\": agent_dqn,\n",
" \"Monte Carlo\": agent_mc,\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": null,
"id": "4d449701",
"metadata": {},
"outputs": [
@@ -1290,39 +998,31 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Selected agent: Q-Learning\n",
"Checkpoint path: checkpoints/q_learning.pkl\n",
"Selected agent: Monte Carlo\n",
"Checkpoint path: checkpoints/monte_carlo.pkl\n",
"\n",
"============================================================\n",
"Training: Q-Learning (2500 episodes)\n",
"Training: Monte Carlo (2500 episodes)\n",
"============================================================\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "02b31c90057641a09f7e6bb3f3c53f79",
"model_id": "83fbf34f73ba481f906d78f0686d3c89",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training Q-Learning: 0%| | 0/2500 [00:00<?, ?it/s]"
"Training Monte Carlo: 0%| | 0/2500 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"-> Q-Learning avg reward (last 100 eps): -0.19\n",
"Checkpoint saved.\n"
]
}
],
"source": [
"AGENT_TO_TRAIN = \"Q-Learning\" # TODO: change to: \"Q-Learning\", \"DQN\", \"Monte Carlo\", \"Random\"\n",
"AGENT_TO_TRAIN = \"Monte Carlo\" # TODO: change to: \"Q-Learning\", \"Monte Carlo\", \"Random\"\n",
"TRAINING_EPISODES = 2500\n",
"FORCE_RETRAIN = False\n",
"\n",
@@ -1332,17 +1032,16 @@
"\n",
"training_histories: dict[str, list[float]] = {}\n",
"agent = agents[AGENT_TO_TRAIN]\n",
"ckpt_name = AGENT_TO_TRAIN.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n",
"ckpt_path = CHECKPOINT_DIR / ckpt_name\n",
"checkpoint_path = get_path(AGENT_TO_TRAIN)\n",
"\n",
"print(f\"Selected agent: {AGENT_TO_TRAIN}\")\n",
"print(f\"Checkpoint path: {ckpt_path}\")\n",
"print(f\"Checkpoint path: {checkpoint_path}\")\n",
"\n",
"if AGENT_TO_TRAIN == \"Random\":\n",
" print(\"Random is a baseline and is not trained.\")\n",
" training_histories[AGENT_TO_TRAIN] = []\n",
"elif ckpt_path.exists() and not FORCE_RETRAIN:\n",
" agent.load(str(ckpt_path))\n",
"elif checkpoint_path.exists() and not FORCE_RETRAIN:\n",
" agent.load(str(checkpoint_path))\n",
" print(\"Checkpoint found -> weights loaded, training skipped.\")\n",
" training_histories[AGENT_TO_TRAIN] = []\n",
"else:\n",
@@ -1363,13 +1062,13 @@
" avg_last_100 = np.mean(training_histories[AGENT_TO_TRAIN][-100:])\n",
" print(f\"-> {AGENT_TO_TRAIN} avg reward (last 100 eps): {avg_last_100:.2f}\")\n",
"\n",
" agent.save(str(ckpt_path))\n",
" agent.save(str(checkpoint_path))\n",
" print(\"Checkpoint saved.\")\n"
]
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": null,
"id": "a13a65df",
"metadata": {},
"outputs": [
@@ -1390,45 +1089,6 @@
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0fc0c643",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded checkpoints: ['SARSA']\n",
"Missing checkpoints (untrained/not saved yet): ['Q-Learning', 'DQN', 'Monte Carlo']\n"
]
}
],
"source": [
"# Load all available checkpoints before final evaluation\n",
"CHECKPOINT_DIR = Path(\"checkpoints\")\n",
"loaded_agents: list[str] = []\n",
"missing_agents: list[str] = []\n",
"\n",
"for name, agent in agents.items():\n",
" if name == \"Random\":\n",
" continue\n",
"\n",
" ckpt_name = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n",
" ckpt_path = CHECKPOINT_DIR / ckpt_name\n",
"\n",
" if ckpt_path.exists():\n",
" agent.load(str(ckpt_path))\n",
" loaded_agents.append(name)\n",
" else:\n",
" missing_agents.append(name)\n",
"\n",
"print(f\"Loaded checkpoints: {loaded_agents}\")\n",
"if missing_agents:\n",
" print(f\"Missing checkpoints (untrained/not saved yet): {missing_agents}\")\n"
]
},
{
"cell_type": "markdown",
"id": "0e5f2c49",
@@ -1445,7 +1105,53 @@
"execution_count": null,
"id": "70f5d5cd",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoints/sarsa.pkl\n",
"checkpoints/q_learning.pkl\n",
"checkpoints/monte_carlo.pkl\n",
"Agents evaluated: ['Random', 'SARSA', 'Q-Learning']\n",
"Skipped (no checkpoint yet): ['Monte Carlo']\n",
"[Evaluation 1/3] Random\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "46b0cdfc1baa4ef2a8dcc20c6924346d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluating Random: 0%| | 0/20 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[15]\u001b[39m\u001b[32m, line 24\u001b[39m\n\u001b[32m 21\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(eval_agents) < \u001b[32m2\u001b[39m:\n\u001b[32m 22\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mTrain at least one non-random agent before final evaluation.\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m24\u001b[39m results = \u001b[43mevaluate_tournament\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meval_agents\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepisodes_per_agent\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m20\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 25\u001b[39m plot_evaluation_comparison(results)\n\u001b[32m 27\u001b[39m \u001b[38;5;66;03m# Print summary table\u001b[39;00m\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[12]\u001b[39m\u001b[32m, line 226\u001b[39m, in \u001b[36mevaluate_tournament\u001b[39m\u001b[34m(env, agents, episodes_per_agent)\u001b[39m\n\u001b[32m 224\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m idx, (name, agent) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(agents.items(), start=\u001b[32m1\u001b[39m):\n\u001b[32m 225\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m[Evaluation \u001b[39m\u001b[38;5;132;01m{\u001b[39;00midx\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mn_agents\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m] \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m226\u001b[39m results[name] = \u001b[43mevaluate_agent\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 227\u001b[39m \u001b[43m \u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magent\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepisodes\u001b[49m\u001b[43m=\u001b[49m\u001b[43mepisodes_per_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 228\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 229\u001b[39m mean_r = results[name][\u001b[33m\"\u001b[39m\u001b[33mmean_reward\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 230\u001b[39m wr = results[name][\u001b[33m\"\u001b[39m\u001b[33mwin_rate\u001b[39m\u001b[33m\"\u001b[39m]\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[12]\u001b[39m\u001b[32m, line 117\u001b[39m, in \u001b[36mevaluate_agent\u001b[39m\u001b[34m(env, agent, name, episodes, max_steps)\u001b[39m\n\u001b[32m 115\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_steps):\n\u001b[32m 116\u001b[39m action = agent.get_action(np.asarray(obs), epsilon=\u001b[32m0.0\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m117\u001b[39m obs, reward, terminated, truncated, _info = \u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 118\u001b[39m reward = \u001b[38;5;28mfloat\u001b[39m(reward)\n\u001b[32m 119\u001b[39m total_reward += reward\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/stateful_observation.py:425\u001b[39m, in \u001b[36mFrameStackObservation.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 414\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 415\u001b[39m \u001b[38;5;28mself\u001b[39m, action: WrapperActType\n\u001b[32m 416\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 417\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Steps through the environment, appending the observation to the frame buffer.\u001b[39;00m\n\u001b[32m 418\u001b[39m \n\u001b[32m 419\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 423\u001b[39m \u001b[33;03m Stacked observations, reward, terminated, truncated, and info from the environment\u001b[39;00m\n\u001b[32m 424\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m425\u001b[39m obs, reward, terminated, truncated, info = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 426\u001b[39m \u001b[38;5;28mself\u001b[39m.obs_queue.append(obs)\n\u001b[32m 428\u001b[39m updated_obs = deepcopy(\n\u001b[32m 429\u001b[39m concatenate(\u001b[38;5;28mself\u001b[39m.env.observation_space, \u001b[38;5;28mself\u001b[39m.obs_queue, \u001b[38;5;28mself\u001b[39m.stacked_obs)\n\u001b[32m 430\u001b[39m )\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/core.py:560\u001b[39m, in \u001b[36mObservationWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 556\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 557\u001b[39m \u001b[38;5;28mself\u001b[39m, action: ActType\n\u001b[32m 558\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 559\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Modifies the :attr:`env` after calling :meth:`step` using :meth:`self.observation` on the returned observations.\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m560\u001b[39m observation, reward, terminated, truncated, info = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 561\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.observation(observation), reward, terminated, truncated, info\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/common.py:393\u001b[39m, in \u001b[36mOrderEnforcing.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 391\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._has_reset:\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m ResetNeeded(\u001b[33m\"\u001b[39m\u001b[33mCannot call env.step() before calling env.reset()\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m393\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/core.py:327\u001b[39m, in \u001b[36mWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 323\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 324\u001b[39m \u001b[38;5;28mself\u001b[39m, action: WrapperActType\n\u001b[32m 325\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 326\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data.\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m327\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/common.py:285\u001b[39m, in \u001b[36mPassiveEnvChecker.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 283\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m env_step_passive_checker(\u001b[38;5;28mself\u001b[39m.env, action)\n\u001b[32m 284\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m285\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/ale_py/env.py:305\u001b[39m, in \u001b[36mAtariEnv.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 303\u001b[39m reward = \u001b[32m0.0\u001b[39m\n\u001b[32m 304\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(frameskip):\n\u001b[32m--> \u001b[39m\u001b[32m305\u001b[39m reward += \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43male\u001b[49m\u001b[43m.\u001b[49m\u001b[43mact\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrength\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 307\u001b[39m is_terminal = \u001b[38;5;28mself\u001b[39m.ale.game_over(with_truncation=\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m 308\u001b[39m is_truncated = \u001b[38;5;28mself\u001b[39m.ale.game_truncated()\n",
"\u001b[31mKeyboardInterrupt\u001b[39m: "
]
}
],
"source": [
"# Build evaluation set: Random + agents with existing checkpoints\n",
"eval_agents: dict[str, Agent] = {\"Random\": agents[\"Random\"]}\n",
@@ -1455,11 +1161,10 @@
" if name == \"Random\":\n",
" continue\n",
"\n",
" ckpt_name = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n",
" ckpt_path = CHECKPOINT_DIR / ckpt_name\n",
" checkpoint_path = get_path(name)\n",
"\n",
" if ckpt_path.exists():\n",
" agent.load(str(ckpt_path))\n",
" if checkpoint_path.exists():\n",
" agent.load(str(checkpoint_path))\n",
" eval_agents[name] = agent\n",
" else:\n",
" missing_agents.append(name)\n",
@@ -1541,7 +1246,6 @@
"\n",
"def run_pettingzoo_tournament(\n",
" agents: dict[str, Agent],\n",
" checkpoint_dir: Path,\n",
" episodes_per_side: int = 10,\n",
") -> tuple[np.ndarray, list[str]]:\n",
" \"\"\"Round-robin tournament excluding Random, with seat-swap fairness.\"\"\"\n",
@@ -1551,10 +1255,9 @@
" # Keep only agents that have a checkpoint\n",
" ready_names: list[str] = []\n",
" for name in candidate_names:\n",
" ckpt_name = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n",
" ckpt_path = checkpoint_dir / ckpt_name\n",
" if ckpt_path.exists():\n",
" agents[name].load(str(ckpt_path))\n",
" checkpoint_path = get_path(name)\n",
" if checkpoint_path.exists():\n",
" agents[name].load(str(checkpoint_path))\n",
" ready_names.append(name)\n",
"\n",
" if len(ready_names) < 2:\n",
@@ -1613,7 +1316,6 @@
"# Run tournament (non-random agents only)\n",
"win_matrix_pz, pz_names = run_pettingzoo_tournament(\n",
" agents=agents,\n",
" checkpoint_dir=CHECKPOINT_DIR,\n",
" episodes_per_side=10,\n",
")\n",
"\n",
@@ -1658,11 +1360,17 @@
"\n",
"This tournament uses `from pettingzoo.atari import tennis_v3` to make trained agents play against each other directly.\n",
"\n",
"- Checkpoints are loaded from `checkpoints/*.pkl`\n",
"- Checkpoints are loaded from `checkpoints/` (`.pkl`)\n",
"- `Random` is **excluded** from ranking\n",
"- Each pair plays in both seat positions (`first_0` and `second_0`) to reduce position bias\n",
"- A win-rate matrix and final ranking are produced"
]
},
{
"cell_type": "markdown",
"id": "150e6764",
"metadata": {},
"source": []
}
],
"metadata": {

View File

@@ -7,48 +7,48 @@ requires-python = ">= 3.11"
dependencies = [
"accelerate>=1.12.0",
"ale-py>=0.11.2",
"catboost>=1.2.8",
"catboost>=1.2.10",
"datasets>=4.6.1",
"faiss-cpu>=1.13.2",
"folium>=0.20.0",
"geopandas>=1.1.2",
"google-api-python-client>=2.190.0",
"google-api-python-client>=2.191.0",
"google-auth-oauthlib>=1.3.0",
"google-generativeai>=0.8.6",
"gymnasium[toy-text]>=1.2.3",
"imblearn>=0.0",
"ipykernel>=7.2.0",
"ipywidgets>=8.1.8",
"langchain>=1.2.0",
"langchain>=1.2.10",
"langchain-community>=0.4.1",
"langchain-core>=1.2.16",
"langchain-huggingface>=1.2.0",
"langchain-core>=1.2.17",
"langchain-huggingface>=1.2.1",
"langchain-mistralai>=1.1.1",
"langchain-ollama>=1.0.1",
"langchain-openai>=1.1.10",
"langchain-text-splitters>=1.1.1",
"mapclassify>=2.10.0",
"marimo>=0.19.11",
"matplotlib>=3.10.1",
"marimo>=0.20.2",
"matplotlib>=3.10.8",
"nbformat>=5.10.4",
"numpy>=2.2.5",
"opencv-python>=4.11.0.86",
"numpy>=2.4.2",
"opencv-python>=4.13.0.92",
"openpyxl>=3.1.5",
"pandas>=2.2.3",
"pandas>=3.0.1",
"pandas-stubs>=3.0.0.260204",
"pettingzoo[atari]>=1.25.0",
"plotly>=6.3.0",
"polars>=1.37.0",
"plotly>=6.6.0",
"polars>=1.38.1",
"pyclustertend>=1.9.0",
"pypdf>=6.7.5",
"rasterio>=1.4.4",
"rasterio>=1.5.0",
"requests>=2.32.5",
"scikit-learn>=1.6.1",
"scipy>=1.15.2",
"scikit-learn>=1.8.0",
"scipy>=1.17.1",
"seaborn>=0.13.2",
"sentence-transformers>=5.2.3",
# "sequenzo>=0.1.20",
"shap>=0.49.1",
"shap>=0.50.0",
"spacy>=3.8.11",
"statsmodels>=0.14.6",
"supersuit>=3.10.0",
@@ -65,10 +65,7 @@ dependencies = [
]
[dependency-groups]
dev = [
"ipykernel>=7.2.0",
"uv>=0.10.7",
]
dev = ["ipykernel>=7.2.0", "uv>=0.10.7"]
[tool.ty.rules]
index-out-of-bounds = "ignore"
@@ -83,18 +80,18 @@ select = ["ALL"]
# Désactiver certaines règles
ignore = [
"E501", # line too long, géré par le formatter
"E402", # Imports in top of file
"T201", # Print
"E501", # line too long, géré par le formatter
"E402", # Imports in top of file
"T201", # Print
"N806",
"N803",
"N802",
"N816",
"PLR0913", # Too many arguments
"PLR2004", #
"FBT001", # Boolean positional argument used in function definition
"FBT001", # Boolean positional argument used in function definition
"EM101",
"TRY003", # Avoid specifying long messages outside the exception class
"TRY003", # Avoid specifying long messages outside the exception class
]
# Exclure certains fichiers ou répertoires
@@ -127,9 +124,25 @@ unfixable = []
[tool.ruff.lint.isort]
# Regrouper les imports par thématiques
section-order = ["future", "standard-library", "third-party", "data-science", "ml", "first-party", "local-folder"]
section-order = [
"future",
"standard-library",
"third-party",
"data-science",
"ml",
"first-party",
"local-folder",
]
[tool.ruff.lint.isort.sections]
# On sépare les outils de manipulation de données des frameworks de ML lourds
"data-science" = ["numpy", "pandas", "scipy", "matplotlib", "seaborn", "plotly"]
"ml" = ["tensorflow", "keras", "torch", "sklearn", "xgboost", "catboost", "shap"]
"ml" = [
"tensorflow",
"keras",
"torch",
"sklearn",
"xgboost",
"catboost",
"shap",
]