From acd403a14e45d81bffcc694f746c3c9b48098d0c Mon Sep 17 00:00:00 2001 From: Arthur DANJOU Date: Fri, 6 Mar 2026 15:28:31 +0100 Subject: [PATCH] working --- .../Project_RL_DANJOU_VON-SIEMENS.ipynb | 585 +++++++++++++++--- 1 file changed, 500 insertions(+), 85 deletions(-) diff --git a/M2/Reinforcement Learning/project/Project_RL_DANJOU_VON-SIEMENS.ipynb b/M2/Reinforcement Learning/project/Project_RL_DANJOU_VON-SIEMENS.ipynb index 3c9dbd1..11332d9 100644 --- a/M2/Reinforcement Learning/project/Project_RL_DANJOU_VON-SIEMENS.ipynb +++ b/M2/Reinforcement Learning/project/Project_RL_DANJOU_VON-SIEMENS.ipynb @@ -17,7 +17,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 43, "id": "b50d7174", "metadata": {}, "outputs": [ @@ -66,7 +66,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 44, "id": "ff3486a4", "metadata": {}, "outputs": [], @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 45, "id": "be85c130", "metadata": {}, "outputs": [], @@ -186,7 +186,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 46, "id": "ded9b1fb", "metadata": {}, "outputs": [], @@ -240,7 +240,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 47, "id": "78bdc9d2", "metadata": {}, "outputs": [], @@ -273,7 +273,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 48, "id": "c124ed9a", "metadata": {}, "outputs": [], @@ -418,7 +418,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 49, "id": "f5b5b9ea", "metadata": {}, "outputs": [], @@ -540,7 +540,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 50, "id": "7a3aa454", "metadata": {}, "outputs": [], @@ -571,6 +571,7 @@ " gamma: float = 0.99,\n", " seed: int = 42,\n", " ) -> None:\n", + " \"\"\"Initialize Monte Carlo agent with linear weights and episode buffers.\"\"\"\n", " super().__init__(seed, n_actions)\n", " self.n_features = n_features\n", " self.alpha = alpha\n", @@ -584,6 +585,7 @@ " return self.W @ phi\n", "\n", " def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n", + " \"\"\"Select action using ε-greedy policy over linear Q-values.\"\"\"\n", " phi = observation.flatten().astype(np.float32) / np.float32(255.0)\n", " q_vals = self._q_values(phi)\n", " return epsilon_greedy(q_vals, epsilon, self.rng)\n", @@ -597,6 +599,7 @@ " done: bool,\n", " next_action: int | None = None,\n", " ) -> None:\n", + " \"\"\"Accumulate episode transitions and update weights at episode end.\"\"\"\n", " _ = next_state, next_action\n", "\n", " self._obs_buf.append(state)\n", @@ -667,7 +670,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 51, "id": "9c777493", "metadata": {}, "outputs": [], @@ -676,6 +679,7 @@ " \"\"\"MLP Q-network: Input -> 256 -> ReLU -> 256 -> ReLU -> n_actions.\"\"\"\n", "\n", " def __init__(self, n_features: int, n_actions: int) -> None:\n", + " \"\"\"Initialize the Q-network architecture.\"\"\"\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(n_features, 256),\n", @@ -686,6 +690,7 @@ " )\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"Compute Q-values for input state tensor x.\"\"\"\n", " return self.net(x)\n", "\n", "\n", @@ -693,12 +698,15 @@ " \"\"\"Fixed-size circular replay buffer storing (s, a, r, s', done) transitions.\"\"\"\n", "\n", " def __init__(self, capacity: int) -> None:\n", + " \"\"\"Initialize the replay buffer with a maximum capacity.\"\"\"\n", " self.buffer: deque[tuple[np.ndarray, int, float, np.ndarray, bool]] = deque(maxlen=capacity)\n", "\n", " def push(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool) -> None:\n", + " \"\"\"Add a transition to the replay buffer.\"\"\"\n", " self.buffer.append((state, action, reward, next_state, done))\n", "\n", " def sample(self, batch_size: int, rng: np.random.Generator) -> tuple[np.ndarray, ...]:\n", + " \"\"\"Sample a random minibatch of transitions from the buffer.\"\"\"\n", " indices = rng.choice(len(self.buffer), size=batch_size, replace=False)\n", " batch = [self.buffer[i] for i in indices]\n", " states = np.array([t[0] for t in batch])\n", @@ -709,6 +717,7 @@ " return states, actions, rewards, next_states, dones\n", "\n", " def __len__(self) -> int:\n", + " \"\"\"Return the current number of transitions stored in the buffer.\"\"\"\n", " return len(self.buffer)\n", "\n", "\n", @@ -876,7 +885,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 52, "id": "f9a973dd", "metadata": {}, "outputs": [], @@ -908,19 +917,25 @@ "Functions for training and evaluating agents in the single-agent Gymnasium environment:\n", "\n", "1. **`train_agent`** — Pre-trains an agent against the built-in AI for a given number of episodes with ε-greedy exploration\n", - "2. **`evaluate_agent`** — Evaluates a trained agent (no exploration, ε = 0) and returns performance metrics\n", - "3. **`plot_training_curves`** — Plots the training reward history (moving average) for all agents\n", - "4. **`plot_evaluation_comparison`** — Bar chart comparing final evaluation scores across agents\n", - "5. **`evaluate_tournament`** — Evaluates all agents and produces a summary comparison" + "2. **`plot_training_curves`** — Plots the training reward history (moving average) for all agents" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "id": "06b91580", "metadata": {}, "outputs": [], "source": [ + "AGENT_COLORS = {\n", + " \"Random\": \"#95a5a6\",\n", + " \"SARSA\": \"#3498db\",\n", + " \"Q-Learning\": \"#e74c3c\",\n", + " \"MonteCarlo\": \"#2ecc71\",\n", + " \"DQN\": \"#9b59b6\",\n", + "}\n", + "\n", + "\n", "def train_agent(\n", " env: gym.Env,\n", " agent: Agent,\n", @@ -1017,22 +1032,29 @@ " window: Moving average window size.\n", "\n", " \"\"\"\n", - " plt.figure(figsize=(12, 6))\n", + " fig, ax = plt.subplots(figsize=(13, 6))\n", "\n", " for name, rewards in training_histories.items():\n", + " color = AGENT_COLORS.get(name, \"#7f8c8d\")\n", " if len(rewards) >= window:\n", " ma = np.convolve(rewards, np.ones(window) / window, mode=\"valid\")\n", - " plt.plot(np.arange(window - 1, len(rewards)), ma, label=name)\n", + " x = np.arange(window - 1, len(rewards))\n", + " ax.plot(x, ma, label=name, color=color, linewidth=2)\n", + " ax.fill_between(x, ma - np.std(rewards) * 0.3, ma + np.std(rewards) * 0.3,\n", + " color=color, alpha=0.1)\n", " else:\n", - " plt.plot(rewards, label=f\"{name} (raw)\")\n", + " ax.plot(rewards, label=f\"{name} (raw)\", color=color, linewidth=1.5, alpha=0.7)\n", "\n", - " plt.xlabel(\"Episodes\")\n", - " plt.ylabel(f\"Average Reward (Window={window})\")\n", - " plt.title(\"Training Curves (vs built-in AI)\")\n", - " plt.legend()\n", - " plt.grid(visible=True)\n", - " plt.tight_layout()\n", - " plt.savefig(path)\n", + " ax.set_xlabel(\"Episodes\", fontsize=12)\n", + " ax.set_ylabel(f\"Average Reward (window = {window})\", fontsize=12)\n", + " ax.set_title(\"Training Curves (vs Built-in AI)\", fontsize=14, fontweight=\"bold\")\n", + " ax.legend(fontsize=10, framealpha=0.9)\n", + " ax.grid(visible=True, alpha=0.3, linestyle=\"--\")\n", + " ax.axhline(y=0, color=\"gray\", linewidth=0.8, linestyle=\":\", alpha=0.5)\n", + " ax.spines[\"top\"].set_visible(False)\n", + " ax.spines[\"right\"].set_visible(False)\n", + " fig.tight_layout()\n", + " plt.savefig(path, dpi=150, bbox_inches=\"tight\")\n", " plt.show()\n" ] }, @@ -1061,7 +1083,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 54, "id": "6f6ba8df", "metadata": {}, "outputs": [ @@ -1105,7 +1127,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 55, "id": "4d449701", "metadata": {}, "outputs": [ @@ -1117,28 +1139,84 @@ "Checkpoint path: checkpoints/q_learning.pkl\n", "\n", "============================================================\n", - "Training: Q-Learning (5000 episodes)\n", + "Training: Q-Learning (3000 episodes)\n", + "============================================================\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Selected agent: Q-Learning\n", + "Checkpoint path: checkpoints/q_learning.pkl\n", + "\n", + "============================================================\n", + "Training: Q-Learning (3000 episodes)\n", "============================================================\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "543ff46900f84f6fa37ccc6989bbe7f2", + "model_id": "0c3a441131c648749a7d938732265143", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Training Q-Learning: 0%| | 0/5000 [00:00 \u001b[39m\u001b[32m28\u001b[39m training_histories[AGENT_TO_TRAIN] = \u001b[43mtrain_agent\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 29\u001b[39m \u001b[43m \u001b[49m\u001b[43menv\u001b[49m\u001b[43m=\u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 30\u001b[39m \u001b[43m \u001b[49m\u001b[43magent\u001b[49m\u001b[43m=\u001b[49m\u001b[43magent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 31\u001b[39m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m=\u001b[49m\u001b[43mAGENT_TO_TRAIN\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 32\u001b[39m \u001b[43m \u001b[49m\u001b[43mepisodes\u001b[49m\u001b[43m=\u001b[49m\u001b[43mTRAINING_EPISODES\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 33\u001b[39m \u001b[43m \u001b[49m\u001b[43mepsilon_start\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1.0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 34\u001b[39m \u001b[43m \u001b[49m\u001b[43mepsilon_end\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0.05\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 35\u001b[39m \u001b[43m \u001b[49m\u001b[43mepsilon_decay\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0.999\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 36\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 38\u001b[39m avg_last_100 = np.mean(training_histories[AGENT_TO_TRAIN][-\u001b[32m100\u001b[39m:])\n\u001b[32m 39\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m-> \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mAGENT_TO_TRAIN\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m avg reward (last 100 eps): \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mavg_last_100\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.2f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[53]\u001b[39m\u001b[32m, line 45\u001b[39m, in \u001b[36mtrain_agent\u001b[39m\u001b[34m(env, agent, name, episodes, epsilon_start, epsilon_end, epsilon_decay, max_steps)\u001b[39m\n\u001b[32m 42\u001b[39m action = agent.get_action(obs, epsilon=epsilon)\n\u001b[32m 44\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_steps):\n\u001b[32m---> \u001b[39m\u001b[32m45\u001b[39m next_obs, reward, terminated, truncated, _info = \u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 46\u001b[39m next_obs = np.asarray(next_obs)\n\u001b[32m 47\u001b[39m done = terminated \u001b[38;5;129;01mor\u001b[39;00m truncated\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/stateful_observation.py:425\u001b[39m, in \u001b[36mFrameStackObservation.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 414\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 415\u001b[39m \u001b[38;5;28mself\u001b[39m, action: WrapperActType\n\u001b[32m 416\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 417\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Steps through the environment, appending the observation to the frame buffer.\u001b[39;00m\n\u001b[32m 418\u001b[39m \n\u001b[32m 419\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 423\u001b[39m \u001b[33;03m Stacked observations, reward, terminated, truncated, and info from the environment\u001b[39;00m\n\u001b[32m 424\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m425\u001b[39m obs, reward, terminated, truncated, info = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 426\u001b[39m \u001b[38;5;28mself\u001b[39m.obs_queue.append(obs)\n\u001b[32m 428\u001b[39m updated_obs = deepcopy(\n\u001b[32m 429\u001b[39m concatenate(\u001b[38;5;28mself\u001b[39m.env.observation_space, \u001b[38;5;28mself\u001b[39m.obs_queue, \u001b[38;5;28mself\u001b[39m.stacked_obs)\n\u001b[32m 430\u001b[39m )\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/core.py:560\u001b[39m, in \u001b[36mObservationWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 556\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 557\u001b[39m \u001b[38;5;28mself\u001b[39m, action: ActType\n\u001b[32m 558\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 559\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Modifies the :attr:`env` after calling :meth:`step` using :meth:`self.observation` on the returned observations.\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m560\u001b[39m observation, reward, terminated, truncated, info = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 561\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.observation(observation), reward, terminated, truncated, info\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/common.py:393\u001b[39m, in \u001b[36mOrderEnforcing.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 391\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._has_reset:\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m ResetNeeded(\u001b[33m\"\u001b[39m\u001b[33mCannot call env.step() before calling env.reset()\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m393\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/core.py:327\u001b[39m, in \u001b[36mWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 323\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\n\u001b[32m 324\u001b[39m \u001b[38;5;28mself\u001b[39m, action: WrapperActType\n\u001b[32m 325\u001b[39m ) -> \u001b[38;5;28mtuple\u001b[39m[WrapperObsType, SupportsFloat, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[32m 326\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data.\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m327\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/gymnasium/wrappers/common.py:285\u001b[39m, in \u001b[36mPassiveEnvChecker.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 283\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m env_step_passive_checker(\u001b[38;5;28mself\u001b[39m.env, action)\n\u001b[32m 284\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m285\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/ale_py/env.py:305\u001b[39m, in \u001b[36mAtariEnv.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 303\u001b[39m reward = \u001b[32m0.0\u001b[39m\n\u001b[32m 304\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(frameskip):\n\u001b[32m--> \u001b[39m\u001b[32m305\u001b[39m reward += \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43male\u001b[49m\u001b[43m.\u001b[49m\u001b[43mact\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrength\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 307\u001b[39m is_terminal = \u001b[38;5;28mself\u001b[39m.ale.game_over(with_truncation=\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m 308\u001b[39m is_truncated = \u001b[38;5;28mself\u001b[39m.ale.game_truncated()\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] } ], "source": [ - "AGENT_TO_TRAIN = \"Q-Learning\" # TODO: change to: \"Q-Learning\", \"Monte Carlo\", \"Random\"\n", - "TRAINING_EPISODES = 5000\n", + "AGENT_TO_TRAIN = \"Q-Learning\"\n", + "TRAINING_EPISODES = 3000\n", "FORCE_RETRAIN = True\n", "\n", "if AGENT_TO_TRAIN not in agents:\n", @@ -1193,13 +1271,44 @@ ")\n" ] }, + { + "cell_type": "markdown", + "id": "56959ccd", + "metadata": {}, + "source": [ + "# Multi-Agent Evaluation & Tournament" + ] + }, + { + "cell_type": "markdown", + "id": "0698f673", + "metadata": {}, + "source": [ + "### Loading Checkpoints & Running Evaluation\n", + "\n", + "Before evaluation, we load the saved weights for each trained agent from the `checkpoints/` directory. If a checkpoint is missing, the agent will play with its initial weights (zeros), which is effectively random behavior." + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "5315eb02", + "metadata": {}, + "outputs": [], + "source": [ + "for name in [\"SARSA\", \"Q-Learning\"]:\n", + " checkpoint_path = get_path(name)\n", + " if checkpoint_path.exists():\n", + " agents[name].load(str(checkpoint_path))\n", + " else:\n", + " print(f\"Warning: Missing checkpoint for {name}\")\n" + ] + }, { "cell_type": "markdown", "id": "3047c4b0", "metadata": {}, "source": [ - "# Multi-Agent Evaluation & Tournament\n", - "\n", "After individual pre-training against Atari's built-in AI, we move on to **head-to-head evaluation** between our agents.\n", "\n", "**PettingZoo Environment**: unlike Gymnasium (single player vs built-in AI), PettingZoo (`tennis_v3`) allows **two Python agents** to play against each other. The same preprocessing wrappers are applied (grayscale, resize to 84×84, 4-frame stacking) via **SuperSuit**.\n", @@ -1209,17 +1318,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 57, "id": "d04d37e0", "metadata": {}, "outputs": [], "source": [ - "def create_tournament_env():\n", + "def create_tournament_env() -> gym.Env:\n", " \"\"\"Create PettingZoo Tennis env with preprocessing compatible with our agents.\"\"\"\n", - " env = tennis_v3.env(obs_type=\"rgb_image\")\n", - " env = ss.color_reduction_v0(env, mode=\"full\")\n", + " env = tennis_v3.env(obs_type=\"grayscale_image\")\n", " env = ss.resize_v1(env, x_size=84, y_size=84)\n", - " return ss.frame_stack_v1(env, 4)\n", + " env = ss.frame_stack_v1(env, 4)\n", + " return ss.observation_lambda_v0(env, lambda obs, _: np.transpose(obs, (2, 0, 1)))\n", "\n", "\n", "def run_match(\n", @@ -1266,6 +1375,129 @@ " return wins\n" ] }, + { + "cell_type": "markdown", + "id": "8fd8215a", + "metadata": {}, + "source": [ + "### Observation Space Alignment Check: Gymnasium vs PettingZoo" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "fd235ac7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Gymnasium: (4, 84, 84) uint8 0 214\n", + "PettingZoo: (4, 84, 84) uint8 0 214\n" + ] + } + ], + "source": [ + "env_gym = create_env()\n", + "obs_gym, _ = env_gym.reset()\n", + "print(\n", + " \"Gymnasium:\",\n", + " np.asarray(obs_gym).shape,\n", + " np.asarray(obs_gym).dtype,\n", + " np.asarray(obs_gym).min(),\n", + " np.asarray(obs_gym).max(),\n", + ")\n", + "\n", + "env_pz = create_tournament_env()\n", + "env_pz.reset()\n", + "for agent_id in env_pz.agent_iter():\n", + " obs_pz, *_ = env_pz.last()\n", + " print(\n", + " \"PettingZoo:\",\n", + " np.asarray(obs_pz).shape,\n", + " np.asarray(obs_pz).dtype,\n", + " np.asarray(obs_pz).min(),\n", + " np.asarray(obs_pz).max(),\n", + " )\n", + " break\n" + ] + }, + { + "cell_type": "markdown", + "id": "5cfcfdb0", + "metadata": {}, + "source": [ + "## Evaluate against Atari pre-trained agents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bee640bb", + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate_vs_atari(\n", + " env: gym.Env,\n", + " agents: dict[str, Agent],\n", + " episodes: int = 20,\n", + " max_steps: int = 5000,\n", + ") -> dict[str, float]:\n", + " \"\"\"Evaluate each agent against the built-in Atari AI opponent over multiple episodes.\"\"\"\n", + " results = {}\n", + " for name, agent in agents.items():\n", + " total_rewards = []\n", + " for _ in range(episodes):\n", + " obs, _info = env.reset()\n", + " obs = np.asarray(obs)\n", + " ep_reward = 0.0\n", + "\n", + " for _ in range(max_steps):\n", + " action = agent.get_action(obs, epsilon=0.0)\n", + " next_obs, reward, terminated, truncated, _info = env.step(action)\n", + " ep_reward += float(reward)\n", + "\n", + " if terminated or truncated:\n", + " break\n", + "\n", + " obs = np.asarray(next_obs)\n", + "\n", + " total_rewards.append(ep_reward)\n", + "\n", + " avg_reward = float(np.mean(total_rewards))\n", + " results[name] = avg_reward\n", + " print(\n", + " f\"{name} vs Atari AI: Score moyen = {avg_reward:.2f} sur {episodes} épisodes\",\n", + " )\n", + "\n", + " return results\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94a804cc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluation against built-in Atari AI opponent:\n", + "Random vs Atari AI: Score moyen = -23.90 sur 10 épisodes\n", + "SARSA vs Atari AI: Score moyen = 0.00 sur 10 épisodes\n", + "Q-Learning vs Atari AI: Score moyen = 0.00 sur 10 épisodes\n", + "MonteCarlo vs Atari AI: Score moyen = -23.80 sur 10 épisodes\n", + "DQN vs Atari AI: Score moyen = -24.00 sur 10 épisodes\n" + ] + } + ], + "source": [ + "print(\"Evaluation against built-in Atari AI opponent:\")\n", + "atari_results = evaluate_vs_atari(env, agents, episodes=10)\n" + ] + }, { "cell_type": "markdown", "id": "150e6764", @@ -1315,17 +1547,60 @@ " )\n", "\n", " env.close()\n", - " return win_rates\n" - ] - }, - { - "cell_type": "markdown", - "id": "82c24f27", - "metadata": {}, - "source": [ - "### Loading Checkpoints & Running Evaluation\n", + " return win_rates\n", "\n", - "Before evaluation, we load the saved weights for each trained agent from the `checkpoints/` directory. If a checkpoint is missing, the agent will play with its initial weights (zeros), which is effectively random behavior." + "def plot_evaluation_results(\n", + " atari_results: dict[str, float],\n", + " random_win_rates: dict[str, float],\n", + " save_path: str = \"plots/evaluation_results.png\",\n", + ") -> None:\n", + " \"\"\"Plot evaluation results as two side-by-side subplots.\"\"\"\n", + " agent_names = list(atari_results.keys())\n", + " colors = [AGENT_COLORS.get(n, \"#7f8c8d\") for n in agent_names]\n", + "\n", + " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n", + "\n", + " # --- Left panel: Average reward vs Atari AI ---\n", + " atari_scores = [atari_results[n] for n in agent_names]\n", + " x = np.arange(len(agent_names))\n", + " bars1 = ax1.bar(x, atari_scores, color=colors, edgecolor=\"white\", width=0.6)\n", + " ax1.set_xticks(x)\n", + " ax1.set_xticklabels(agent_names, rotation=30, ha=\"right\", fontsize=10)\n", + " ax1.set_ylabel(\"Average Reward\", fontsize=11)\n", + " ax1.set_title(\"Performance vs Atari Built-in AI\", fontsize=13, fontweight=\"bold\")\n", + " ax1.axhline(y=0, color=\"gray\", linestyle=\"--\", linewidth=0.8, alpha=0.5)\n", + " ax1.grid(axis=\"y\", alpha=0.3, linestyle=\"--\")\n", + " ax1.spines[\"top\"].set_visible(False)\n", + " ax1.spines[\"right\"].set_visible(False)\n", + " for bar, score in zip(bars1, atari_scores):\n", + " va = \"bottom\" if score >= 0 else \"top\"\n", + " ax1.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),\n", + " f\"{score:+.1f}\", ha=\"center\", va=va, fontsize=9, fontweight=\"bold\")\n", + "\n", + " # --- Right panel: Win rate vs Random ---\n", + " trained = [n for n in agent_names if n in random_win_rates]\n", + " trained_colors = [AGENT_COLORS.get(n, \"#7f8c8d\") for n in trained]\n", + " rates_pct = [random_win_rates[n] * 100 for n in trained]\n", + " x2 = np.arange(len(trained))\n", + " bars2 = ax2.bar(x2, rates_pct, color=trained_colors, edgecolor=\"white\", width=0.6)\n", + " ax2.set_xticks(x2)\n", + " ax2.set_xticklabels(trained, rotation=30, ha=\"right\", fontsize=10)\n", + " ax2.set_ylabel(\"Win Rate (%)\", fontsize=11)\n", + " ax2.set_title(\"Win Rate vs Random Baseline\", fontsize=13, fontweight=\"bold\")\n", + " ax2.set_ylim(0, 110)\n", + " ax2.axhline(y=50, color=\"gray\", linestyle=\"--\", linewidth=0.8, alpha=0.5, label=\"50 % baseline\")\n", + " ax2.legend(fontsize=8, loc=\"upper right\")\n", + " ax2.grid(axis=\"y\", alpha=0.3, linestyle=\"--\")\n", + " ax2.spines[\"top\"].set_visible(False)\n", + " ax2.spines[\"right\"].set_visible(False)\n", + " for bar, wr in zip(bars2, rates_pct):\n", + " ax2.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 1.5,\n", + " f\"{wr:.0f} %\", ha=\"center\", va=\"bottom\", fontsize=9, fontweight=\"bold\")\n", + "\n", + " fig.suptitle(\"Agent Evaluation Summary\", fontsize=15, fontweight=\"bold\", y=1.02)\n", + " fig.tight_layout()\n", + " plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n", + " plt.show()\n" ] }, { @@ -1333,21 +1608,73 @@ "execution_count": null, "id": "a053644e", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluation against the Random agent\n", + "SARSA vs Random: 0 wins out of 20 decisive matches (Win rate: 0.0%)\n", + "Q-Learning vs Random: 0 wins out of 20 decisive matches (Win rate: 0.0%)\n", + "MonteCarlo vs Random: 8 wins out of 18 decisive matches (Win rate: 44.4%)\n", + "DQN vs Random: 6 wins out of 12 decisive matches (Win rate: 50.0%)\n" + ] + } + ], "source": [ - "for name in [\"SARSA\", \"Q-Learning\"]:\n", - " checkpoint_path = get_path(name)\n", - " if checkpoint_path.exists():\n", - " agents[name].load(str(checkpoint_path))\n", - " else:\n", - " print(f\"Warning: Missing checkpoint for {name}\")\n", - "\n", "print(\"Evaluation against the Random agent\")\n", "win_rates_vs_random = evaluate_vs_random(\n", " agents, random_agent_name=\"Random\", episodes_per_leg=10,\n", ")\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fe97c3d", + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "'float' object is not subscriptable", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[62]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mplot_evaluation_results\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2\u001b[39m \u001b[43m \u001b[49m\u001b[43matari_results\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwin_rates_vs_random\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msave_path\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mplots/evaluation_results.png\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 3\u001b[39m \u001b[43m)\u001b[49m\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[61]\u001b[39m\u001b[32m, line 66\u001b[39m, in \u001b[36mplot_evaluation_results\u001b[39m\u001b[34m(atari_results, random_results, save_path)\u001b[39m\n\u001b[32m 63\u001b[39m fig, (ax1, ax2) = plt.subplots(\u001b[32m1\u001b[39m, \u001b[32m2\u001b[39m, figsize=(\u001b[32m14\u001b[39m, \u001b[32m5\u001b[39m))\n\u001b[32m 65\u001b[39m \u001b[38;5;66;03m# --- Left: Average reward vs Atari AI with error bars ---\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m66\u001b[39m means = [\u001b[43matari_results\u001b[49m\u001b[43m[\u001b[49m\u001b[43mn\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mmean\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m n \u001b[38;5;129;01min\u001b[39;00m agent_names]\n\u001b[32m 67\u001b[39m stds = [atari_results[n][\u001b[33m\"\u001b[39m\u001b[33mstd\u001b[39m\u001b[33m\"\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m n \u001b[38;5;129;01min\u001b[39;00m agent_names]\n\u001b[32m 68\u001b[39m x = np.arange(\u001b[38;5;28mlen\u001b[39m(agent_names))\n", + "\u001b[31mTypeError\u001b[39m: 'float' object is not subscriptable" + ] + }, + { + "ename": "TypeError", + "evalue": "'float' object is not subscriptable", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[62]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mplot_evaluation_results\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2\u001b[39m \u001b[43m \u001b[49m\u001b[43matari_results\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwin_rates_vs_random\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msave_path\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mplots/evaluation_results.png\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 3\u001b[39m \u001b[43m)\u001b[49m\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[61]\u001b[39m\u001b[32m, line 66\u001b[39m, in \u001b[36mplot_evaluation_results\u001b[39m\u001b[34m(atari_results, random_results, save_path)\u001b[39m\n\u001b[32m 63\u001b[39m fig, (ax1, ax2) = plt.subplots(\u001b[32m1\u001b[39m, \u001b[32m2\u001b[39m, figsize=(\u001b[32m14\u001b[39m, \u001b[32m5\u001b[39m))\n\u001b[32m 65\u001b[39m \u001b[38;5;66;03m# --- Left: Average reward vs Atari AI with error bars ---\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m66\u001b[39m means = [\u001b[43matari_results\u001b[49m\u001b[43m[\u001b[49m\u001b[43mn\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mmean\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m n \u001b[38;5;129;01min\u001b[39;00m agent_names]\n\u001b[32m 67\u001b[39m stds = [atari_results[n][\u001b[33m\"\u001b[39m\u001b[33mstd\u001b[39m\u001b[33m\"\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m n \u001b[38;5;129;01min\u001b[39;00m agent_names]\n\u001b[32m 68\u001b[39m x = np.arange(\u001b[38;5;28mlen\u001b[39m(agent_names))\n", + "\u001b[31mTypeError\u001b[39m: 'float' object is not subscriptable" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABHsAAAGyCAYAAAB0jsg1AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAIQZJREFUeJzt3X2MFtXZB+CzgICmgloKCEWpWkWLgoJsQY2xoZJosPzRlKoBSvyo1RoLaQVEwW+srxqSukpErf5RC2rEGCFYpRJjpSGCJNoKRlGhRhao5aOooDBvZprdsrhYF3efnb33upIRZnZmn7PPkZ17f3vmnKosy7IEAAAAQAgdWrsBAAAAADQfYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAew57XnrppTR69OjUp0+fVFVVlZ5++un/ec3SpUvTaaedlrp06ZKOO+649MgjjxxoewEA2hS1EwBQ+rBnx44dadCgQammpuYrnf/uu++m888/P51zzjlp1apV6Ve/+lW69NJL03PPPXcg7QUAaFPUTgBApVVlWZYd8MVVVWnBggVpzJgx+z1nypQpaeHChemNN96oP/bTn/40bdmyJS1evPhAXxoAoM1ROwEAldCppV9g2bJlaeTIkQ2OjRo1qhjhsz87d+4stjp79uxJH330UfrmN79ZFEkAQDnlv0Pavn178bh3hw6mBqxU7ZRTPwFA25S1QP3U4mHPhg0bUq9evRocy/e3bduWPvnkk3TwwQd/4ZpZs2alm266qaWbBgC0kPXr16dvf/vb3t8K1U459RMAtG3rm7F+avGw50BMmzYtTZ48uX5/69at6aijjiq+8G7durVq2wCA/csDiX79+qVDDz3U21Rh6icAaJu2tUD91OJhT+/evVNtbW2DY/l+Htrs7zdT+apd+bav/BphDwCUn8euK1s75dRPANC2VTXjtDUt/jD98OHD05IlSxoce/7554vjAAConQCA5tXksOff//53sYR6vtUtrZ7/fd26dfVDiMePH19//hVXXJHWrl2brr322rR69ep03333pccffzxNmjSpOb8OAIBSUjsBAKUPe1599dV06qmnFlsun1sn//uMGTOK/Q8//LA++Ml95zvfKZZez0fzDBo0KN19993pwQcfLFaVAACITu0EAFRaVZav8dUGJivq3r17MVGzOXsAoLzcs8tDXwBA+71nt/icPQAAAABUjrAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAoL2HPTU1Nal///6pa9euqbq6Oi1fvvxLz589e3Y64YQT0sEHH5z69euXJk2alD799NMDbTMAQJujfgIAShv2zJ8/P02ePDnNnDkzrVy5Mg0aNCiNGjUqbdy4sdHzH3vssTR16tTi/DfffDM99NBDxee47rrrmqP9AAClp34CAEod9txzzz3psssuSxMnTkwnnXRSmjNnTjrkkEPSww8/3Oj5r7zySjrjjDPSRRddVIwGOvfcc9OFF174P0cDAQBEoX4CAEob9uzatSutWLEijRw58r+foEOHYn/ZsmWNXjNixIjimrpwZ+3atWnRokXpvPPO2+/r7Ny5M23btq3BBgDQFqmfAIBK69SUkzdv3px2796devXq1eB4vr969epGr8lH9OTXnXnmmSnLsvT555+nK6644ksf45o1a1a66aabmtI0AIBSUj8BAOFW41q6dGm6/fbb03333VfM8fPUU0+lhQsXpltuuWW/10ybNi1t3bq1flu/fn1LNxMAoDTUTwBAxUb29OjRI3Xs2DHV1tY2OJ7v9+7du9FrbrjhhjRu3Lh06aWXFvsnn3xy2rFjR7r88svT9OnTi8fA9tWlS5diAwBo69RPAECpR/Z07tw5DRkyJC1ZsqT+2J49e4r94cOHN3rNxx9//IVAJw+McvljXQAAkamfAIBSj+zJ5cuuT5gwIQ0dOjQNGzYszZ49uxipk6/OlRs/fnzq27dvMe9ObvTo0cUKFKeeemqqrq5Ob7/9djHaJz9eF/oAAESmfgIASh32jB07Nm3atCnNmDEjbdiwIQ0ePDgtXry4ftLmdevWNRjJc/3116eqqqrizw8++CB961vfKoKe2267rXm/EgCAklI/AQCVVJW1gWep8qXXu3fvXkzW3K1bt9ZuDgCwH+7Z5aEvAKD93rNbfDUuAAAAACpH2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAAAQiLAHAAAAIBBhDwAAAEAgwh4AAACAQIQ9AAAAAIEIewAAAAACEfYAAAAABCLsAQAAAAhE2AMAAADQ3sOempqa1L9//9S1a9dUXV2dli9f/qXnb9myJV111VXpyCOPTF26dEnHH398WrRo0YG2GQCgzVE/AQCV0qmpF8yfPz9Nnjw5zZkzpwh6Zs+enUaNGpXWrFmTevbs+YXzd+3alX74wx8WH3vyySdT37590/vvv58OO+yw5voaAABKTf0EAFRSVZZlWVMuyAOe008/Pd17773F/p49e1K/fv3S1VdfnaZOnfqF8/NQ6P/+7//S6tWr00EHHXRAjdy2bVvq3r172rp1a+rWrdsBfQ4AoOW5ZzdO/QQAVLJ+atJjXPkonRUrVqSRI0f+9xN06FDsL1u2rNFrnnnmmTR8+PDiMa5evXqlgQMHpttvvz3t3r17v6+zc+fO4ovdewMAaIvUTwBApTUp7Nm8eXMR0uShzd7y/Q0bNjR6zdq1a4vHt/Lr8nl6brjhhnT33XenW2+9db+vM2vWrCLVqtvykUMAAG2R+gkACLcaV/6YVz5fzwMPPJCGDBmSxo4dm6ZPn1483rU/06ZNK4Yv1W3r169v6WYCAJSG+gkAqNgEzT169EgdO3ZMtbW1DY7n+7179270mnwFrnyunvy6OieeeGIxEigf1ty5c+cvXJOv2JVvAABtnfoJACj1yJ48mMlH5yxZsqTBb57y/XxensacccYZ6e233y7Oq/PWW28VIVBjQQ8AQCTqJwCg9I9x5cuuz507Nz366KPpzTffTL/4xS/Sjh070sSJE4uPjx8/vngMq07+8Y8++ihdc801RcizcOHCYoLmfMJmAID2QP0EAJT2Ma5cPufOpk2b0owZM4pHsQYPHpwWL15cP2nzunXrihW66uSTKz/33HNp0qRJ6ZRTTkl9+/Ytgp8pU6Y071cCAFBS6icAoJKqsizLUjtccx4AaH7u2eWhLwCg/d6zW3w1LgAAAAAqR9gDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAAEIiwBwAAACAQYQ8AAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAhH2AAAAAAQi7AEAAAAIRNgDAAAA0N7DnpqamtS/f//UtWvXVF1dnZYvX/6Vrps3b16qqqpKY8aMOZCXBQBos9RPAEBpw5758+enyZMnp5kzZ6aVK1emQYMGpVGjRqWNGzd+6XXvvfde+vWvf53OOuusr9NeAIA2R/0EAJQ67LnnnnvSZZddliZOnJhOOumkNGfOnHTIIYekhx9+eL/X7N69O1188cXppptuSsccc8zXbTMAQJuifgIAShv27Nq1K61YsSKNHDnyv5+gQ4dif9myZfu97uabb049e/ZMl1xyyVd6nZ07d6Zt27Y12AAA2iL1EwBQ6rBn8+bNxSidXr16NTie72/YsKHRa15++eX00EMPpblz537l15k1a1bq3r17/davX7+mNBMAoDTUTwBAqNW4tm/fnsaNG1cEPT169PjK102bNi1t3bq1flu/fn1LNhMAoDTUTwDA19WpKSfngU3Hjh1TbW1tg+P5fu/evb9w/jvvvFNMzDx69Oj6Y3v27PnPC3fqlNasWZOOPfbYL1zXpUuXYgMAaOvUTwBAqUf2dO7cOQ0ZMiQtWbKkQXiT7w8fPvwL5w8YMCC9/vrradWqVfXbBRdckM4555zi7x7PAgCiUz8BAKUe2ZPLl12fMGFCGjp0aBo2bFiaPXt22rFjR7E6V278+PGpb9++xbw7Xbt2TQMHDmxw/WGHHVb8ue9xAICo1E8AQKnDnrFjx6ZNmzalGTNmFJMyDx48OC1evLh+0uZ169YVK3QBAKB+AgAqryrLsiyVXL70er4qVz5Zc7du3Vq7OQDAfrhnl4e+AID2e882BAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQHsPe2pqalL//v1T165dU3V1dVq+fPl+z507d24666yz0uGHH15sI0eO/NLzAQAiUj8BAKUNe+bPn58mT56cZs6cmVauXJkGDRqURo0alTZu3Njo+UuXLk0XXnhhevHFF9OyZctSv3790rnnnps++OCD5mg/AEDpqZ8AgEqqyrIsa8oF+Uie008/Pd17773F/p49e4oA5+qrr05Tp079n9fv3r27GOGTXz9+/Piv9Jrbtm1L3bt3T1u3bk3dunVrSnMBgApyz26c+gkAqGT91KSRPbt27UorVqwoHsWq/wQdOhT7+aidr+Ljjz9On332WTriiCP2e87OnTuLL3bvDQCgLVI/AQCV1qSwZ/PmzcXInF69ejU4nu9v2LDhK32OKVOmpD59+jQIjPY1a9asItWq2/KRQwAAbZH6CQAIvRrXHXfckebNm5cWLFhQTO68P9OmTSuGL9Vt69evr2QzAQBKQ/0EADRVp6ac3KNHj9SxY8dUW1vb4Hi+37t37y+99q677iqKlRdeeCGdcsopX3puly5dig0AoK1TPwEApR7Z07lz5zRkyJC0ZMmS+mP5BM35/vDhw/d73Z133pluueWWtHjx4jR06NCv12IAgDZE/QQAlHpkTy5fdn3ChAlFaDNs2LA0e/bstGPHjjRx4sTi4/kKW3379i3m3cn99re/TTNmzEiPPfZY6t+/f/3cPt/4xjeKDQAgOvUTAFDqsGfs2LFp06ZNRYCTBzeDBw8uRuzUTdq8bt26YoWuOvfff3+xCsWPf/zjBp9n5syZ6cYbb2yOrwEAoNTUTwBAJVVlWZaldrjmPADQ/Nyzy0NfAED7vWdXdDUuAAAAAFqWsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAACgvYc9NTU1qX///qlr166puro6LV++/EvPf+KJJ9KAAQOK808++eS0aNGiA20vAECbpH4CAEob9syfPz9Nnjw5zZw5M61cuTINGjQojRo1Km3cuLHR81955ZV04YUXpksuuSS99tpracyYMcX2xhtvNEf7AQBKT/0EAFRSVZZlWVMuyEfynH766enee+8t9vfs2ZP69euXrr766jR16tQvnD927Ni0Y8eO9Oyzz9Yf+/73v58GDx6c5syZ85Vec9u2bal79+5p69atqVu3bk1pLgBQQe7ZjVM/AQCVrJ86NeXkXbt2pRUrVqRp06bVH+vQoUMaOXJkWrZsWaPX5MfzkUB7y0cCPf300/t9nZ07dxZbnfwLrnsDAIDyqrtXN/F3SaGpnwCAStdPTQp7Nm/enHbv3p169erV4Hi+v3r16kav2bBhQ6Pn58f3Z9asWemmm276wvF8BBEAUH7//Oc/i99QoX4CACpfPzUp7KmUfOTQ3qOBtmzZko4++ui0bt06hWMrp4154LZ+/XqP07UyfVEe+qIc9EN55KNxjzrqqHTEEUe0dlPaHfVTOfn+VB76ohz0Q3noi9j1U5PCnh49eqSOHTum2traBsfz/d69ezd6TX68KefnunTpUmz7yhMuc/a0vrwP9EM56Ivy0BfloB/KI3/Mm/9QP5Hz/ak89EU56Ify0Bcx66cmfabOnTunIUOGpCVLltQfyydozveHDx/e6DX58b3Pzz3//PP7PR8AIBL1EwBQaU1+jCt/vGrChAlp6NChadiwYWn27NnFalsTJ04sPj5+/PjUt2/fYt6d3DXXXJPOPvvsdPfdd6fzzz8/zZs3L7366qvpgQceaP6vBgCghNRPAECpw558KfVNmzalGTNmFJMs50uoL168uH4S5nxenb2HHo0YMSI99thj6frrr0/XXXdd+u53v1usxDVw4MCv/Jr5I10zZ85s9NEuKkc/lIe+KA99UQ76oTz0RePUT+2XfxPloS/KQT+Uh76I3RdVmbVRAQAAAMIweyIAAABAIMIeAAAAgECEPQAAAACBCHsAAAAAAilN2FNTU5P69++funbtmqqrq9Py5cu/9PwnnngiDRgwoDj/5JNPTosWLapYWyNrSj/MnTs3nXXWWenwww8vtpEjR/7PfqNl+mJv8+bNS1VVVWnMmDHe7lbqiy1btqSrrroqHXnkkcWM+scff7zvUa3QD7Nnz04nnHBCOvjgg1O/fv3SpEmT0qefftocTWm3XnrppTR69OjUp0+f4vtMvrrm/7J06dJ02mmnFf8WjjvuuPTII49UpK3tgdqpPNRP5aF+Kge1U3mon9px/ZSVwLx587LOnTtnDz/8cPa3v/0tu+yyy7LDDjssq62tbfT8v/zlL1nHjh2zO++8M/v73/+eXX/99dlBBx2Uvf766xVveyRN7YeLLrooq6mpyV577bXszTffzH72s59l3bt3z/7xj39UvO3tvS/qvPvuu1nfvn2zs846K/vRj35UsfZG1tS+2LlzZzZ06NDsvPPOy15++eWiT5YuXZqtWrWq4m1vz/3whz/8IevSpUvxZ94Hzz33XHbkkUdmkyZNqnjbI1m0aFE2ffr07KmnnsryEmLBggVfev7atWuzQw45JJs8eXJxv/7d735X3L8XL15csTZHpXYqD/VTeaifykHtVB7qp/ZdP5Ui7Bk2bFh21VVX1e/v3r0769OnTzZr1qxGz//JT36SnX/++Q2OVVdXZz//+c9bvK2RNbUf9vX5559nhx56aPboo4+2YCvbhwPpi/z9HzFiRPbggw9mEyZMEPa0Ul/cf//92THHHJPt2rWruZrAAfRDfu4PfvCDBsfyG+YZZ5zh/WwmX6VYufbaa7Pvfe97DY6NHTs2GzVqlH74mtRO5aF+Kg/1UzmoncpD/dS+66dWf4xr165dacWKFcUjQHU6dOhQ7C9btqzRa/Lje5+fGzVq1H7Pp2X6YV8ff/xx+uyzz9IRRxzhLW+Fvrj55ptTz5490yWXXOL9b8W+eOaZZ9Lw4cOLx7h69eqVBg4cmG6//fa0e/du/VLBfhgxYkRxTd2jXmvXri0epTvvvPP0QwW5X7cMtVN5qJ/KQ/1UDmqn8lA/tV3NVT91Sq1s8+bNxQ9B+Q9Fe8v3V69e3eg1GzZsaPT8/DiV64d9TZkypXgOcd//MWn5vnj55ZfTQw89lFatWuXtbuW+yEOFP//5z+niiy8uwoW33347XXnllUUQOnPmTP1ToX646KKLiuvOPPPMfARr+vzzz9MVV1yRrrvuOn1QQfu7X2/bti198sknxXxKNJ3aqTzUT+WhfioHtVN5qJ/aruaqn1p9ZA8x3HHHHcXEwAsWLCgmT6Vytm/fnsaNG1dMmN2jRw9vfSvbs2dPMcLqgQceSEOGDEljx45N06dPT3PmzGntprUr+aR2+Yiq++67L61cuTI99dRTaeHChemWW25p7aYB1FM/tR71U3moncpD/RRLq4/syX847dixY6qtrW1wPN/v3bt3o9fkx5tyPi3TD3Xuuuuuolh54YUX0imnnOLtrnBfvPPOO+m9994rZnjf+6aZ69SpU1qzZk069thj9UsF+iKXr8B10EEHFdfVOfHEE4uEPh9O27lzZ31RgX644YYbihD00ksvLfbzVRt37NiRLr/88iJ8yx8Do+Xt737drVs3o3q+BrVTeaifykP9VA5qp/JQP7VdzVU/tXq1m//gk//2e8mSJQ1+UM3383kvGpMf3/v83PPPP7/f82mZfsjdeeedxW/KFy9enIYOHeqtboW+GDBgQHr99deLR7jqtgsuuCCdc845xd/zJaepTF/kzjjjjOLRrbrALffWW28VIZCgp3L9kM8htm+gUxfA/WduPCrB/bplqJ3KQ/1UHuqnclA7lYf6qe1qtvopK8mScPkSuY888kixtNjll19eLKm7YcOG4uPjxo3Lpk6d2mDp9U6dOmV33XVXseT3zJkzLb3eCv1wxx13FEshP/nkk9mHH35Yv23fvr05mtOuNbUv9mU1rtbri3Xr1hWr0v3yl7/M1qxZkz377LNZz549s1tvvbUZW9X+NLUf8vtC3g9//OMfi+Ur//SnP2XHHntssZojBy7//v7aa68VW15C3HPPPcXf33///eLjeR/kfbHv0qG/+c1vivt1TU2NpdebidqpPNRP5aF+Kge1U3mon9p3/VSKsCeXrx1/1FFHFeFBvkTcX//61/qPnX322cUPr3t7/PHHs+OPP744P1+WbOHCha3Q6nia0g9HH3108T/rvlv+QxaV7Yt9CXtaty9eeeWVrLq6uggn8mXYb7vttuzzzz9v5la1P03ph88++yy78cYbi4Cna9euWb9+/bIrr7wy+9e//tVKrY/hxRdfbPT7ft17n/+Z98W+1wwePLjot/zfw+9///tWan08aqfyUD+Vh/qpHNRO5aF+ar/1U1X+n+YddAQAAABAa2n1OXsAAAAAaD7CHgAAAIBAhD0AAAAAgQh7AAAAAAIR9gAAAAAEIuwBAAAACETYAwAAABCIsAcAAAAgEGEPAAAAQCDCHgAAAIBAhD0AAAAAgQh7AAAAAFIc/w/JLyUiKOJTlAAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_evaluation_results(\n", + " atari_results, win_rates_vs_random, save_path=\"plots/evaluation_results.png\",\n", + ")\n" + ] + }, { "cell_type": "markdown", "id": "b9cb30f5", @@ -1371,39 +1698,78 @@ "metadata": {}, "outputs": [], "source": [ - "def run_championship(\n", - " agent1_name: str,\n", - " agent2_name: str,\n", + "def run_all_matchups(\n", " agents: dict[str, Agent],\n", " episodes_per_leg: int = 20,\n", - ") -> None:\n", - " \"\"\"Run a full championship between two agents, playing multiple legs with swapped positions.\"\"\"\n", + ") -> np.ndarray:\n", + " \"\"\"Run a full round-robin tournament and return an n*n win matrix.\n", + "\n", + " matrix[i][j] = wins of agent_i (as first_0) vs agent_j (as second_0).\n", + " Diagonal entries are NaN (no self-play).\n", + " \"\"\"\n", " env = create_tournament_env()\n", - " agent1 = agents[agent1_name]\n", - " agent2 = agents[agent2_name]\n", + " names = list(agents.keys())\n", + " n = len(names)\n", + " matrix = np.full((n, n), np.nan)\n", "\n", - " # Leg 1: Agent 1 plays first_0, Agent 2 plays second_0\n", - " leg1 = run_match(env, agent1, agent2, episodes=episodes_per_leg)\n", - " # Leg 2: Swap starting positions\n", - " leg2 = run_match(env, agent2, agent1, episodes=episodes_per_leg)\n", + " for i in range(n):\n", + " for j in range(n):\n", + " if i == j:\n", + " continue\n", + " result = run_match(\n", + " env, agents[names[i]], agents[names[j]], episodes=episodes_per_leg,\n", + " )\n", + " matrix[i, j] = result[\"first\"]\n", "\n", - " wins_agent1 = leg1[\"first\"] + leg2[\"second\"]\n", - " wins_agent2 = leg1[\"second\"] + leg2[\"first\"]\n", - " draws = leg1[\"draw\"] + leg2[\"draw\"]\n", + " env.close()\n", + " return matrix\n", "\n", - " print(f\"--- Final Result: {agent1_name} vs {agent2_name} ---\")\n", - " print(f\"{agent1_name} wins: {wins_agent1}\")\n", - " print(f\"{agent2_name} wins: {wins_agent2}\")\n", - " print(f\"Draws: {draws}\")\n", "\n", - " if wins_agent1 > wins_agent2:\n", - " print(f\"The winner is {agent1_name}!\")\n", - " elif wins_agent2 > wins_agent1:\n", - " print(f\"The winner is {agent2_name}!\")\n", - " else:\n", - " print(\"Perfect tie!\")\n", + "def plot_matchup_matrix(\n", + " matrix: np.ndarray,\n", + " agent_names: list[str],\n", + " episodes_per_leg: int = 20,\n", + " save_path: str = \"plots/championship_matrix.png\",\n", + ") -> None:\n", + " \"\"\"Plot the n*n matchup matrix as an annotated heatmap.\"\"\"\n", + " n = len(agent_names)\n", + " pct_matrix = matrix / episodes_per_leg * 100\n", "\n", - " env.close()\n" + " fig, ax = plt.subplots(figsize=(8, 7))\n", + " masked = np.ma.array(pct_matrix, mask=np.isnan(pct_matrix))\n", + " cmap = plt.cm.RdYlGn.copy()\n", + " cmap.set_bad(color=\"#f0f0f0\")\n", + "\n", + " im = ax.imshow(masked, cmap=cmap, vmin=0, vmax=100, aspect=\"equal\")\n", + "\n", + " ax.set_xticks(np.arange(n))\n", + " ax.set_yticks(np.arange(n))\n", + " ax.set_xticklabels(agent_names, rotation=45, ha=\"right\", fontsize=10)\n", + " ax.set_yticklabels(agent_names, fontsize=10)\n", + " ax.set_xlabel(\"Opponent (second_0)\", fontsize=11)\n", + " ax.set_ylabel(\"Agent (first_0)\", fontsize=11)\n", + " ax.set_title(\"Championship Matchup Matrix\", fontsize=14, fontweight=\"bold\", pad=12)\n", + "\n", + " for i in range(n):\n", + " for j in range(n):\n", + " if i == j:\n", + " ax.text(j, i, \"—\", ha=\"center\", va=\"center\", color=\"#aaa\", fontsize=14)\n", + " else:\n", + " wins = int(matrix[i, j])\n", + " pct = pct_matrix[i, j]\n", + " txt_color = \"white\" if pct > 70 or pct < 30 else \"black\"\n", + " ax.text(j, i, f\"{wins}\\n({pct:.0f} %)\",\n", + " ha=\"center\", va=\"center\", color=txt_color,\n", + " fontsize=9, fontweight=\"bold\", linespacing=1.4)\n", + "\n", + " cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)\n", + " cbar.set_label(f\"Win Rate (%) — {episodes_per_leg} eps per matchup\", fontsize=10)\n", + " cbar.set_ticks([0, 25, 50, 75, 100])\n", + " cbar.set_ticklabels([\"0 %\", \"25 %\", \"50 %\", \"75 %\", \"100 %\"])\n", + "\n", + " plt.tight_layout()\n", + " plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n", + " plt.show()\n" ] }, { @@ -1411,10 +1777,59 @@ "execution_count": null, "id": "b07b403c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/lambda_wrappers/observation_lambda.py:68\u001b[39m, in \u001b[36maec_observation_lambda._modify_observation\u001b[39m\u001b[34m(self, agent, observation)\u001b[39m\n\u001b[32m 67\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m68\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mchange_observation_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobservation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mold_obs_space\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magent\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 69\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n", + "\u001b[31mTypeError\u001b[39m: basic_obs_wrapper..change_obs() takes 2 positional arguments but 3 were given", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[41]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m matchup_matrix = \u001b[43mrun_all_matchups\u001b[49m\u001b[43m(\u001b[49m\u001b[43magents\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepisodes_per_leg\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m20\u001b[39;49m\u001b[43m)\u001b[49m\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[40]\u001b[39m\u001b[32m, line 19\u001b[39m, in \u001b[36mrun_all_matchups\u001b[39m\u001b[34m(agents, episodes_per_leg)\u001b[39m\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m i == j:\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m19\u001b[39m result = \u001b[43mrun_match\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 20\u001b[39m \u001b[43m \u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magents\u001b[49m\u001b[43m[\u001b[49m\u001b[43mnames\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43magents\u001b[49m\u001b[43m[\u001b[49m\u001b[43mnames\u001b[49m\u001b[43m[\u001b[49m\u001b[43mj\u001b[49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepisodes\u001b[49m\u001b[43m=\u001b[49m\u001b[43mepisodes_per_leg\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 21\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 22\u001b[39m matrix[i, j] = result[\u001b[33m\"\u001b[39m\u001b[33mfirst\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 24\u001b[39m env.close()\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[28]\u001b[39m\u001b[32m, line 37\u001b[39m, in \u001b[36mrun_match\u001b[39m\u001b[34m(env, agent_first, agent_second, episodes, max_steps)\u001b[39m\n\u001b[32m 34\u001b[39m current_agent = agent_first \u001b[38;5;28;01mif\u001b[39;00m agent_id == \u001b[33m\"\u001b[39m\u001b[33mfirst_0\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m agent_second\n\u001b[32m 35\u001b[39m action = current_agent.get_action(np.asarray(obs), epsilon=\u001b[32m0.0\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m37\u001b[39m \u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 39\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m step_idx + \u001b[32m1\u001b[39m >= max_steps:\n\u001b[32m 40\u001b[39m \u001b[38;5;28;01mbreak\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/utils/base_aec_wrapper.py:46\u001b[39m, in \u001b[36mBaseWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 43\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m.terminations[agent] \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m.truncations[agent]):\n\u001b[32m 44\u001b[39m action = \u001b[38;5;28mself\u001b[39m._modify_action(agent, action)\n\u001b[32m---> \u001b[39m\u001b[32m46\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 48\u001b[39m \u001b[38;5;28mself\u001b[39m._update_step(\u001b[38;5;28mself\u001b[39m.agent_selection)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/order_enforcing.py:70\u001b[39m, in \u001b[36mOrderEnforcingWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 68\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 69\u001b[39m \u001b[38;5;28mself\u001b[39m._has_updated = \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m70\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/base.py:47\u001b[39m, in \u001b[36mBaseWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 46\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\u001b[38;5;28mself\u001b[39m, action: ActionType) -> \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m47\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/generic_wrappers/utils/shared_wrapper_util.py:69\u001b[39m, in \u001b[36mshared_wrapper_aec.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 66\u001b[39m \u001b[38;5;28msuper\u001b[39m().step(action)\n\u001b[32m 67\u001b[39m \u001b[38;5;28mself\u001b[39m.add_modifiers(\u001b[38;5;28mself\u001b[39m.agents)\n\u001b[32m 68\u001b[39m \u001b[38;5;28mself\u001b[39m.modifiers[\u001b[38;5;28mself\u001b[39m.agent_selection].modify_obs(\n\u001b[32m---> \u001b[39m\u001b[32m69\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobserve\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43magent_selection\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 70\u001b[39m )\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/order_enforcing.py:75\u001b[39m, in \u001b[36mOrderEnforcingWrapper.observe\u001b[39m\u001b[34m(self, agent)\u001b[39m\n\u001b[32m 73\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._has_reset:\n\u001b[32m 74\u001b[39m EnvLogger.error_observe_before_reset()\n\u001b[32m---> \u001b[39m\u001b[32m75\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobserve\u001b[49m\u001b[43m(\u001b[49m\u001b[43magent\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/base.py:41\u001b[39m, in \u001b[36mBaseWrapper.observe\u001b[39m\u001b[34m(self, agent)\u001b[39m\n\u001b[32m 40\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mobserve\u001b[39m(\u001b[38;5;28mself\u001b[39m, agent: AgentID) -> ObsType | \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m41\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobserve\u001b[49m\u001b[43m(\u001b[49m\u001b[43magent\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/utils/base_aec_wrapper.py:38\u001b[39m, in \u001b[36mBaseWrapper.observe\u001b[39m\u001b[34m(self, agent)\u001b[39m\n\u001b[32m 34\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mobserve\u001b[39m(\u001b[38;5;28mself\u001b[39m, agent):\n\u001b[32m 35\u001b[39m obs = \u001b[38;5;28msuper\u001b[39m().observe(\n\u001b[32m 36\u001b[39m agent\n\u001b[32m 37\u001b[39m ) \u001b[38;5;66;03m# problem is in this line, the obs is sometimes a different size from the obs space\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m38\u001b[39m observation = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_modify_observation\u001b[49m\u001b[43m(\u001b[49m\u001b[43magent\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 39\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m observation\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/lambda_wrappers/observation_lambda.py:70\u001b[39m, in \u001b[36maec_observation_lambda._modify_observation\u001b[39m\u001b[34m(self, agent, observation)\u001b[39m\n\u001b[32m 68\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.change_observation_fn(observation, old_obs_space, agent)\n\u001b[32m 69\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m70\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mchange_observation_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobservation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mold_obs_space\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/generic_wrappers/basic_wrappers.py:19\u001b[39m, in \u001b[36mbasic_obs_wrapper..change_obs\u001b[39m\u001b[34m(obs, obs_space)\u001b[39m\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mchange_obs\u001b[39m(obs, obs_space):\n\u001b[32m---> \u001b[39m\u001b[32m19\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodule\u001b[49m\u001b[43m.\u001b[49m\u001b[43mchange_observation\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobs_space\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparam\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/supersuit/utils/basic_transforms/resize.py:21\u001b[39m, in \u001b[36mchange_observation\u001b[39m\u001b[34m(obs, obs_space, resize)\u001b[39m\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mchange_obs_space\u001b[39m(obs_space, param):\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m convert_box(\u001b[38;5;28;01mlambda\u001b[39;00m obs: change_observation(obs, obs_space, param), obs_space)\n\u001b[32m---> \u001b[39m\u001b[32m21\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mchange_observation\u001b[39m(obs, obs_space, resize):\n\u001b[32m 22\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtinyscaler\u001b[39;00m\n\u001b[32m 24\u001b[39m xsize, ysize, linear_interp = resize\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + } + ], "source": [ - "print(\"Championship match between the two trained agents\")\n", - "run_championship(\"SARSA\", \"Q-Learning\", agents, episodes_per_leg=20)\n" + "matchup_matrix = run_all_matchups(agents, episodes_per_leg=20)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9668071a", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'matchup_matrix' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[42]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m plot_matchup_matrix(\u001b[43mmatchup_matrix\u001b[49m, episodes_per_leg=EPISODES_PER_LEG)\n", + "\u001b[31mNameError\u001b[39m: name 'matchup_matrix' is not defined" + ] + } + ], + "source": [ + "plot_matchup_matrix(matchup_matrix, list(agents.keys()), episodes_per_leg=20)\n" ] }, {