mirror of
https://github.com/ArthurDanjou/ArtStudies.git
synced 2026-03-16 03:11:46 +01:00
Add PettingZoo tournament functionality for agent vs agent matches
- Introduced functions to create a Tennis environment and run matches between agents. - Implemented a round-robin tournament format excluding random agents. - Added win-rate matrix visualization and final ranking of agents based on performance. - Updated imports to include necessary libraries for the new functionality.
This commit is contained in:
@@ -19,17 +19,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": null,
|
||||
"id": "b50d7174",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import itertools\n",
|
||||
"import pickle\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import ale_py # noqa: F401 — registers ALE environments\n",
|
||||
"import gymnasium as gym\n",
|
||||
"import supersuit as ss\n",
|
||||
"from gymnasium.wrappers import FrameStackObservation, ResizeObservation\n",
|
||||
"from pettingzoo.atari import tennis_v3\n",
|
||||
"from tqdm.auto import tqdm\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
@@ -1215,7 +1218,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 30,
|
||||
"id": "6f6ba8df",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1263,18 +1266,6 @@
|
||||
"id": "4d449701",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Selected agent: SARSA\n",
|
||||
"Checkpoint path: checkpoints/sarsa.pkl\n",
|
||||
"\n",
|
||||
"============================================================\n",
|
||||
"Training: SARSA (5000 episodes)\n",
|
||||
"============================================================\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
@@ -1290,7 +1281,7 @@
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "2809e1c28b15458daaeba87066ad1c74",
|
||||
"model_id": "c817ca287f56462e8222b3fe71fffdec",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
@@ -1432,6 +1423,189 @@
|
||||
"for name, res in results.items():\n",
|
||||
" print(f\"{name:<15} {res['mean_reward']:>12.2f} {res['std_reward']:>8.2f} {res['win_rate']:>9.1%}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d04d37e0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def create_tournament_env():\n",
|
||||
" \"\"\"Create PettingZoo Tennis env with preprocessing compatible with our agents.\"\"\"\n",
|
||||
" env = tennis_v3.env(obs_type=\"rgb_image\")\n",
|
||||
" env = ss.color_reduction_v0(env, mode=\"full\")\n",
|
||||
" env = ss.resize_v1(env, x_size=84, y_size=84)\n",
|
||||
" return ss.frame_stack_v1(env, 4)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def run_pz_match(\n",
|
||||
" env,\n",
|
||||
" agent_first: Agent,\n",
|
||||
" agent_second: Agent,\n",
|
||||
" episodes: int = 10,\n",
|
||||
" max_steps: int = 4000,\n",
|
||||
") -> dict[str, int]:\n",
|
||||
" \"\"\"Run multiple PettingZoo episodes between two agents.\n",
|
||||
"\n",
|
||||
" Returns wins for global labels {'first': ..., 'second': ..., 'draw': ...}.\n",
|
||||
" \"\"\"\n",
|
||||
" wins = {\"first\": 0, \"second\": 0, \"draw\": 0}\n",
|
||||
"\n",
|
||||
" for _ep in range(episodes):\n",
|
||||
" env.reset()\n",
|
||||
" rewards = {\"first_0\": 0.0, \"second_0\": 0.0}\n",
|
||||
"\n",
|
||||
" for step_idx, agent_id in enumerate(env.agent_iter()):\n",
|
||||
" obs, reward, termination, truncation, _info = env.last()\n",
|
||||
" done = termination or truncation\n",
|
||||
" rewards[agent_id] += float(reward)\n",
|
||||
"\n",
|
||||
" if done or step_idx >= max_steps:\n",
|
||||
" action = None\n",
|
||||
" else:\n",
|
||||
" current_agent = agent_first if agent_id == \"first_0\" else agent_second\n",
|
||||
" action = current_agent.get_action(np.asarray(obs), epsilon=0.0)\n",
|
||||
"\n",
|
||||
" env.step(action)\n",
|
||||
"\n",
|
||||
" if step_idx + 1 >= max_steps:\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" if rewards[\"first_0\"] > rewards[\"second_0\"]:\n",
|
||||
" wins[\"first\"] += 1\n",
|
||||
" elif rewards[\"second_0\"] > rewards[\"first_0\"]:\n",
|
||||
" wins[\"second\"] += 1\n",
|
||||
" else:\n",
|
||||
" wins[\"draw\"] += 1\n",
|
||||
"\n",
|
||||
" return wins\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def run_pettingzoo_tournament(\n",
|
||||
" agents: dict[str, Agent],\n",
|
||||
" checkpoint_dir: Path,\n",
|
||||
" episodes_per_side: int = 10,\n",
|
||||
") -> tuple[np.ndarray, list[str]]:\n",
|
||||
" \"\"\"Round-robin tournament excluding Random, with seat-swap fairness.\"\"\"\n",
|
||||
" _ = itertools # kept for notebook context consistency\n",
|
||||
" candidate_names = [name for name in agents if name != \"Random\"]\n",
|
||||
"\n",
|
||||
" # Keep only agents that have a checkpoint\n",
|
||||
" ready_names: list[str] = []\n",
|
||||
" for name in candidate_names:\n",
|
||||
" ckpt_name = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n",
|
||||
" ckpt_path = checkpoint_dir / ckpt_name\n",
|
||||
" if ckpt_path.exists():\n",
|
||||
" agents[name].load(str(ckpt_path))\n",
|
||||
" ready_names.append(name)\n",
|
||||
"\n",
|
||||
" if len(ready_names) < 2:\n",
|
||||
" msg = \"Need at least 2 trained (checkpointed) non-random agents for PettingZoo tournament.\"\n",
|
||||
" raise RuntimeError(msg)\n",
|
||||
"\n",
|
||||
" n = len(ready_names)\n",
|
||||
" win_matrix = np.full((n, n), np.nan)\n",
|
||||
" np.fill_diagonal(win_matrix, 0.5)\n",
|
||||
"\n",
|
||||
" for i in range(n):\n",
|
||||
" for j in range(i + 1, n):\n",
|
||||
" name_i = ready_names[i]\n",
|
||||
" name_j = ready_names[j]\n",
|
||||
"\n",
|
||||
" print(f\"Matchup: {name_i} vs {name_j}\")\n",
|
||||
" env = create_tournament_env()\n",
|
||||
"\n",
|
||||
" # Leg 1: i as first_0, j as second_0\n",
|
||||
" leg1 = run_pz_match(\n",
|
||||
" env,\n",
|
||||
" agent_first=agents[name_i],\n",
|
||||
" agent_second=agents[name_j],\n",
|
||||
" episodes=episodes_per_side,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Leg 2: swap seats\n",
|
||||
" leg2 = run_pz_match(\n",
|
||||
" env,\n",
|
||||
" agent_first=agents[name_j],\n",
|
||||
" agent_second=agents[name_i],\n",
|
||||
" episodes=episodes_per_side,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" env.close()\n",
|
||||
"\n",
|
||||
" wins_i = leg1[\"first\"] + leg2[\"second\"]\n",
|
||||
" wins_j = leg1[\"second\"] + leg2[\"first\"]\n",
|
||||
"\n",
|
||||
" decisive = wins_i + wins_j\n",
|
||||
" if decisive == 0:\n",
|
||||
" wr_i = 0.5\n",
|
||||
" wr_j = 0.5\n",
|
||||
" else:\n",
|
||||
" wr_i = wins_i / decisive\n",
|
||||
" wr_j = wins_j / decisive\n",
|
||||
"\n",
|
||||
" win_matrix[i, j] = wr_i\n",
|
||||
" win_matrix[j, i] = wr_j\n",
|
||||
"\n",
|
||||
" print(f\" -> {name_i}: {wins_i} wins | {name_j}: {wins_j} wins\\n\")\n",
|
||||
"\n",
|
||||
" return win_matrix, ready_names\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Run tournament (non-random agents only)\n",
|
||||
"win_matrix_pz, pz_names = run_pettingzoo_tournament(\n",
|
||||
" agents=agents,\n",
|
||||
" checkpoint_dir=CHECKPOINT_DIR,\n",
|
||||
" episodes_per_side=10,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Plot win-rate matrix\n",
|
||||
"plt.figure(figsize=(8, 6))\n",
|
||||
"sns.heatmap(\n",
|
||||
" win_matrix_pz,\n",
|
||||
" annot=True,\n",
|
||||
" fmt=\".2f\",\n",
|
||||
" cmap=\"Blues\",\n",
|
||||
" vmin=0.0,\n",
|
||||
" vmax=1.0,\n",
|
||||
" xticklabels=pz_names,\n",
|
||||
" yticklabels=pz_names,\n",
|
||||
")\n",
|
||||
"plt.xlabel(\"Opponent\")\n",
|
||||
"plt.ylabel(\"Agent\")\n",
|
||||
"plt.title(\"PettingZoo Tournament Win Rate Matrix (Non-random agents)\")\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"# Rank agents by mean win rate vs others (excluding diagonal)\n",
|
||||
"scores = {}\n",
|
||||
"for idx, name in enumerate(pz_names):\n",
|
||||
" row = np.delete(win_matrix_pz[idx], idx)\n",
|
||||
" scores[name] = float(np.mean(row))\n",
|
||||
"\n",
|
||||
"ranking = sorted(scores.items(), key=lambda x: x[1], reverse=True)\n",
|
||||
"print(\"Final ranking (PettingZoo tournament, non-random):\")\n",
|
||||
"for rank_idx, (name, score) in enumerate(ranking, start=1):\n",
|
||||
" print(f\"{rank_idx}. {name:<12} | mean win rate: {score:.3f}\")\n",
|
||||
"\n",
|
||||
"print(f\"\\nBest agent: {ranking[0][0]}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3f8b300d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## PettingZoo Tournament (Agents vs Agents)\n",
|
||||
"\n",
|
||||
"This tournament uses `from pettingzoo.atari import tennis_v3` to make trained agents play against each other directly.\n",
|
||||
"\n",
|
||||
"- Checkpoints are loaded from `checkpoints/*.pkl`\n",
|
||||
"- `Random` is **excluded** from ranking\n",
|
||||
"- Each pair plays in both seat positions (`first_0` and `second_0`) to reduce position bias\n",
|
||||
"- A win-rate matrix and final ranking are produced"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
Reference in New Issue
Block a user