Update dependencies in pyproject.toml for improved compatibility and features

- Bump catboost from 1.2.8 to 1.2.10
- Update google-api-python-client from 2.190.0 to 2.191.0
- Upgrade langchain from 1.2.0 to 1.2.10
- Update langchain-core from 1.2.16 to 1.2.17
- Upgrade langchain-huggingface from 1.2.0 to 1.2.1
- Bump marimo from 0.19.11 to 0.20.2
- Update matplotlib from 3.10.1 to 3.10.8
- Upgrade numpy from 2.2.5 to 2.4.2
- Update opencv-python from 4.11.0.86 to 4.13.0.92
- Bump pandas from 2.2.3 to 3.0.1
- Update plotly from 6.3.0 to 6.6.0
- Upgrade polars from 1.37.0 to 1.38.1
- Bump rasterio from 1.4.4 to 1.5.0
- Update scikit-learn from 1.6.1 to 1.8.0
- Upgrade scipy from 1.15.2 to 1.17.1
- Bump shap from 0.49.1 to 0.50.0
- Adjust isort section order for better readability
This commit is contained in:
2026-03-03 10:59:12 +01:00
parent 96edad169f
commit 7f2f820650
2 changed files with 222 additions and 254 deletions

View File

@@ -5,13 +5,13 @@
"id": "fef45687",
"metadata": {},
"source": [
"# RL Project: Atari Tennis Tournament (NumPy Agents)\n",
"# RL Project: Atari Tennis Tournament\n",
"\n",
"This notebook implements four Reinforcement Learning algorithms **without PyTorch** to play Atari Tennis (`ALE/Tennis-v5` via Gymnasium):\n",
"This notebook implements four Reinforcement Learning algorithms to play Atari Tennis (`ALE/Tennis-v5` via Gymnasium):\n",
"\n",
"1. **SARSA** — Semi-gradient SARSA with linear approximation (inspired by Lab 7, on-policy update from Lab 5B)\n",
"2. **Q-Learning** — Off-policy linear approximation (inspired by Lab 5B)\n",
"3. **DQN** — Deep Q-Network with pure numpy MLP, experience replay and target network (inspired by Lab 6A + classic DQN)\n",
"3. **DQN** — Deep Q-Network with PyTorch MLP, experience replay and target network (inspired by Lab 6A + classic DQN), GPU-accelerated via MPS\n",
"4. **Monte Carlo** — First-visit MC control with linear approximation (inspired by Lab 4)\n",
"\n",
"Each agent is **pre-trained independently** against the built-in Atari AI opponent, then evaluated in a comparative tournament."
@@ -19,30 +19,22 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": null,
"id": "b50d7174",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'pygame'",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[39]\u001b[39m\u001b[32m, line 9\u001b[39m\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01msupersuit\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mss\u001b[39;00m\n\u001b[32m 8\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mgymnasium\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mwrappers\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m FrameStackObservation, ResizeObservation\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpettingzoo\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01matari\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m tennis_v3\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtqdm\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mauto\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m tqdm\n\u001b[32m 12\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmatplotlib\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mpyplot\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mplt\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/atari/__init__.py:5\u001b[39m, in \u001b[36m__getattr__\u001b[39m\u001b[34m(env_name)\u001b[39m\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__getattr__\u001b[39m(env_name):\n\u001b[32m----> \u001b[39m\u001b[32m5\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mdeprecated_handler\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m__path__\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[34;43m__name__\u001b[39;49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/deprecated_module.py:70\u001b[39m, in \u001b[36mdeprecated_handler\u001b[39m\u001b[34m(env_name, module_path, module_name)\u001b[39m\n\u001b[32m 67\u001b[39m \u001b[38;5;66;03m# This executes the module and will raise any exceptions\u001b[39;00m\n\u001b[32m 68\u001b[39m \u001b[38;5;66;03m# that would typically be raised by just `import blah`\u001b[39;00m\n\u001b[32m 69\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m spec.loader\n\u001b[32m---> \u001b[39m\u001b[32m70\u001b[39m \u001b[43mspec\u001b[49m\u001b[43m.\u001b[49m\u001b[43mloader\u001b[49m\u001b[43m.\u001b[49m\u001b[43mexec_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodule\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 71\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m module\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/atari/tennis_v3.py:1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpettingzoo\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01matari\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtennis\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtennis\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m env, parallel_env, raw_env\n\u001b[32m 3\u001b[39m __all__ = [\u001b[33m\"\u001b[39m\u001b[33menv\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mparallel_env\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mraw_env\u001b[39m\u001b[33m\"\u001b[39m]\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/atari/tennis/tennis.py:77\u001b[39m\n\u001b[32m 74\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mos\u001b[39;00m\n\u001b[32m 75\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mglob\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m glob\n\u001b[32m---> \u001b[39m\u001b[32m77\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpettingzoo\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01matari\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbase_atari_env\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[32m 78\u001b[39m BaseAtariEnv,\n\u001b[32m 79\u001b[39m base_env_wrapper_fn,\n\u001b[32m 80\u001b[39m parallel_wrapper_fn,\n\u001b[32m 81\u001b[39m )\n\u001b[32m 84\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mraw_env\u001b[39m(**kwargs):\n\u001b[32m 85\u001b[39m name = os.path.basename(\u001b[34m__file__\u001b[39m).split(\u001b[33m\"\u001b[39m\u001b[33m.\u001b[39m\u001b[33m\"\u001b[39m)[\u001b[32m0\u001b[39m]\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/atari/base_atari_env.py:6\u001b[39m\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmulti_agent_ale_py\u001b[39;00m\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpygame\u001b[39;00m\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mgymnasium\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m spaces\n\u001b[32m 8\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mgymnasium\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m EzPickle, seeding\n",
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'pygame'"
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch device: cuda\n"
]
}
],
"source": [
"import itertools\n",
"import pickle\n",
"from collections import deque\n",
"from pathlib import Path\n",
"\n",
"import ale_py # noqa: F401 — registers ALE environments\n",
@@ -54,7 +46,18 @@
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import seaborn as sns\n"
"import seaborn as sns\n",
"\n",
"import torch\n",
"from torch import nn, optim\n",
"\n",
"if torch.cuda.is_available():\n",
" DEVICE = torch.device(\"cuda\")\n",
"elif torch.backends.mps.is_available():\n",
" DEVICE = torch.device(\"mps\")\n",
"else:\n",
" DEVICE = torch.device(\"cpu\")\n",
"print(f\"PyTorch device: {DEVICE}\")\n"
]
},
{
@@ -65,7 +68,14 @@
"outputs": [],
"source": [
"CHECKPOINT_DIR = Path(\"checkpoints\")\n",
"CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n"
"CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n",
"\n",
"\n",
"def _ckpt_path(name: str) -> Path:\n",
" \"\"\"Return the checkpoint path for an agent (DQN uses .pt, others use .pkl).\"\"\"\n",
" base = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\")\n",
" ext = \".pt\" if name == \"DQN\" else \".pkl\"\n",
" return CHECKPOINT_DIR / (base + ext)\n"
]
},
{
@@ -511,17 +521,18 @@
"id": "f644e2ef",
"metadata": {},
"source": [
"## DQN Agent — NumPy MLP with Experience Replay and Target Network\n",
"## DQN Agent — PyTorch MLP with Experience Replay and Target Network\n",
"\n",
"This agent implements the Deep Q-Network (DQN) **entirely in numpy**, without PyTorch.\n",
"This agent implements the Deep Q-Network (DQN) using **PyTorch** for GPU-accelerated training (MPS on Apple Silicon).\n",
"\n",
"**Network architecture** (identical to the original DQNAgent structure):\n",
"$$\\text{Input}(n\\_features) \\to \\text{Linear}(128) \\to \\text{ReLU} \\to \\text{Linear}(128) \\to \\text{ReLU} \\to \\text{Linear}(n\\_actions)$$\n",
"**Network architecture** (same structure as before, now as `torch.nn.Module`):\n",
"$$\\text{Input}(n\\_features) \\to \\text{Linear}(256) \\to \\text{ReLU} \\to \\text{Linear}(256) \\to \\text{ReLU} \\to \\text{Linear}(n\\_actions)$$\n",
"\n",
"**Key techniques** (inspired by Lab 6A Dyna-Q + classic DQN):\n",
"- **Experience Replay**: like the `self.model` dict in Dyna-Q (Lab 6A) that stores past transitions, we store transitions in a circular buffer and sample minibatches for updates.\n",
"- **Target Network**: periodically synchronized copy of the network, stabilizes learning.\n",
"- **Manual backpropagation**: MSE loss gradient backpropagated layer by layer."
"- **Experience Replay**: circular buffer of transitions, sampled as minibatches for off-policy updates\n",
"- **Target Network**: periodically synchronized copy of the Q-network, stabilizes learning\n",
"- **Gradient clipping**: prevents exploding gradients in deep networks\n",
"- **GPU acceleration**: tensors on MPS/CUDA device for fast forward/backward passes"
]
},
{
@@ -531,58 +542,64 @@
"metadata": {},
"outputs": [],
"source": [
"def _relu(x: np.ndarray) -> np.ndarray:\n",
" \"\"\"ReLU activation: max(0, x).\"\"\"\n",
" return np.maximum(0, x)\n",
"class QNetwork(nn.Module):\n",
" \"\"\"MLP Q-network: Input -> 256 -> ReLU -> 256 -> ReLU -> n_actions.\"\"\"\n",
"\n",
" def __init__(self, n_features: int, n_actions: int) -> None:\n",
" super().__init__()\n",
" self.net = nn.Sequential(\n",
" nn.Linear(n_features, 256),\n",
" nn.ReLU(),\n",
" nn.Linear(256, 256),\n",
" nn.ReLU(),\n",
" nn.Linear(256, n_actions),\n",
" )\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" return self.net(x)\n",
"\n",
"\n",
"def _relu_grad(x: np.ndarray) -> np.ndarray:\n",
" \"\"\"Compute the ReLU derivative: return 1 where x > 0, else 0.\"\"\"\n",
" return (x > 0).astype(np.float64)\n",
"class ReplayBuffer:\n",
" \"\"\"Fixed-size circular replay buffer storing (s, a, r, s', done) transitions.\"\"\"\n",
"\n",
" def __init__(self, capacity: int) -> None:\n",
" self.buffer: deque[tuple[np.ndarray, int, float, np.ndarray, bool]] = deque(maxlen=capacity)\n",
"\n",
"def _init_layer(\n",
" fan_in: int, fan_out: int, rng: np.random.Generator,\n",
") -> tuple[np.ndarray, np.ndarray]:\n",
" \"\"\"Initialize a linear layer with He initialization.\n",
" def push(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool) -> None:\n",
" self.buffer.append((state, action, reward, next_state, done))\n",
"\n",
" He init: W ~ N(0, sqrt(2/fan_in)) — standard for ReLU networks.\n",
" def sample(self, batch_size: int, rng: np.random.Generator) -> tuple[np.ndarray, ...]:\n",
" indices = rng.choice(len(self.buffer), size=batch_size, replace=False)\n",
" batch = [self.buffer[i] for i in indices]\n",
" states = np.array([t[0] for t in batch])\n",
" actions = np.array([t[1] for t in batch])\n",
" rewards = np.array([t[2] for t in batch])\n",
" next_states = np.array([t[3] for t in batch])\n",
" dones = np.array([t[4] for t in batch], dtype=np.float32)\n",
" return states, actions, rewards, next_states, dones\n",
"\n",
" Args:\n",
" fan_in: Number of input features.\n",
" fan_out: Number of output features.\n",
" rng: Random number generator.\n",
"\n",
" Returns:\n",
" Weight matrix W of shape (fan_in, fan_out) and bias vector b of shape (fan_out,).\n",
"\n",
" \"\"\"\n",
" scale = np.sqrt(2.0 / fan_in)\n",
" W = rng.normal(0, scale, size=(fan_in, fan_out)).astype(np.float64)\n",
" b = np.zeros(fan_out, dtype=np.float64)\n",
" return W, b\n",
" def __len__(self) -> int:\n",
" return len(self.buffer)\n",
"\n",
"\n",
"class DQNAgent(Agent):\n",
" \"\"\"Deep Q-Network agent with a 2-hidden-layer MLP implemented in pure numpy.\n",
" \"\"\"Deep Q-Network agent using PyTorch with GPU acceleration (MPS/CUDA).\n",
"\n",
" Architecture: Input -> Linear(128) -> ReLU -> Linear(128) -> ReLU -> Linear(n_actions)\n",
" Matches the original PyTorch DQNAgent structure.\n",
" Inspired by:\n",
" - Lab 6A Dyna-Q: experience replay (store transitions, sample for updates)\n",
" - Classic DQN (Mnih et al., 2015): target network, minibatch SGD\n",
"\n",
" Features:\n",
" - Experience replay buffer (like Dyna-Q model in Lab 6A: store past transitions, sample for updates)\n",
" - Target network (periodically synchronized copy for stable targets)\n",
" - Manual forward pass and backpropagation through the MLP\n",
" Uses Adam optimizer and Huber loss (smooth L1) for stable training.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" n_features: int,\n",
" n_actions: int,\n",
" lr: float = 0.0001,\n",
" lr: float = 1e-4,\n",
" gamma: float = 0.99,\n",
" buffer_size: int = 10000,\n",
" batch_size: int = 64,\n",
" buffer_size: int = 50_000,\n",
" batch_size: int = 128,\n",
" target_update_freq: int = 1000,\n",
" seed: int = 42,\n",
" ) -> None:\n",
@@ -591,9 +608,9 @@
" Args:\n",
" n_features: Input feature dimension.\n",
" n_actions: Number of discrete actions.\n",
" lr: Learning rate for gradient descent.\n",
" lr: Learning rate for Adam optimizer.\n",
" gamma: Discount factor.\n",
" buffer_size: Maximum size of the replay buffer.\n",
" buffer_size: Maximum replay buffer capacity.\n",
" batch_size: Minibatch size for updates.\n",
" target_update_freq: Steps between target network syncs.\n",
" seed: RNG seed.\n",
@@ -603,111 +620,33 @@
" self.n_features = n_features\n",
" self.lr = lr\n",
" self.gamma = gamma\n",
" self.buffer_size = buffer_size\n",
" self.batch_size = batch_size\n",
" self.target_update_freq = target_update_freq\n",
" self.update_step = 0\n",
"\n",
" # Initialize network parameters: 3 layers\n",
" # Layer 1: n_features -> 128\n",
" # Layer 2: 128 -> 128\n",
" # Layer 3: 128 -> n_actions\n",
" self.params = self._init_params(self.rng)\n",
" # Q-network and target network on GPU\n",
" torch.manual_seed(seed)\n",
" self.q_net = QNetwork(n_features, n_actions).to(DEVICE)\n",
" self.target_net = QNetwork(n_features, n_actions).to(DEVICE)\n",
" self.target_net.load_state_dict(self.q_net.state_dict())\n",
" self.target_net.eval()\n",
"\n",
" # Target network: deep copy of params (like Dyna-Q's model copy concept)\n",
" self.target_params = {k: v.copy() for k, v in self.params.items()}\n",
" self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)\n",
" self.loss_fn = nn.SmoothL1Loss() # Huber loss — more robust than MSE\n",
"\n",
" # Experience replay buffer: list of (s, a, r, s', done) tuples\n",
" # Analogous to Dyna-Q's self.model dict + self.observed_sa in Lab 6A,\n",
" # but storing raw transitions for off-policy sampling\n",
" self.replay_buffer: list[tuple[np.ndarray, int, float, np.ndarray, bool]] = []\n",
"\n",
" def _init_params(self, rng: np.random.Generator) -> dict[str, np.ndarray]:\n",
" \"\"\"Initialize all MLP parameters with He initialization.\"\"\"\n",
" W1, b1 = _init_layer(self.n_features, 128, rng)\n",
" W2, b2 = _init_layer(128, 128, rng)\n",
" W3, b3 = _init_layer(128, self.action_space, rng)\n",
" return {\"W1\": W1, \"b1\": b1, \"W2\": W2, \"b2\": b2, \"W3\": W3, \"b3\": b3}\n",
"\n",
" def _forward(\n",
" self, x: np.ndarray, params: dict[str, np.ndarray],\n",
" ) -> tuple[np.ndarray, dict[str, np.ndarray]]:\n",
" \"\"\"Forward pass through the MLP. Returns output and cached activations for backprop.\n",
"\n",
" Architecture: x -> Linear -> ReLU -> Linear -> ReLU -> Linear -> output\n",
"\n",
" Args:\n",
" x: Input array of shape (batch, n_features) or (n_features,).\n",
" params: Dictionary of network parameters.\n",
"\n",
" Returns:\n",
" output: Q-values of shape (batch, n_actions) or (n_actions,).\n",
" cache: Dictionary of intermediate values for backpropagation.\n",
"\n",
" \"\"\"\n",
" # Layer 1\n",
" z1 = x @ params[\"W1\"] + params[\"b1\"]\n",
" h1 = _relu(z1)\n",
" # Layer 2\n",
" z2 = h1 @ params[\"W2\"] + params[\"b2\"]\n",
" h2 = _relu(z2)\n",
" # Layer 3 (output, no activation)\n",
" out = h2 @ params[\"W3\"] + params[\"b3\"]\n",
" cache = {\"x\": x, \"z1\": z1, \"h1\": h1, \"z2\": z2, \"h2\": h2}\n",
" return out, cache\n",
"\n",
" def _backward(\n",
" self,\n",
" d_out: np.ndarray,\n",
" cache: dict[str, np.ndarray],\n",
" params: dict[str, np.ndarray],\n",
" ) -> dict[str, np.ndarray]:\n",
" \"\"\"Backward pass: compute gradients of loss w.r.t. all parameters.\n",
"\n",
" Args:\n",
" d_out: Gradient of loss w.r.t. output, shape (batch, n_actions).\n",
" cache: Cached activations from forward pass.\n",
" params: Current network parameters.\n",
"\n",
" Returns:\n",
" Dictionary of gradients for each parameter.\n",
"\n",
" \"\"\"\n",
" batch = d_out.shape[0] if d_out.ndim > 1 else 1\n",
" if d_out.ndim == 1:\n",
" d_out = d_out.reshape(1, -1)\n",
"\n",
" x = cache[\"x\"] if cache[\"x\"].ndim > 1 else cache[\"x\"].reshape(1, -1)\n",
" h1 = cache[\"h1\"] if cache[\"h1\"].ndim > 1 else cache[\"h1\"].reshape(1, -1)\n",
" h2 = cache[\"h2\"] if cache[\"h2\"].ndim > 1 else cache[\"h2\"].reshape(1, -1)\n",
" z1 = cache[\"z1\"] if cache[\"z1\"].ndim > 1 else cache[\"z1\"].reshape(1, -1)\n",
" z2 = cache[\"z2\"] if cache[\"z2\"].ndim > 1 else cache[\"z2\"].reshape(1, -1)\n",
"\n",
" # Layer 3 gradients\n",
" dW3 = h2.T @ d_out / batch\n",
" db3 = d_out.mean(axis=0)\n",
" dh2 = d_out @ params[\"W3\"].T\n",
"\n",
" # Layer 2 gradients (through ReLU)\n",
" dz2 = dh2 * _relu_grad(z2)\n",
" dW2 = h1.T @ dz2 / batch\n",
" db2 = dz2.mean(axis=0)\n",
" dh1 = dz2 @ params[\"W2\"].T\n",
"\n",
" # Layer 1 gradients (through ReLU)\n",
" dz1 = dh1 * _relu_grad(z1)\n",
" dW1 = x.T @ dz1 / batch\n",
" db1 = dz1.mean(axis=0)\n",
"\n",
" return {\"W1\": dW1, \"b1\": db1, \"W2\": dW2, \"b2\": db2, \"W3\": dW3, \"b3\": db3}\n",
" # Experience replay buffer\n",
" self.replay_buffer = ReplayBuffer(buffer_size)\n",
"\n",
" def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n",
" \"\"\"Select action using ε-greedy policy over MLP Q-values.\"\"\"\n",
" \"\"\"Select action using ε-greedy policy over Q-network outputs.\"\"\"\n",
" if self.rng.random() < epsilon:\n",
" return int(self.rng.integers(0, self.action_space))\n",
"\n",
" phi = normalize_obs(observation)\n",
" q_vals, _ = self._forward(phi, self.params)\n",
" if q_vals.ndim > 1:\n",
" q_vals = q_vals.squeeze(0)\n",
" return epsilon_greedy(q_vals, epsilon, self.rng)\n",
" with torch.no_grad():\n",
" state_t = torch.from_numpy(phi).float().unsqueeze(0).to(DEVICE)\n",
" q_vals = self.q_net(state_t).cpu().numpy().squeeze(0)\n",
" return epsilon_greedy(q_vals, 0.0, self.rng)\n",
"\n",
" def update(\n",
" self,\n",
@@ -721,66 +660,81 @@
" \"\"\"Store transition and perform a minibatch DQN update.\n",
"\n",
" Steps:\n",
" 1. Add transition to replay buffer (like Dyna-Q model update in Lab 6A)\n",
" 1. Add transition to replay buffer\n",
" 2. If buffer has enough samples, sample a minibatch\n",
" 3. Compute targets using target network (Q-learning: max_a' Q_target(s', a'))\n",
" 4. Backpropagate MSE loss through the MLP\n",
" 5. Update weights with gradient descent\n",
" 3. Compute targets using target network (max_a' Q_target(s', a'))\n",
" 4. Compute Huber loss and backpropagate\n",
" 5. Clip gradients and update weights with Adam\n",
" 6. Periodically sync target network\n",
"\n",
" \"\"\"\n",
" _ = next_action # DQN is off-policy\n",
"\n",
" # Store transition in replay buffer\n",
" # Analogous to Lab 6A: self.model[(s,a)] = (r, sp)\n",
" # Store transition\n",
" phi_s = normalize_obs(state)\n",
" phi_sp = normalize_obs(next_state)\n",
" self.replay_buffer.append((phi_s, action, reward, phi_sp, done))\n",
" if len(self.replay_buffer) > self.buffer_size:\n",
" self.replay_buffer.pop(0)\n",
" self.replay_buffer.push(phi_s, action, reward, phi_sp, done)\n",
"\n",
" # Don't update until we have enough samples\n",
" if len(self.replay_buffer) < self.batch_size:\n",
" return\n",
"\n",
" # Sample minibatch from replay buffer\n",
" # Like Dyna-Q planning: sample from observed transitions (Lab 6A)\n",
" idx = self.rng.choice(len(self.replay_buffer), size=self.batch_size, replace=False)\n",
" batch = [self.replay_buffer[i] for i in idx]\n",
" # Sample minibatch\n",
" states_b, actions_b, rewards_b, next_states_b, dones_b = self.replay_buffer.sample(\n",
" self.batch_size, self.rng,\n",
" )\n",
"\n",
" states_b = np.array([t[0] for t in batch]) # (batch, n_features)\n",
" actions_b = np.array([t[1] for t in batch]) # (batch,)\n",
" rewards_b = np.array([t[2] for t in batch]) # (batch,)\n",
" next_states_b = np.array([t[3] for t in batch]) # (batch, n_features)\n",
" dones_b = np.array([t[4] for t in batch], dtype=np.float64)\n",
" # Convert to tensors on device\n",
" states_t = torch.from_numpy(states_b).float().to(DEVICE)\n",
" actions_t = torch.from_numpy(actions_b).long().to(DEVICE)\n",
" rewards_t = torch.from_numpy(rewards_b).float().to(DEVICE)\n",
" next_states_t = torch.from_numpy(next_states_b).float().to(DEVICE)\n",
" dones_t = torch.from_numpy(dones_b).float().to(DEVICE)\n",
"\n",
" # Forward pass on current states\n",
" q_all, cache = self._forward(states_b, self.params)\n",
" # Gather Q-values for the taken actions\n",
" q_curr = q_all[np.arange(self.batch_size), actions_b] # (batch,)\n",
" # Current Q-values for taken actions\n",
" q_values = self.q_net(states_t)\n",
" q_curr = q_values.gather(1, actions_t.unsqueeze(1)).squeeze(1)\n",
"\n",
" # Compute targets using target network (off-policy: max over actions)\n",
" # Follows Lab 5B Q-learning: td_target = r + gamma * max(Q[s2]) * (0 if terminated else 1)\n",
" q_next_all, _ = self._forward(next_states_b, self.target_params)\n",
" q_next_max = np.max(q_next_all, axis=1) # (batch,)\n",
" targets = rewards_b + (1.0 - dones_b) * self.gamma * q_next_max\n",
" # Target Q-values (off-policy: max over actions in next state)\n",
" with torch.no_grad():\n",
" q_next = self.target_net(next_states_t).max(dim=1).values\n",
" targets = rewards_t + (1.0 - dones_t) * self.gamma * q_next\n",
"\n",
" # MSE loss gradient: d_loss/d_q_curr = 2 * (q_curr - targets) / batch\n",
" # We only backprop through the action that was taken\n",
" d_out = np.zeros_like(q_all)\n",
" d_out[np.arange(self.batch_size), actions_b] = 2.0 * (q_curr - targets) / self.batch_size\n",
"\n",
" # Backward pass\n",
" grads = self._backward(d_out, cache, self.params)\n",
"\n",
" # Gradient descent update\n",
" for key in self.params:\n",
" self.params[key] -= self.lr * grads[key]\n",
" # Compute loss and update\n",
" loss = self.loss_fn(q_curr, targets)\n",
" self.optimizer.zero_grad()\n",
" loss.backward()\n",
" nn.utils.clip_grad_norm_(self.q_net.parameters(), max_norm=10.0)\n",
" self.optimizer.step()\n",
"\n",
" # Sync target network periodically\n",
" self.update_step += 1\n",
" if self.update_step % self.target_update_freq == 0:\n",
" self.target_params = {k: v.copy() for k, v in self.params.items()}\n"
" self.target_net.load_state_dict(self.q_net.state_dict())\n",
"\n",
" def save(self, filename: str) -> None:\n",
" \"\"\"Save agent state using torch.save (networks + optimizer + metadata).\"\"\"\n",
" torch.save(\n",
" {\n",
" \"q_net\": self.q_net.state_dict(),\n",
" \"target_net\": self.target_net.state_dict(),\n",
" \"optimizer\": self.optimizer.state_dict(),\n",
" \"update_step\": self.update_step,\n",
" \"n_features\": self.n_features,\n",
" \"action_space\": self.action_space,\n",
" },\n",
" filename,\n",
" )\n",
"\n",
" def load(self, filename: str) -> None:\n",
" \"\"\"Load agent state from a torch checkpoint.\"\"\"\n",
" checkpoint = torch.load(filename, map_location=DEVICE, weights_only=False)\n",
" self.q_net.load_state_dict(checkpoint[\"q_net\"])\n",
" self.target_net.load_state_dict(checkpoint[\"target_net\"])\n",
" self.optimizer.load_state_dict(checkpoint[\"optimizer\"])\n",
" self.update_step = checkpoint[\"update_step\"]\n",
" self.q_net.to(DEVICE)\n",
" self.target_net.to(DEVICE)\n",
" self.target_net.eval()\n"
]
},
{
@@ -1226,12 +1180,12 @@
"- **Random** — random baseline (no training needed)\n",
"- **SARSA** — linear approximation, semi-gradient TD(0)\n",
"- **Q-Learning** — linear approximation, off-policy\n",
"- **DQN** — NumPy MLP (2 hidden layers of 128, experience replay, target network)\n",
"- **DQN** — PyTorch MLP (2 hidden layers of 256, experience replay, target network, GPU-accelerated)\n",
"- **Monte Carlo** — linear approximation, first-visit returns\n",
"\n",
"**Workflow**:\n",
"1. Train **one** selected agent (`AGENT_TO_TRAIN`)\n",
"2. Save its weights to `checkpoints/<agent>.pkl`\n",
"2. Save its weights to `checkpoints/` (`.pkl` for linear agents, `.pt` for DQN)\n",
"3. Repeat later for another agent without retraining previous ones\n",
"4. Load all saved checkpoints before the final evaluation"
]
@@ -1282,7 +1236,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 23,
"id": "4d449701",
"metadata": {},
"outputs": [
@@ -1290,39 +1244,45 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Selected agent: Q-Learning\n",
"Checkpoint path: checkpoints/q_learning.pkl\n",
"Selected agent: DQN\n",
"Checkpoint path: checkpoints/dqn.pt\n",
"\n",
"============================================================\n",
"Training: Q-Learning (2500 episodes)\n",
"Training: DQN (2500 episodes)\n",
"============================================================\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "02b31c90057641a09f7e6bb3f3c53f79",
"model_id": "47c9af1f459346dea335fb822a091b1b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training Q-Learning: 0%| | 0/2500 [00:00<?, ?it/s]"
"Training DQN: 0%| | 0/2500 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"-> Q-Learning avg reward (last 100 eps): -0.19\n",
"Checkpoint saved.\n"
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_730/423872786.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{'='*60}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m training_histories[AGENT_TO_TRAIN] = train_agent(\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0magent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0magent\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/tmp/ipykernel_730/4063369955.py\u001b[0m in \u001b[0;36mtrain_agent\u001b[0;34m(env, agent, name, episodes, epsilon_start, epsilon_end, epsilon_decay, max_steps)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;31m# Update agent with the transition\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m agent.update(\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mobs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/tmp/ipykernel_730/4071728682.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, state, action, reward, next_state, done, next_action)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;31m# Sample minibatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 138\u001b[0;31m states_b, actions_b, rewards_b, next_states_b, dones_b = self.replay_buffer.sample(\n\u001b[0m\u001b[1;32m 139\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m )\n",
"\u001b[0;32m/tmp/ipykernel_730/4071728682.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(self, batch_size, rng)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrng\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchoice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mstates\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0mactions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0mrewards\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"AGENT_TO_TRAIN = \"Q-Learning\" # TODO: change to: \"Q-Learning\", \"DQN\", \"Monte Carlo\", \"Random\"\n",
"AGENT_TO_TRAIN = \"DQN\" # TODO: change to: \"Q-Learning\", \"DQN\", \"Monte Carlo\", \"Random\"\n",
"TRAINING_EPISODES = 2500\n",
"FORCE_RETRAIN = False\n",
"\n",
@@ -1332,8 +1292,7 @@
"\n",
"training_histories: dict[str, list[float]] = {}\n",
"agent = agents[AGENT_TO_TRAIN]\n",
"ckpt_name = AGENT_TO_TRAIN.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n",
"ckpt_path = CHECKPOINT_DIR / ckpt_name\n",
"ckpt_path = _ckpt_path(AGENT_TO_TRAIN)\n",
"\n",
"print(f\"Selected agent: {AGENT_TO_TRAIN}\")\n",
"print(f\"Checkpoint path: {ckpt_path}\")\n",
@@ -1369,7 +1328,7 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": null,
"id": "a13a65df",
"metadata": {},
"outputs": [
@@ -1400,14 +1359,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded checkpoints: ['SARSA']\n",
"Missing checkpoints (untrained/not saved yet): ['Q-Learning', 'DQN', 'Monte Carlo']\n"
"Loaded checkpoints: ['SARSA', 'Q-Learning']\n",
"Missing checkpoints (untrained/not saved yet): ['DQN', 'Monte Carlo']\n"
]
}
],
"source": [
"# Load all available checkpoints before final evaluation\n",
"CHECKPOINT_DIR = Path(\"checkpoints\")\n",
"loaded_agents: list[str] = []\n",
"missing_agents: list[str] = []\n",
"\n",
@@ -1415,8 +1373,7 @@
" if name == \"Random\":\n",
" continue\n",
"\n",
" ckpt_name = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n",
" ckpt_path = CHECKPOINT_DIR / ckpt_name\n",
" ckpt_path = _ckpt_path(name)\n",
"\n",
" if ckpt_path.exists():\n",
" agent.load(str(ckpt_path))\n",
@@ -1455,8 +1412,7 @@
" if name == \"Random\":\n",
" continue\n",
"\n",
" ckpt_name = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n",
" ckpt_path = CHECKPOINT_DIR / ckpt_name\n",
" ckpt_path = _ckpt_path(name)\n",
"\n",
" if ckpt_path.exists():\n",
" agent.load(str(ckpt_path))\n",
@@ -1551,8 +1507,7 @@
" # Keep only agents that have a checkpoint\n",
" ready_names: list[str] = []\n",
" for name in candidate_names:\n",
" ckpt_name = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n",
" ckpt_path = checkpoint_dir / ckpt_name\n",
" ckpt_path = _ckpt_path(name)\n",
" if ckpt_path.exists():\n",
" agents[name].load(str(ckpt_path))\n",
" ready_names.append(name)\n",
@@ -1658,7 +1613,7 @@
"\n",
"This tournament uses `from pettingzoo.atari import tennis_v3` to make trained agents play against each other directly.\n",
"\n",
"- Checkpoints are loaded from `checkpoints/*.pkl`\n",
"- Checkpoints are loaded from `checkpoints/` (`.pkl` for linear agents, `.pt` for DQN)\n",
"- `Random` is **excluded** from ranking\n",
"- Each pair plays in both seat positions (`first_0` and `second_0`) to reduce position bias\n",
"- A win-rate matrix and final ranking are produced"
@@ -1667,7 +1622,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "studies (3.13.9)",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -1681,7 +1636,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.9"
"version": "3.12.12"
}
},
"nbformat": 4,

View File

@@ -7,48 +7,48 @@ requires-python = ">= 3.11"
dependencies = [
"accelerate>=1.12.0",
"ale-py>=0.11.2",
"catboost>=1.2.8",
"catboost>=1.2.10",
"datasets>=4.6.1",
"faiss-cpu>=1.13.2",
"folium>=0.20.0",
"geopandas>=1.1.2",
"google-api-python-client>=2.190.0",
"google-api-python-client>=2.191.0",
"google-auth-oauthlib>=1.3.0",
"google-generativeai>=0.8.6",
"gymnasium[toy-text]>=1.2.3",
"imblearn>=0.0",
"ipykernel>=7.2.0",
"ipywidgets>=8.1.8",
"langchain>=1.2.0",
"langchain>=1.2.10",
"langchain-community>=0.4.1",
"langchain-core>=1.2.16",
"langchain-huggingface>=1.2.0",
"langchain-core>=1.2.17",
"langchain-huggingface>=1.2.1",
"langchain-mistralai>=1.1.1",
"langchain-ollama>=1.0.1",
"langchain-openai>=1.1.10",
"langchain-text-splitters>=1.1.1",
"mapclassify>=2.10.0",
"marimo>=0.19.11",
"matplotlib>=3.10.1",
"marimo>=0.20.2",
"matplotlib>=3.10.8",
"nbformat>=5.10.4",
"numpy>=2.2.5",
"opencv-python>=4.11.0.86",
"numpy>=2.4.2",
"opencv-python>=4.13.0.92",
"openpyxl>=3.1.5",
"pandas>=2.2.3",
"pandas>=3.0.1",
"pandas-stubs>=3.0.0.260204",
"pettingzoo[atari]>=1.25.0",
"plotly>=6.3.0",
"polars>=1.37.0",
"plotly>=6.6.0",
"polars>=1.38.1",
"pyclustertend>=1.9.0",
"pypdf>=6.7.5",
"rasterio>=1.4.4",
"rasterio>=1.5.0",
"requests>=2.32.5",
"scikit-learn>=1.6.1",
"scipy>=1.15.2",
"scikit-learn>=1.8.0",
"scipy>=1.17.1",
"seaborn>=0.13.2",
"sentence-transformers>=5.2.3",
# "sequenzo>=0.1.20",
"shap>=0.49.1",
"shap>=0.50.0",
"spacy>=3.8.11",
"statsmodels>=0.14.6",
"supersuit>=3.10.0",
@@ -65,10 +65,7 @@ dependencies = [
]
[dependency-groups]
dev = [
"ipykernel>=7.2.0",
"uv>=0.10.7",
]
dev = ["ipykernel>=7.2.0", "uv>=0.10.7"]
[tool.ty.rules]
index-out-of-bounds = "ignore"
@@ -83,18 +80,18 @@ select = ["ALL"]
# Désactiver certaines règles
ignore = [
"E501", # line too long, géré par le formatter
"E402", # Imports in top of file
"T201", # Print
"E501", # line too long, géré par le formatter
"E402", # Imports in top of file
"T201", # Print
"N806",
"N803",
"N802",
"N816",
"PLR0913", # Too many arguments
"PLR2004", #
"FBT001", # Boolean positional argument used in function definition
"FBT001", # Boolean positional argument used in function definition
"EM101",
"TRY003", # Avoid specifying long messages outside the exception class
"TRY003", # Avoid specifying long messages outside the exception class
]
# Exclure certains fichiers ou répertoires
@@ -127,9 +124,25 @@ unfixable = []
[tool.ruff.lint.isort]
# Regrouper les imports par thématiques
section-order = ["future", "standard-library", "third-party", "data-science", "ml", "first-party", "local-folder"]
section-order = [
"future",
"standard-library",
"third-party",
"data-science",
"ml",
"first-party",
"local-folder",
]
[tool.ruff.lint.isort.sections]
# On sépare les outils de manipulation de données des frameworks de ML lourds
"data-science" = ["numpy", "pandas", "scipy", "matplotlib", "seaborn", "plotly"]
"ml" = ["tensorflow", "keras", "torch", "sklearn", "xgboost", "catboost", "shap"]
"ml" = [
"tensorflow",
"keras",
"torch",
"sklearn",
"xgboost",
"catboost",
"shap",
]