From ae1dacc42c1d46a435dedc1b6ea9ceba428bd487 Mon Sep 17 00:00:00 2001 From: Arthur DANJOU Date: Mon, 2 Mar 2026 12:21:25 +0100 Subject: [PATCH] Add PettingZoo tournament functionality for agent vs agent matches - Introduced functions to create a Tennis environment and run matches between agents. - Implemented a round-robin tournament format excluding random agents. - Added win-rate matrix visualization and final ranking of agents based on performance. - Updated imports to include necessary libraries for the new functionality. --- .../project/Project.ipynb | 204 ++++++++++++++++-- 1 file changed, 189 insertions(+), 15 deletions(-) diff --git a/M2/Reinforcement Learning/project/Project.ipynb b/M2/Reinforcement Learning/project/Project.ipynb index c225b6c..7f8cf02 100644 --- a/M2/Reinforcement Learning/project/Project.ipynb +++ b/M2/Reinforcement Learning/project/Project.ipynb @@ -19,17 +19,20 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "b50d7174", "metadata": {}, "outputs": [], "source": [ + "import itertools\n", "import pickle\n", "from pathlib import Path\n", "\n", "import ale_py # noqa: F401 — registers ALE environments\n", "import gymnasium as gym\n", + "import supersuit as ss\n", "from gymnasium.wrappers import FrameStackObservation, ResizeObservation\n", + "from pettingzoo.atari import tennis_v3\n", "from tqdm.auto import tqdm\n", "\n", "import matplotlib.pyplot as plt\n", @@ -1215,7 +1218,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 30, "id": "6f6ba8df", "metadata": {}, "outputs": [ @@ -1263,18 +1266,6 @@ "id": "4d449701", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Selected agent: SARSA\n", - "Checkpoint path: checkpoints/sarsa.pkl\n", - "\n", - "============================================================\n", - "Training: SARSA (5000 episodes)\n", - "============================================================\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -1290,7 +1281,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2809e1c28b15458daaeba87066ad1c74", + "model_id": "c817ca287f56462e8222b3fe71fffdec", "version_major": 2, "version_minor": 0 }, @@ -1432,6 +1423,189 @@ "for name, res in results.items():\n", " print(f\"{name:<15} {res['mean_reward']:>12.2f} {res['std_reward']:>8.2f} {res['win_rate']:>9.1%}\")\n" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d04d37e0", + "metadata": {}, + "outputs": [], + "source": [ + "def create_tournament_env():\n", + " \"\"\"Create PettingZoo Tennis env with preprocessing compatible with our agents.\"\"\"\n", + " env = tennis_v3.env(obs_type=\"rgb_image\")\n", + " env = ss.color_reduction_v0(env, mode=\"full\")\n", + " env = ss.resize_v1(env, x_size=84, y_size=84)\n", + " return ss.frame_stack_v1(env, 4)\n", + "\n", + "\n", + "def run_pz_match(\n", + " env,\n", + " agent_first: Agent,\n", + " agent_second: Agent,\n", + " episodes: int = 10,\n", + " max_steps: int = 4000,\n", + ") -> dict[str, int]:\n", + " \"\"\"Run multiple PettingZoo episodes between two agents.\n", + "\n", + " Returns wins for global labels {'first': ..., 'second': ..., 'draw': ...}.\n", + " \"\"\"\n", + " wins = {\"first\": 0, \"second\": 0, \"draw\": 0}\n", + "\n", + " for _ep in range(episodes):\n", + " env.reset()\n", + " rewards = {\"first_0\": 0.0, \"second_0\": 0.0}\n", + "\n", + " for step_idx, agent_id in enumerate(env.agent_iter()):\n", + " obs, reward, termination, truncation, _info = env.last()\n", + " done = termination or truncation\n", + " rewards[agent_id] += float(reward)\n", + "\n", + " if done or step_idx >= max_steps:\n", + " action = None\n", + " else:\n", + " current_agent = agent_first if agent_id == \"first_0\" else agent_second\n", + " action = current_agent.get_action(np.asarray(obs), epsilon=0.0)\n", + "\n", + " env.step(action)\n", + "\n", + " if step_idx + 1 >= max_steps:\n", + " break\n", + "\n", + " if rewards[\"first_0\"] > rewards[\"second_0\"]:\n", + " wins[\"first\"] += 1\n", + " elif rewards[\"second_0\"] > rewards[\"first_0\"]:\n", + " wins[\"second\"] += 1\n", + " else:\n", + " wins[\"draw\"] += 1\n", + "\n", + " return wins\n", + "\n", + "\n", + "def run_pettingzoo_tournament(\n", + " agents: dict[str, Agent],\n", + " checkpoint_dir: Path,\n", + " episodes_per_side: int = 10,\n", + ") -> tuple[np.ndarray, list[str]]:\n", + " \"\"\"Round-robin tournament excluding Random, with seat-swap fairness.\"\"\"\n", + " _ = itertools # kept for notebook context consistency\n", + " candidate_names = [name for name in agents if name != \"Random\"]\n", + "\n", + " # Keep only agents that have a checkpoint\n", + " ready_names: list[str] = []\n", + " for name in candidate_names:\n", + " ckpt_name = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n", + " ckpt_path = checkpoint_dir / ckpt_name\n", + " if ckpt_path.exists():\n", + " agents[name].load(str(ckpt_path))\n", + " ready_names.append(name)\n", + "\n", + " if len(ready_names) < 2:\n", + " msg = \"Need at least 2 trained (checkpointed) non-random agents for PettingZoo tournament.\"\n", + " raise RuntimeError(msg)\n", + "\n", + " n = len(ready_names)\n", + " win_matrix = np.full((n, n), np.nan)\n", + " np.fill_diagonal(win_matrix, 0.5)\n", + "\n", + " for i in range(n):\n", + " for j in range(i + 1, n):\n", + " name_i = ready_names[i]\n", + " name_j = ready_names[j]\n", + "\n", + " print(f\"Matchup: {name_i} vs {name_j}\")\n", + " env = create_tournament_env()\n", + "\n", + " # Leg 1: i as first_0, j as second_0\n", + " leg1 = run_pz_match(\n", + " env,\n", + " agent_first=agents[name_i],\n", + " agent_second=agents[name_j],\n", + " episodes=episodes_per_side,\n", + " )\n", + "\n", + " # Leg 2: swap seats\n", + " leg2 = run_pz_match(\n", + " env,\n", + " agent_first=agents[name_j],\n", + " agent_second=agents[name_i],\n", + " episodes=episodes_per_side,\n", + " )\n", + "\n", + " env.close()\n", + "\n", + " wins_i = leg1[\"first\"] + leg2[\"second\"]\n", + " wins_j = leg1[\"second\"] + leg2[\"first\"]\n", + "\n", + " decisive = wins_i + wins_j\n", + " if decisive == 0:\n", + " wr_i = 0.5\n", + " wr_j = 0.5\n", + " else:\n", + " wr_i = wins_i / decisive\n", + " wr_j = wins_j / decisive\n", + "\n", + " win_matrix[i, j] = wr_i\n", + " win_matrix[j, i] = wr_j\n", + "\n", + " print(f\" -> {name_i}: {wins_i} wins | {name_j}: {wins_j} wins\\n\")\n", + "\n", + " return win_matrix, ready_names\n", + "\n", + "\n", + "# Run tournament (non-random agents only)\n", + "win_matrix_pz, pz_names = run_pettingzoo_tournament(\n", + " agents=agents,\n", + " checkpoint_dir=CHECKPOINT_DIR,\n", + " episodes_per_side=10,\n", + ")\n", + "\n", + "# Plot win-rate matrix\n", + "plt.figure(figsize=(8, 6))\n", + "sns.heatmap(\n", + " win_matrix_pz,\n", + " annot=True,\n", + " fmt=\".2f\",\n", + " cmap=\"Blues\",\n", + " vmin=0.0,\n", + " vmax=1.0,\n", + " xticklabels=pz_names,\n", + " yticklabels=pz_names,\n", + ")\n", + "plt.xlabel(\"Opponent\")\n", + "plt.ylabel(\"Agent\")\n", + "plt.title(\"PettingZoo Tournament Win Rate Matrix (Non-random agents)\")\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# Rank agents by mean win rate vs others (excluding diagonal)\n", + "scores = {}\n", + "for idx, name in enumerate(pz_names):\n", + " row = np.delete(win_matrix_pz[idx], idx)\n", + " scores[name] = float(np.mean(row))\n", + "\n", + "ranking = sorted(scores.items(), key=lambda x: x[1], reverse=True)\n", + "print(\"Final ranking (PettingZoo tournament, non-random):\")\n", + "for rank_idx, (name, score) in enumerate(ranking, start=1):\n", + " print(f\"{rank_idx}. {name:<12} | mean win rate: {score:.3f}\")\n", + "\n", + "print(f\"\\nBest agent: {ranking[0][0]}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "3f8b300d", + "metadata": {}, + "source": [ + "## PettingZoo Tournament (Agents vs Agents)\n", + "\n", + "This tournament uses `from pettingzoo.atari import tennis_v3` to make trained agents play against each other directly.\n", + "\n", + "- Checkpoints are loaded from `checkpoints/*.pkl`\n", + "- `Random` is **excluded** from ranking\n", + "- Each pair plays in both seat positions (`first_0` and `second_0`) to reduce position bias\n", + "- A win-rate matrix and final ranking are produced" + ] } ], "metadata": {