Add PettingZoo tournament functionality for agent vs agent matches

- Introduced functions to create a Tennis environment and run matches between agents.
- Implemented a round-robin tournament format excluding random agents.
- Added win-rate matrix visualization and final ranking of agents based on performance.
- Updated imports to include necessary libraries for the new functionality.
This commit is contained in:
2026-03-02 12:21:25 +01:00
parent 61da605113
commit ae1dacc42c

View File

@@ -19,17 +19,20 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"id": "b50d7174",
"metadata": {},
"outputs": [],
"source": [
"import itertools\n",
"import pickle\n",
"from pathlib import Path\n",
"\n",
"import ale_py # noqa: F401 — registers ALE environments\n",
"import gymnasium as gym\n",
"import supersuit as ss\n",
"from gymnasium.wrappers import FrameStackObservation, ResizeObservation\n",
"from pettingzoo.atari import tennis_v3\n",
"from tqdm.auto import tqdm\n",
"\n",
"import matplotlib.pyplot as plt\n",
@@ -1215,7 +1218,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 30,
"id": "6f6ba8df",
"metadata": {},
"outputs": [
@@ -1263,18 +1266,6 @@
"id": "4d449701",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Selected agent: SARSA\n",
"Checkpoint path: checkpoints/sarsa.pkl\n",
"\n",
"============================================================\n",
"Training: SARSA (5000 episodes)\n",
"============================================================\n"
]
},
{
"name": "stdout",
"output_type": "stream",
@@ -1290,7 +1281,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2809e1c28b15458daaeba87066ad1c74",
"model_id": "c817ca287f56462e8222b3fe71fffdec",
"version_major": 2,
"version_minor": 0
},
@@ -1432,6 +1423,189 @@
"for name, res in results.items():\n",
" print(f\"{name:<15} {res['mean_reward']:>12.2f} {res['std_reward']:>8.2f} {res['win_rate']:>9.1%}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d04d37e0",
"metadata": {},
"outputs": [],
"source": [
"def create_tournament_env():\n",
" \"\"\"Create PettingZoo Tennis env with preprocessing compatible with our agents.\"\"\"\n",
" env = tennis_v3.env(obs_type=\"rgb_image\")\n",
" env = ss.color_reduction_v0(env, mode=\"full\")\n",
" env = ss.resize_v1(env, x_size=84, y_size=84)\n",
" return ss.frame_stack_v1(env, 4)\n",
"\n",
"\n",
"def run_pz_match(\n",
" env,\n",
" agent_first: Agent,\n",
" agent_second: Agent,\n",
" episodes: int = 10,\n",
" max_steps: int = 4000,\n",
") -> dict[str, int]:\n",
" \"\"\"Run multiple PettingZoo episodes between two agents.\n",
"\n",
" Returns wins for global labels {'first': ..., 'second': ..., 'draw': ...}.\n",
" \"\"\"\n",
" wins = {\"first\": 0, \"second\": 0, \"draw\": 0}\n",
"\n",
" for _ep in range(episodes):\n",
" env.reset()\n",
" rewards = {\"first_0\": 0.0, \"second_0\": 0.0}\n",
"\n",
" for step_idx, agent_id in enumerate(env.agent_iter()):\n",
" obs, reward, termination, truncation, _info = env.last()\n",
" done = termination or truncation\n",
" rewards[agent_id] += float(reward)\n",
"\n",
" if done or step_idx >= max_steps:\n",
" action = None\n",
" else:\n",
" current_agent = agent_first if agent_id == \"first_0\" else agent_second\n",
" action = current_agent.get_action(np.asarray(obs), epsilon=0.0)\n",
"\n",
" env.step(action)\n",
"\n",
" if step_idx + 1 >= max_steps:\n",
" break\n",
"\n",
" if rewards[\"first_0\"] > rewards[\"second_0\"]:\n",
" wins[\"first\"] += 1\n",
" elif rewards[\"second_0\"] > rewards[\"first_0\"]:\n",
" wins[\"second\"] += 1\n",
" else:\n",
" wins[\"draw\"] += 1\n",
"\n",
" return wins\n",
"\n",
"\n",
"def run_pettingzoo_tournament(\n",
" agents: dict[str, Agent],\n",
" checkpoint_dir: Path,\n",
" episodes_per_side: int = 10,\n",
") -> tuple[np.ndarray, list[str]]:\n",
" \"\"\"Round-robin tournament excluding Random, with seat-swap fairness.\"\"\"\n",
" _ = itertools # kept for notebook context consistency\n",
" candidate_names = [name for name in agents if name != \"Random\"]\n",
"\n",
" # Keep only agents that have a checkpoint\n",
" ready_names: list[str] = []\n",
" for name in candidate_names:\n",
" ckpt_name = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n",
" ckpt_path = checkpoint_dir / ckpt_name\n",
" if ckpt_path.exists():\n",
" agents[name].load(str(ckpt_path))\n",
" ready_names.append(name)\n",
"\n",
" if len(ready_names) < 2:\n",
" msg = \"Need at least 2 trained (checkpointed) non-random agents for PettingZoo tournament.\"\n",
" raise RuntimeError(msg)\n",
"\n",
" n = len(ready_names)\n",
" win_matrix = np.full((n, n), np.nan)\n",
" np.fill_diagonal(win_matrix, 0.5)\n",
"\n",
" for i in range(n):\n",
" for j in range(i + 1, n):\n",
" name_i = ready_names[i]\n",
" name_j = ready_names[j]\n",
"\n",
" print(f\"Matchup: {name_i} vs {name_j}\")\n",
" env = create_tournament_env()\n",
"\n",
" # Leg 1: i as first_0, j as second_0\n",
" leg1 = run_pz_match(\n",
" env,\n",
" agent_first=agents[name_i],\n",
" agent_second=agents[name_j],\n",
" episodes=episodes_per_side,\n",
" )\n",
"\n",
" # Leg 2: swap seats\n",
" leg2 = run_pz_match(\n",
" env,\n",
" agent_first=agents[name_j],\n",
" agent_second=agents[name_i],\n",
" episodes=episodes_per_side,\n",
" )\n",
"\n",
" env.close()\n",
"\n",
" wins_i = leg1[\"first\"] + leg2[\"second\"]\n",
" wins_j = leg1[\"second\"] + leg2[\"first\"]\n",
"\n",
" decisive = wins_i + wins_j\n",
" if decisive == 0:\n",
" wr_i = 0.5\n",
" wr_j = 0.5\n",
" else:\n",
" wr_i = wins_i / decisive\n",
" wr_j = wins_j / decisive\n",
"\n",
" win_matrix[i, j] = wr_i\n",
" win_matrix[j, i] = wr_j\n",
"\n",
" print(f\" -> {name_i}: {wins_i} wins | {name_j}: {wins_j} wins\\n\")\n",
"\n",
" return win_matrix, ready_names\n",
"\n",
"\n",
"# Run tournament (non-random agents only)\n",
"win_matrix_pz, pz_names = run_pettingzoo_tournament(\n",
" agents=agents,\n",
" checkpoint_dir=CHECKPOINT_DIR,\n",
" episodes_per_side=10,\n",
")\n",
"\n",
"# Plot win-rate matrix\n",
"plt.figure(figsize=(8, 6))\n",
"sns.heatmap(\n",
" win_matrix_pz,\n",
" annot=True,\n",
" fmt=\".2f\",\n",
" cmap=\"Blues\",\n",
" vmin=0.0,\n",
" vmax=1.0,\n",
" xticklabels=pz_names,\n",
" yticklabels=pz_names,\n",
")\n",
"plt.xlabel(\"Opponent\")\n",
"plt.ylabel(\"Agent\")\n",
"plt.title(\"PettingZoo Tournament Win Rate Matrix (Non-random agents)\")\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"# Rank agents by mean win rate vs others (excluding diagonal)\n",
"scores = {}\n",
"for idx, name in enumerate(pz_names):\n",
" row = np.delete(win_matrix_pz[idx], idx)\n",
" scores[name] = float(np.mean(row))\n",
"\n",
"ranking = sorted(scores.items(), key=lambda x: x[1], reverse=True)\n",
"print(\"Final ranking (PettingZoo tournament, non-random):\")\n",
"for rank_idx, (name, score) in enumerate(ranking, start=1):\n",
" print(f\"{rank_idx}. {name:<12} | mean win rate: {score:.3f}\")\n",
"\n",
"print(f\"\\nBest agent: {ranking[0][0]}\")\n"
]
},
{
"cell_type": "markdown",
"id": "3f8b300d",
"metadata": {},
"source": [
"## PettingZoo Tournament (Agents vs Agents)\n",
"\n",
"This tournament uses `from pettingzoo.atari import tennis_v3` to make trained agents play against each other directly.\n",
"\n",
"- Checkpoints are loaded from `checkpoints/*.pkl`\n",
"- `Random` is **excluded** from ranking\n",
"- Each pair plays in both seat positions (`first_0` and `second_0`) to reduce position bias\n",
"- A win-rate matrix and final ranking are produced"
]
}
],
"metadata": {