WIP - Add RL Project

This commit is contained in:
2026-03-02 09:33:44 +01:00
parent fadb874240
commit d3d56fd6ab
3 changed files with 738 additions and 23 deletions

5
.gitignore vendored
View File

@@ -34,6 +34,7 @@ results_stage_2/
*.pth *.pth
*.bin *.bin
Classification and Regression
BDTOPO BDTOPO
Test_ISF.csv
Train_ISF.csv

View File

@@ -2,21 +2,41 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 1,
"id": "7e37429a", "id": "7e37429a",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import torch\n", "import itertools\n",
"import numpy as np\n",
"import pickle\n", "import pickle\n",
"from collections.abc import Iterator\n",
"from pathlib import Path\n", "from pathlib import Path\n",
"\n" "from typing import Protocol\n",
"\n",
"import supersuit as ss\n",
"from pettingzoo.atari import tennis_v3\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import seaborn as sns\n",
"\n",
"import torch\n",
"from torch import nn, optim\n"
]
},
{
"cell_type": "markdown",
"id": "5f888489",
"metadata": {},
"source": [
"# Agent Definitions\n",
"\n",
"## Base Agent Class"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 2,
"id": "85ff0eb4", "id": "85ff0eb4",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -24,27 +44,34 @@
"class Agent:\n", "class Agent:\n",
" \"\"\"Base class for reinforcement learning agents.\"\"\"\n", " \"\"\"Base class for reinforcement learning agents.\"\"\"\n",
"\n", "\n",
" def __init__(self, action_space: int) -> None:\n", " def __init__(self, seed: int, action_space: int) -> None:\n",
" \"\"\"Initialize the agent.\"\"\"\n", " \"\"\"Initialize the agent with its action space.\"\"\"\n",
" self.action_space = action_space\n", " self.action_space = action_space\n",
" self.rng = np.random.default_rng(seed = seed)\n",
"\n", "\n",
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0):\n", " def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> None:\n",
" \"\"\"Select an action based on the current observation.\"\"\"\n", " \"\"\"Select an action from the current observation.\"\"\"\n",
" raise NotImplementedError\n",
"\n", "\n",
" def update(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool):\n", " def update(\n",
" \"\"\"Update the agent's knowledge based on the experience tuple.\"\"\"\n", " self,\n",
" pass\n", " state: np.ndarray,\n",
" action: int,\n",
" reward: float,\n",
" next_state: np.ndarray,\n",
" done: bool,\n",
" next_action: int | None = None,\n",
" ) -> None:\n",
" \"\"\"Update agent parameters from one transition.\"\"\"\n",
"\n", "\n",
" def save(self, filename: str) -> None:\n", " def save(self, filename: str) -> None:\n",
" \"\"\"Save the agent's state to a file.\"\"\"\n", " \"\"\"Save the agent state to disk.\"\"\"\n",
" with Path(filename).open(\"wb\") as f:\n", " with Path(filename).open(\"wb\") as f:\n",
" pickle.dump(self.__dict__, f)\n", " pickle.dump(self.__dict__, f)\n",
"\n", "\n",
" def load(self, filename: str) -> None:\n", " def load(self, filename: str) -> None:\n",
" \"\"\"Load the agent's state from a file.\"\"\"\n", " \"\"\"Load the agent state from disk.\"\"\"\n",
" with Path(filename).open(\"rb\") as f:\n", " with Path(filename).open(\"rb\") as f:\n",
" self.__dict__.update(pickle.load(f)) # noqa: S301\n" " self.__dict__.update(pickle.load(f)) # noqa: S301\n"
] ]
}, },
{ {
@@ -57,16 +84,701 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 3,
"id": "89633751", "id": "89633751",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"class RandomAgent(Agent):\n", "class RandomAgent(Agent):\n",
" \"\"\"A simple agent that selects actions randomly.\"\"\"\n", " \"\"\"A simple agent that selects actions randomly.\"\"\"\n",
" def get_action(self, observation, epsilon=0.0):\n", "\n",
" return self.action_space.sample()\n" " def get_action( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]\n",
" self,\n",
" observation: np.ndarray,\n",
" epsilon: float = 0.0,\n",
" ) -> int:\n",
" \"\"\"Select an action randomly, ignoring the observation and epsilon.\"\"\"\n",
" _ = observation\n",
" _ = epsilon\n",
" return int(self.rng.integers(0, self.action_space))\n"
] ]
},
{
"cell_type": "markdown",
"id": "4b24ce3a",
"metadata": {},
"source": [
"## Torch Agent"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7b1dba76",
"metadata": {},
"outputs": [],
"source": [
"class TorchAgent(Agent):\n",
" \"\"\"A reinforcement-learning agent using a PyTorch model to estimate action values.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" n_features: int,\n",
" n_actions: int,\n",
" lr: float = 0.001,\n",
" gamma: float = 0.99,\n",
" seed: int = 42,\n",
" ) -> None:\n",
" \"\"\"Initialize TorchAgent.\n",
"\n",
" Args:\n",
" n_features: Number of input features (state dimensionality).\n",
" n_actions: Number of possible discrete actions.\n",
" lr: Learning rate (kept for subclasses).\n",
" gamma: Discount factor.\n",
" seed: RNG seed for reproducibility.\n",
"\n",
" \"\"\"\n",
" # Initialize base Agent with seed and action space size\n",
" super().__init__(seed, n_actions)\n",
"\n",
" # keep parameters to avoid unused-argument warnings and for subclasses\n",
" self.n_features = n_features\n",
" self.lr = lr\n",
"\n",
" self.gamma = gamma\n",
" self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" self.model: nn.Module | None = None\n",
" self.optimizer: optim.Optimizer | None = None\n",
" self.criterion = nn.MSELoss()\n",
"\n",
" def save(self, filename: str) -> None:\n",
" \"\"\"Save model parameters to disk.\"\"\"\n",
" if self.model is None:\n",
" message = \"Model is not initialized. Define self.model before calling save().\"\n",
" raise ValueError(message)\n",
" torch.save(self.model.state_dict(), filename)\n",
"\n",
" def load(self, filename: str) -> None:\n",
" \"\"\"Load model parameters from disk into existing model.\"\"\"\n",
" if self.model is None:\n",
" message = \"Model is not initialized. Define self.model before calling load().\"\n",
" raise ValueError(message)\n",
" state = torch.load(filename, map_location=self.device)\n",
" self.model.load_state_dict(state)\n",
" self.model.eval()\n",
"\n",
" def get_action( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]\n",
" self,\n",
" observation: np.ndarray,\n",
" epsilon: float = 0.0,\n",
" ) -> int:\n",
" \"\"\"Return an action for the given observation using epsilon-greedy policy.\"\"\"\n",
" if self.model is None:\n",
" message = \"Model is not initialized. Define self.model before calling get_action().\"\n",
" raise ValueError(message)\n",
"\n",
" if self.rng.random() < epsilon:\n",
" return int(self.rng.integers(0, self.action_space))\n",
"\n",
" state_t = (\n",
" torch.as_tensor(observation, dtype=torch.float32, device=self.device)\n",
" / 255.0\n",
" )\n",
" if state_t.ndim == 1:\n",
" state_t = state_t.unsqueeze(0)\n",
"\n",
" self.model.eval()\n",
" with torch.no_grad():\n",
" q_values = self.model(state_t).squeeze(0)\n",
"\n",
" return int(torch.argmax(q_values).item())\n"
]
},
{
"cell_type": "markdown",
"id": "07150351",
"metadata": {},
"source": [
"## Linear Agent"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "19ed1ba8",
"metadata": {},
"outputs": [],
"source": [
"class LinearAgent(TorchAgent):\n",
" \"\"\"An agent that uses a simple linear model to estimate action values.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" n_features: int,\n",
" n_actions: int,\n",
" lr: float = 0.001,\n",
" gamma: float = 0.99,\n",
" method: str = \"q_learning\",\n",
" ) -> None:\n",
" \"\"\"Initialize LinearAgent with a linear model and specified update method.\"\"\"\n",
" super().__init__(n_features, n_actions, lr, gamma)\n",
" self.method = method\n",
" self.model = nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(n_features, n_actions),\n",
" ).to(self.device)\n",
" self.optimizer = optim.Adam(self.model.parameters(), lr=lr)\n",
"\n",
" def update(\n",
" self,\n",
" state: np.ndarray,\n",
" action: int,\n",
" reward: float,\n",
" next_state: np.ndarray,\n",
" done: bool,\n",
" next_action: int | None = None,\n",
" ) -> None:\n",
" \"\"\"Update the agent's value estimates using either Q-learning or SARSA.\"\"\"\n",
" if self.model is None:\n",
" message = \"Model is not initialized.\"\n",
" raise ValueError(message)\n",
" if self.optimizer is None:\n",
" message = \"Optimizer is not initialized.\"\n",
" raise ValueError(message)\n",
"\n",
" model = self.model\n",
" optimizer = self.optimizer\n",
"\n",
" state_t = torch.as_tensor(state, device=self.device, dtype=torch.float32).unsqueeze(0) / 255.0\n",
" next_state_t = torch.as_tensor(next_state, device=self.device, dtype=torch.float32).unsqueeze(0) / 255.0\n",
" reward_t = torch.as_tensor(reward, device=self.device, dtype=torch.float32)\n",
" done_t = torch.as_tensor(float(done), device=self.device, dtype=torch.float32)\n",
"\n",
" q_curr = model(state_t).squeeze(0)[action]\n",
"\n",
" with torch.no_grad():\n",
" if done:\n",
" q_next = torch.tensor(0.0, device=self.device)\n",
" elif self.method == \"q_learning\":\n",
" q_next = model(next_state_t).squeeze(0).max()\n",
" elif self.method == \"sarsa\":\n",
" if next_action is None:\n",
" message = \"next_action is required for SARSA updates.\"\n",
" raise ValueError(message)\n",
" q_next = model(next_state_t).squeeze(0)[next_action]\n",
" else:\n",
" message = (\n",
" f\"Unknown method: {self.method}. Supported methods are \"\n",
" \"'q_learning' and 'sarsa'.\"\n",
" )\n",
" raise ValueError(message)\n",
"\n",
" target = reward_t + (1.0 - done_t) * self.gamma * q_next\n",
"\n",
" loss = self.criterion(q_curr, target)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n"
]
},
{
"cell_type": "markdown",
"id": "1b132c0e",
"metadata": {},
"source": [
"## DQN Agent"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d90b82d0",
"metadata": {},
"outputs": [],
"source": [
"class DQNAgent(TorchAgent):\n",
" \"\"\"An agent that uses a deep neural network to estimate action values, with support for experience replay and target networks.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" n_features: int,\n",
" n_actions: int,\n",
" lr: float = 0.0001,\n",
" gamma: float = 0.99,\n",
" buffer_size: int = 10000,\n",
" batch_size: int = 64,\n",
" target_update_freq: int = 1000,\n",
" ) -> None:\n",
" \"\"\"Initialize DQNAgent with a deep neural network, experience replay buffer, and target network.\"\"\"\n",
" super().__init__(n_features, n_actions, lr, gamma)\n",
" self.buffer_size = buffer_size\n",
" self.batch_size = batch_size\n",
" self.target_update_freq = target_update_freq\n",
" self.update_step = 0\n",
"\n",
" self.model = nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(n_features, 128),\n",
" nn.ReLU(),\n",
" nn.Linear(128, 128),\n",
" nn.ReLU(),\n",
" nn.Linear(128, n_actions),\n",
" ).to(self.device)\n",
"\n",
" self.target_model: nn.Module = nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(n_features, 128),\n",
" nn.ReLU(),\n",
" nn.Linear(128, 128),\n",
" nn.ReLU(),\n",
" nn.Linear(128, n_actions),\n",
" ).to(self.device)\n",
" self.target_model.load_state_dict(self.model.state_dict())\n",
"\n",
" self.optimizer = optim.Adam(self.model.parameters(), lr=lr)\n",
" self.replay_buffer: list[tuple[np.ndarray, int, float, np.ndarray, bool]] = []\n",
"\n",
" def update(\n",
" self,\n",
" state: np.ndarray,\n",
" action: int,\n",
" reward: float,\n",
" next_state: np.ndarray,\n",
" done: bool,\n",
" next_action: int | None = None,\n",
" ) -> None:\n",
" \"\"\"Update the agent's value estimates using experience replay and target network.\"\"\"\n",
" _ = next_action # kept for API compatibility with other agents\n",
"\n",
" if self.model is None:\n",
" message = \"Model is not initialized.\"\n",
" raise ValueError(message)\n",
" if self.optimizer is None:\n",
" message = \"Optimizer is not initialized.\"\n",
" raise ValueError(message)\n",
"\n",
" model = self.model\n",
" optimizer = self.optimizer\n",
" target_model = self.target_model\n",
"\n",
" self.replay_buffer.append((state, action, reward, next_state, done))\n",
" if len(self.replay_buffer) > self.buffer_size:\n",
" self.replay_buffer.pop(0)\n",
"\n",
" if len(self.replay_buffer) < self.batch_size:\n",
" return\n",
"\n",
" idx = self.rng.choice(\n",
" len(self.replay_buffer), size=self.batch_size, replace=False,\n",
" )\n",
" batch = [self.replay_buffer[i] for i in idx]\n",
"\n",
" states_np, actions_np, rewards_np, next_states_np, dones_np = map(\n",
" np.array, zip(*batch, strict=True),\n",
" )\n",
"\n",
" states_t = torch.as_tensor(states_np, dtype=torch.float32, device=self.device) / 255.0\n",
" next_states_t = torch.as_tensor(next_states_np, dtype=torch.float32, device=self.device) / 255.0\n",
" actions_t = torch.as_tensor(actions_np, dtype=torch.long, device=self.device)\n",
" rewards_t = torch.as_tensor(rewards_np, dtype=torch.float32, device=self.device)\n",
" dones_t = torch.as_tensor(\n",
" dones_np.astype(float), dtype=torch.float32, device=self.device,\n",
" )\n",
"\n",
" q_curr = model(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)\n",
"\n",
" with torch.no_grad():\n",
" q_next = target_model(next_states_t).max(dim=1)[0]\n",
" target = rewards_t + (1.0 - dones_t) * self.gamma * q_next\n",
"\n",
" loss = self.criterion(q_curr, target)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" self.update_step += 1\n",
" if self.update_step % self.target_update_freq == 0:\n",
" target_model.load_state_dict(model.state_dict())\n"
]
},
{
"cell_type": "markdown",
"id": "2014cc3e",
"metadata": {},
"source": [
"## Monte Carlo Agent"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f99050ce",
"metadata": {},
"outputs": [],
"source": [
"class MonteCarloAgent(LinearAgent):\n",
" \"\"\"An agent that uses Monte Carlo methods to estimate action values.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" n_features: int,\n",
" n_actions: int,\n",
" lr: float = 0.001,\n",
" gamma: float = 0.99,\n",
" ) -> None:\n",
" \"\"\"Initialize MonteCarloAgent.\"\"\"\n",
" super().__init__(n_features, n_actions, lr, gamma, method=\"monte_carlo\")\n",
" self.episode_buffer: list[tuple[np.ndarray, int, float]] = []\n",
"\n",
" def update(\n",
" self,\n",
" state: np.ndarray,\n",
" action: int,\n",
" reward: float,\n",
" next_state: np.ndarray,\n",
" done: bool,\n",
" next_action: int | None = None,\n",
" ) -> None:\n",
" \"\"\"Update value estimates with Monte Carlo returns at episode end.\"\"\"\n",
" _ = next_state\n",
" _ = next_action\n",
"\n",
" if self.model is None:\n",
" msg = \"Model is not initialized.\"\n",
" raise ValueError(msg)\n",
" if self.optimizer is None:\n",
" msg = \"Optimizer is not initialized.\"\n",
" raise ValueError(msg)\n",
"\n",
" model = self.model\n",
" optimizer = self.optimizer\n",
"\n",
" self.episode_buffer.append((state, action, reward))\n",
"\n",
" if not done:\n",
" return\n",
"\n",
" returns = 0.0\n",
" for s, a, r in reversed(self.episode_buffer):\n",
" returns = self.gamma * returns + r\n",
"\n",
" state_t = torch.as_tensor(s, dtype=torch.float32, device=self.device)\n",
" if state_t.ndim == 1:\n",
" state_t = state_t.unsqueeze(0)\n",
"\n",
" q_val = model(state_t).squeeze(0)[a]\n",
" target = torch.as_tensor(returns, dtype=torch.float32, device=self.device)\n",
"\n",
" loss = self.criterion(q_val, target)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" self.episode_buffer = []\n"
]
},
{
"cell_type": "markdown",
"id": "bfa23713",
"metadata": {},
"source": [
"# Tennis Environment"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "96c06fe3",
"metadata": {},
"outputs": [],
"source": [
"def create_env(obs_type: str = \"ram\"):\n",
" \"\"\"Create the Tennis environment with optional preprocessing for image observations.\"\"\"\n",
" env = tennis_v3.env(obs_type=obs_type)\n",
" if obs_type == \"rgb_image\":\n",
" env = ss.color_reduction_v0(env, mode=\"full\")\n",
" env = ss.resize_v1(env, x_size=84, y_size=84)\n",
" return ss.frame_stack_v1(env, 4)\n"
]
},
{
"cell_type": "markdown",
"id": "fe326344",
"metadata": {},
"source": [
"## Tournament"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7b1d6ebf",
"metadata": {},
"outputs": [],
"source": [
"class AECEnvProtocol(Protocol):\n",
" \"\"\"Protocol for PettingZoo AEC environments used in tournament matches.\"\"\"\n",
"\n",
" possible_agents: list[str]\n",
"\n",
" def reset(self) -> None:\n",
" \"\"\"Reset the environment to its initial state.\"\"\"\n",
" ...\n",
"\n",
" def agent_iter(self) -> Iterator[str]:\n",
" \"\"\"Iterate over agent identifiers in turn order.\"\"\"\n",
" ...\n",
"\n",
" def last(self) -> tuple[np.ndarray, float, bool, bool, dict[str, object]]:\n",
" \"\"\"Return the last transition tuple for the current agent.\"\"\"\n",
" ...\n",
"\n",
" def step(self, action: int | None) -> None:\n",
" \"\"\"Apply an action for the current agent.\"\"\"\n",
" ...\n",
"\n",
"\n",
"def run_tournament_match(\n",
" env: AECEnvProtocol,\n",
" agent1: Agent,\n",
" agent2: Agent,\n",
" *,\n",
" episodes: int = 100,\n",
" train_mode: bool = True,\n",
" epsilon_start: float = 0.2,\n",
" epsilon_end: float = 0.05,\n",
" epsilon_decay: float = 0.99,\n",
" max_steps: int = 2000,\n",
") -> tuple[dict[str, int], dict[str, list[float]]]:\n",
" \"\"\"Run a tournament match between two agents in a PettingZoo AEC environment.\"\"\"\n",
" agents_dict: dict[str, Agent] = {\"first_0\": agent1, \"second_0\": agent2}\n",
" wins = {\"first_0\": 0, \"second_0\": 0}\n",
" rewards_history = {\"first_0\": [], \"second_0\": []}\n",
"\n",
" current_epsilon = epsilon_start if train_mode else 0.0\n",
"\n",
" for _ in range(episodes):\n",
" env.reset()\n",
"\n",
" previous_states: dict[str, np.ndarray | None] = dict.fromkeys(\n",
" env.possible_agents,\n",
" None,\n",
" )\n",
" previous_actions: dict[str, int | None] = dict.fromkeys(\n",
" env.possible_agents,\n",
" None,\n",
" )\n",
" episode_rewards = {\"first_0\": 0.0, \"second_0\": 0.0}\n",
"\n",
" for steps, agent_id in enumerate(env.agent_iter()):\n",
" observation, reward, termination, truncation, _info = env.last()\n",
" done = termination or truncation\n",
" obs_array = np.asarray(observation)\n",
"\n",
" episode_rewards[agent_id] += reward\n",
"\n",
" action: int | None\n",
" if done or steps >= max_steps:\n",
" action = None\n",
" done = True\n",
" if reward > 0:\n",
" wins[agent_id] += 1\n",
" else:\n",
" action = agents_dict[agent_id].get_action(\n",
" obs_array,\n",
" epsilon=current_epsilon,\n",
" )\n",
"\n",
" if train_mode:\n",
" prev_state = previous_states[agent_id]\n",
" prev_action = previous_actions[agent_id]\n",
" if prev_state is not None and prev_action is not None:\n",
" agents_dict[agent_id].update(\n",
" state=prev_state,\n",
" action=prev_action,\n",
" reward=reward,\n",
" next_state=obs_array,\n",
" done=done,\n",
" next_action=action,\n",
" )\n",
"\n",
" if not done:\n",
" previous_states[agent_id] = obs_array\n",
" previous_actions[agent_id] = action\n",
"\n",
" env.step(action)\n",
"\n",
" if steps + 1 >= max_steps:\n",
" break\n",
"\n",
" rewards_history[\"first_0\"].append(episode_rewards[\"first_0\"])\n",
" rewards_history[\"second_0\"].append(episode_rewards[\"second_0\"])\n",
"\n",
" if train_mode:\n",
" current_epsilon = max(epsilon_end, current_epsilon * epsilon_decay)\n",
"\n",
" return wins, rewards_history\n",
"\n",
"\n",
"def plot_learning_curves(rewards_history: dict[str, list[float]], window: int = 50) -> None:\n",
" \"\"\"Plot learning curves for each agent using a moving average of rewards.\"\"\"\n",
" plt.figure(figsize=(10, 5))\n",
"\n",
" for agent_id, rewards in rewards_history.items():\n",
" if len(rewards) >= window:\n",
" ma = np.convolve(rewards, np.ones(window) / window, mode=\"valid\")\n",
" plt.plot(np.arange(window - 1, len(rewards)), ma, label=f\"Agent {agent_id}\")\n",
" else:\n",
" plt.plot(rewards, label=f\"Agent {agent_id} (Raw)\")\n",
"\n",
" plt.xlabel(\"Episodes\")\n",
" plt.ylabel(f\"Moving Average Reward (Window={window})\")\n",
" plt.legend()\n",
" plt.grid(visible=True)\n",
" plt.show()\n",
"\n",
"\n",
"def plot_win_rate_matrix(results_matrix: np.ndarray, agent_names: list[str]) -> None:\n",
" \"\"\"Plot a heatmap of win rates between agents.\"\"\"\n",
" plt.figure(figsize=(8, 6))\n",
" sns.heatmap(\n",
" results_matrix,\n",
" annot=True,\n",
" fmt=\".2f\",\n",
" cmap=\"Blues\",\n",
" xticklabels=agent_names,\n",
" yticklabels=agent_names,\n",
" )\n",
" plt.xlabel(\"Opponent\")\n",
" plt.ylabel(\"Agent\")\n",
" plt.title(\"Win Rate Matrix\")\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "81dc1a05",
"metadata": {},
"outputs": [],
"source": [
"def evaluate_tournament(\n",
" env: AECEnvProtocol,\n",
" agents: dict[str, Agent],\n",
" episodes_per_match: int = 10,\n",
") -> tuple[np.ndarray, list[str]]:\n",
" \"\"\"Evaluate a round-robin tournament between agents, returning a win rate matrix and agent names.\"\"\"\n",
" agent_names = list(agents.keys())\n",
" n_agents = len(agent_names)\n",
" win_matrix = np.zeros((n_agents, n_agents))\n",
"\n",
" matchups = list(itertools.permutations(enumerate(agent_names), 2))\n",
" total_matchups = len(matchups)\n",
"\n",
" for idx, ((idx1, name1), (idx2, name2)) in enumerate(matchups, start=1):\n",
" print(f\"[Evaluation {idx}/{total_matchups}] Match : {name1} vs {name2}\")\n",
"\n",
" agent1 = agents[name1]\n",
" agent2 = agents[name2]\n",
"\n",
" wins, _ = run_tournament_match(\n",
" env=env,\n",
" agent1=agent1,\n",
" agent2=agent2,\n",
" episodes=episodes_per_match,\n",
" train_mode=False,\n",
" epsilon_start=0.0,\n",
" )\n",
"\n",
" total_games = wins[\"first_0\"] + wins[\"second_0\"]\n",
" if total_games > 0:\n",
" win_matrix[idx1, idx2] = wins[\"first_0\"] / total_games\n",
"\n",
" print(f\"-> Résultat : {name1} {wins['first_0']} - {wins['second_0']} {name2}\\n\")\n",
"\n",
" return win_matrix, agent_names\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "1dc5f3f5",
"metadata": {},
"outputs": [],
"source": [
"env = create_env(obs_type=\"ram\")\n",
"env.reset()\n",
"\n",
"obs_space = env.observation_space(\"first_0\")\n",
"n_actions = env.action_space(\"first_0\").n\n",
"n_features = int(np.prod(obs_space.shape))\n",
"\n",
"agent_random = RandomAgent(seed=42, action_space=n_actions)\n",
"agent_q = LinearAgent(n_features=n_features, n_actions=n_actions, method=\"q_learning\")\n",
"agent_sarsa = LinearAgent(n_features=n_features, n_actions=n_actions, method=\"sarsa\")\n",
"agent_dqn = DQNAgent(n_features=n_features, n_actions=n_actions)\n",
"agent_mc = MonteCarloAgent(n_features=n_features, n_actions=n_actions)\n",
"\n",
"agents_dict = {\n",
" \"Random\": agent_random,\n",
" \"Q-Learning\": agent_q,\n",
" \"SARSA\": agent_sarsa,\n",
" \"DQN\": agent_dqn,\n",
" \"MC\": agent_mc,\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "afbdf033",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1/20] Début entraînement : Random vs Q-Learning\n"
]
}
],
"source": [
"matchups = list(itertools.permutations(agents_dict.keys(), 2))\n",
"total_matchups = len(matchups)\n",
"\n",
"for idx, (name1, name2) in enumerate(matchups, start=1):\n",
" print(f\"[{idx}/{total_matchups}] Début entraînement : {name1} vs {name2}\")\n",
"\n",
" wins, rewards_history = run_tournament_match(\n",
" env=env,\n",
" agent1=agents_dict[name1],\n",
" agent2=agents_dict[name2],\n",
" episodes=1000,\n",
" train_mode=True,\n",
" epsilon_start=0.2,\n",
" )\n",
"\n",
" plot_learning_curves(rewards_history, window=10)\n",
"\n",
" print(f\"[{idx}/{total_matchups}] Fin entraînement : {name1} vs {name2}\")\n",
" if idx < total_matchups:\n",
" print(\"→ Passage au match suivant...\\n\")\n",
" else:\n",
" print(\"→ Tous les matchs d'entraînement sont terminés.\\n\")\n",
"\n",
"win_matrix, agent_names = evaluate_tournament(env, agents_dict, episodes_per_match=20)\n",
"plot_win_rate_matrix(win_matrix, agent_names)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9555e5b1",
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {

View File

@@ -89,6 +89,8 @@ ignore = [
"N803", "N803",
"N802", "N802",
"N816", "N816",
"PLR0913", # Too many arguments
"FBT001", # Boolean positional argument used in function definition
] ]
# Exclure certains fichiers ou répertoires # Exclure certains fichiers ou répertoires