From 7f2f820650262e4fb1688ccf0ec5ab36c4be8129 Mon Sep 17 00:00:00 2001 From: Arthur DANJOU Date: Tue, 3 Mar 2026 10:59:12 +0100 Subject: [PATCH] Update dependencies in pyproject.toml for improved compatibility and features - Bump catboost from 1.2.8 to 1.2.10 - Update google-api-python-client from 2.190.0 to 2.191.0 - Upgrade langchain from 1.2.0 to 1.2.10 - Update langchain-core from 1.2.16 to 1.2.17 - Upgrade langchain-huggingface from 1.2.0 to 1.2.1 - Bump marimo from 0.19.11 to 0.20.2 - Update matplotlib from 3.10.1 to 3.10.8 - Upgrade numpy from 2.2.5 to 2.4.2 - Update opencv-python from 4.11.0.86 to 4.13.0.92 - Bump pandas from 2.2.3 to 3.0.1 - Update plotly from 6.3.0 to 6.6.0 - Upgrade polars from 1.37.0 to 1.38.1 - Bump rasterio from 1.4.4 to 1.5.0 - Update scikit-learn from 1.6.1 to 1.8.0 - Upgrade scipy from 1.15.2 to 1.17.1 - Bump shap from 0.49.1 to 0.50.0 - Adjust isort section order for better readability --- .../project/Project.ipynb | 409 ++++++++---------- pyproject.toml | 67 +-- 2 files changed, 222 insertions(+), 254 deletions(-) diff --git a/M2/Reinforcement Learning/project/Project.ipynb b/M2/Reinforcement Learning/project/Project.ipynb index 69f0145..01e5188 100644 --- a/M2/Reinforcement Learning/project/Project.ipynb +++ b/M2/Reinforcement Learning/project/Project.ipynb @@ -5,13 +5,13 @@ "id": "fef45687", "metadata": {}, "source": [ - "# RL Project: Atari Tennis Tournament (NumPy Agents)\n", + "# RL Project: Atari Tennis Tournament\n", "\n", - "This notebook implements four Reinforcement Learning algorithms **without PyTorch** to play Atari Tennis (`ALE/Tennis-v5` via Gymnasium):\n", + "This notebook implements four Reinforcement Learning algorithms to play Atari Tennis (`ALE/Tennis-v5` via Gymnasium):\n", "\n", "1. **SARSA** — Semi-gradient SARSA with linear approximation (inspired by Lab 7, on-policy update from Lab 5B)\n", "2. **Q-Learning** — Off-policy linear approximation (inspired by Lab 5B)\n", - "3. **DQN** — Deep Q-Network with pure numpy MLP, experience replay and target network (inspired by Lab 6A + classic DQN)\n", + "3. **DQN** — Deep Q-Network with PyTorch MLP, experience replay and target network (inspired by Lab 6A + classic DQN), GPU-accelerated via MPS\n", "4. **Monte Carlo** — First-visit MC control with linear approximation (inspired by Lab 4)\n", "\n", "Each agent is **pre-trained independently** against the built-in Atari AI opponent, then evaluated in a comparative tournament." @@ -19,30 +19,22 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "id": "b50d7174", "metadata": {}, "outputs": [ { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'pygame'", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[39]\u001b[39m\u001b[32m, line 9\u001b[39m\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01msupersuit\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mss\u001b[39;00m\n\u001b[32m 8\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mgymnasium\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mwrappers\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m FrameStackObservation, ResizeObservation\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpettingzoo\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01matari\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m tennis_v3\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtqdm\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mauto\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m tqdm\n\u001b[32m 12\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmatplotlib\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mpyplot\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mplt\u001b[39;00m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/atari/__init__.py:5\u001b[39m, in \u001b[36m__getattr__\u001b[39m\u001b[34m(env_name)\u001b[39m\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__getattr__\u001b[39m(env_name):\n\u001b[32m----> \u001b[39m\u001b[32m5\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mdeprecated_handler\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m__path__\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[34;43m__name__\u001b[39;49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/utils/deprecated_module.py:70\u001b[39m, in \u001b[36mdeprecated_handler\u001b[39m\u001b[34m(env_name, module_path, module_name)\u001b[39m\n\u001b[32m 67\u001b[39m \u001b[38;5;66;03m# This executes the module and will raise any exceptions\u001b[39;00m\n\u001b[32m 68\u001b[39m \u001b[38;5;66;03m# that would typically be raised by just `import blah`\u001b[39;00m\n\u001b[32m 69\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m spec.loader\n\u001b[32m---> \u001b[39m\u001b[32m70\u001b[39m \u001b[43mspec\u001b[49m\u001b[43m.\u001b[49m\u001b[43mloader\u001b[49m\u001b[43m.\u001b[49m\u001b[43mexec_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodule\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 71\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m module\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/atari/tennis_v3.py:1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpettingzoo\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01matari\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtennis\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtennis\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m env, parallel_env, raw_env\n\u001b[32m 3\u001b[39m __all__ = [\u001b[33m\"\u001b[39m\u001b[33menv\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mparallel_env\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mraw_env\u001b[39m\u001b[33m\"\u001b[39m]\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/atari/tennis/tennis.py:77\u001b[39m\n\u001b[32m 74\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mos\u001b[39;00m\n\u001b[32m 75\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mglob\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m glob\n\u001b[32m---> \u001b[39m\u001b[32m77\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpettingzoo\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01matari\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbase_atari_env\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[32m 78\u001b[39m BaseAtariEnv,\n\u001b[32m 79\u001b[39m base_env_wrapper_fn,\n\u001b[32m 80\u001b[39m parallel_wrapper_fn,\n\u001b[32m 81\u001b[39m )\n\u001b[32m 84\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mraw_env\u001b[39m(**kwargs):\n\u001b[32m 85\u001b[39m name = os.path.basename(\u001b[34m__file__\u001b[39m).split(\u001b[33m\"\u001b[39m\u001b[33m.\u001b[39m\u001b[33m\"\u001b[39m)[\u001b[32m0\u001b[39m]\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/Workspace/studies/.venv/lib/python3.13/site-packages/pettingzoo/atari/base_atari_env.py:6\u001b[39m\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmulti_agent_ale_py\u001b[39;00m\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpygame\u001b[39;00m\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mgymnasium\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m spaces\n\u001b[32m 8\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mgymnasium\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m EzPickle, seeding\n", - "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'pygame'" + "name": "stdout", + "output_type": "stream", + "text": [ + "PyTorch device: cuda\n" ] } ], "source": [ "import itertools\n", "import pickle\n", + "from collections import deque\n", "from pathlib import Path\n", "\n", "import ale_py # noqa: F401 — registers ALE environments\n", @@ -54,7 +46,18 @@ "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "import seaborn as sns\n" + "import seaborn as sns\n", + "\n", + "import torch\n", + "from torch import nn, optim\n", + "\n", + "if torch.cuda.is_available():\n", + " DEVICE = torch.device(\"cuda\")\n", + "elif torch.backends.mps.is_available():\n", + " DEVICE = torch.device(\"mps\")\n", + "else:\n", + " DEVICE = torch.device(\"cpu\")\n", + "print(f\"PyTorch device: {DEVICE}\")\n" ] }, { @@ -65,7 +68,14 @@ "outputs": [], "source": [ "CHECKPOINT_DIR = Path(\"checkpoints\")\n", - "CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n" + "CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n", + "\n", + "\n", + "def _ckpt_path(name: str) -> Path:\n", + " \"\"\"Return the checkpoint path for an agent (DQN uses .pt, others use .pkl).\"\"\"\n", + " base = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\")\n", + " ext = \".pt\" if name == \"DQN\" else \".pkl\"\n", + " return CHECKPOINT_DIR / (base + ext)\n" ] }, { @@ -511,17 +521,18 @@ "id": "f644e2ef", "metadata": {}, "source": [ - "## DQN Agent — NumPy MLP with Experience Replay and Target Network\n", + "## DQN Agent — PyTorch MLP with Experience Replay and Target Network\n", "\n", - "This agent implements the Deep Q-Network (DQN) **entirely in numpy**, without PyTorch.\n", + "This agent implements the Deep Q-Network (DQN) using **PyTorch** for GPU-accelerated training (MPS on Apple Silicon).\n", "\n", - "**Network architecture** (identical to the original DQNAgent structure):\n", - "$$\\text{Input}(n\\_features) \\to \\text{Linear}(128) \\to \\text{ReLU} \\to \\text{Linear}(128) \\to \\text{ReLU} \\to \\text{Linear}(n\\_actions)$$\n", + "**Network architecture** (same structure as before, now as `torch.nn.Module`):\n", + "$$\\text{Input}(n\\_features) \\to \\text{Linear}(256) \\to \\text{ReLU} \\to \\text{Linear}(256) \\to \\text{ReLU} \\to \\text{Linear}(n\\_actions)$$\n", "\n", "**Key techniques** (inspired by Lab 6A Dyna-Q + classic DQN):\n", - "- **Experience Replay**: like the `self.model` dict in Dyna-Q (Lab 6A) that stores past transitions, we store transitions in a circular buffer and sample minibatches for updates.\n", - "- **Target Network**: periodically synchronized copy of the network, stabilizes learning.\n", - "- **Manual backpropagation**: MSE loss gradient backpropagated layer by layer." + "- **Experience Replay**: circular buffer of transitions, sampled as minibatches for off-policy updates\n", + "- **Target Network**: periodically synchronized copy of the Q-network, stabilizes learning\n", + "- **Gradient clipping**: prevents exploding gradients in deep networks\n", + "- **GPU acceleration**: tensors on MPS/CUDA device for fast forward/backward passes" ] }, { @@ -531,58 +542,64 @@ "metadata": {}, "outputs": [], "source": [ - "def _relu(x: np.ndarray) -> np.ndarray:\n", - " \"\"\"ReLU activation: max(0, x).\"\"\"\n", - " return np.maximum(0, x)\n", + "class QNetwork(nn.Module):\n", + " \"\"\"MLP Q-network: Input -> 256 -> ReLU -> 256 -> ReLU -> n_actions.\"\"\"\n", + "\n", + " def __init__(self, n_features: int, n_actions: int) -> None:\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(n_features, 256),\n", + " nn.ReLU(),\n", + " nn.Linear(256, 256),\n", + " nn.ReLU(),\n", + " nn.Linear(256, n_actions),\n", + " )\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " return self.net(x)\n", "\n", "\n", - "def _relu_grad(x: np.ndarray) -> np.ndarray:\n", - " \"\"\"Compute the ReLU derivative: return 1 where x > 0, else 0.\"\"\"\n", - " return (x > 0).astype(np.float64)\n", + "class ReplayBuffer:\n", + " \"\"\"Fixed-size circular replay buffer storing (s, a, r, s', done) transitions.\"\"\"\n", "\n", + " def __init__(self, capacity: int) -> None:\n", + " self.buffer: deque[tuple[np.ndarray, int, float, np.ndarray, bool]] = deque(maxlen=capacity)\n", "\n", - "def _init_layer(\n", - " fan_in: int, fan_out: int, rng: np.random.Generator,\n", - ") -> tuple[np.ndarray, np.ndarray]:\n", - " \"\"\"Initialize a linear layer with He initialization.\n", + " def push(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool) -> None:\n", + " self.buffer.append((state, action, reward, next_state, done))\n", "\n", - " He init: W ~ N(0, sqrt(2/fan_in)) — standard for ReLU networks.\n", + " def sample(self, batch_size: int, rng: np.random.Generator) -> tuple[np.ndarray, ...]:\n", + " indices = rng.choice(len(self.buffer), size=batch_size, replace=False)\n", + " batch = [self.buffer[i] for i in indices]\n", + " states = np.array([t[0] for t in batch])\n", + " actions = np.array([t[1] for t in batch])\n", + " rewards = np.array([t[2] for t in batch])\n", + " next_states = np.array([t[3] for t in batch])\n", + " dones = np.array([t[4] for t in batch], dtype=np.float32)\n", + " return states, actions, rewards, next_states, dones\n", "\n", - " Args:\n", - " fan_in: Number of input features.\n", - " fan_out: Number of output features.\n", - " rng: Random number generator.\n", - "\n", - " Returns:\n", - " Weight matrix W of shape (fan_in, fan_out) and bias vector b of shape (fan_out,).\n", - "\n", - " \"\"\"\n", - " scale = np.sqrt(2.0 / fan_in)\n", - " W = rng.normal(0, scale, size=(fan_in, fan_out)).astype(np.float64)\n", - " b = np.zeros(fan_out, dtype=np.float64)\n", - " return W, b\n", + " def __len__(self) -> int:\n", + " return len(self.buffer)\n", "\n", "\n", "class DQNAgent(Agent):\n", - " \"\"\"Deep Q-Network agent with a 2-hidden-layer MLP implemented in pure numpy.\n", + " \"\"\"Deep Q-Network agent using PyTorch with GPU acceleration (MPS/CUDA).\n", "\n", - " Architecture: Input -> Linear(128) -> ReLU -> Linear(128) -> ReLU -> Linear(n_actions)\n", - " Matches the original PyTorch DQNAgent structure.\n", + " Inspired by:\n", + " - Lab 6A Dyna-Q: experience replay (store transitions, sample for updates)\n", + " - Classic DQN (Mnih et al., 2015): target network, minibatch SGD\n", "\n", - " Features:\n", - " - Experience replay buffer (like Dyna-Q model in Lab 6A: store past transitions, sample for updates)\n", - " - Target network (periodically synchronized copy for stable targets)\n", - " - Manual forward pass and backpropagation through the MLP\n", + " Uses Adam optimizer and Huber loss (smooth L1) for stable training.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " n_features: int,\n", " n_actions: int,\n", - " lr: float = 0.0001,\n", + " lr: float = 1e-4,\n", " gamma: float = 0.99,\n", - " buffer_size: int = 10000,\n", - " batch_size: int = 64,\n", + " buffer_size: int = 50_000,\n", + " batch_size: int = 128,\n", " target_update_freq: int = 1000,\n", " seed: int = 42,\n", " ) -> None:\n", @@ -591,9 +608,9 @@ " Args:\n", " n_features: Input feature dimension.\n", " n_actions: Number of discrete actions.\n", - " lr: Learning rate for gradient descent.\n", + " lr: Learning rate for Adam optimizer.\n", " gamma: Discount factor.\n", - " buffer_size: Maximum size of the replay buffer.\n", + " buffer_size: Maximum replay buffer capacity.\n", " batch_size: Minibatch size for updates.\n", " target_update_freq: Steps between target network syncs.\n", " seed: RNG seed.\n", @@ -603,111 +620,33 @@ " self.n_features = n_features\n", " self.lr = lr\n", " self.gamma = gamma\n", - " self.buffer_size = buffer_size\n", " self.batch_size = batch_size\n", " self.target_update_freq = target_update_freq\n", " self.update_step = 0\n", "\n", - " # Initialize network parameters: 3 layers\n", - " # Layer 1: n_features -> 128\n", - " # Layer 2: 128 -> 128\n", - " # Layer 3: 128 -> n_actions\n", - " self.params = self._init_params(self.rng)\n", + " # Q-network and target network on GPU\n", + " torch.manual_seed(seed)\n", + " self.q_net = QNetwork(n_features, n_actions).to(DEVICE)\n", + " self.target_net = QNetwork(n_features, n_actions).to(DEVICE)\n", + " self.target_net.load_state_dict(self.q_net.state_dict())\n", + " self.target_net.eval()\n", "\n", - " # Target network: deep copy of params (like Dyna-Q's model copy concept)\n", - " self.target_params = {k: v.copy() for k, v in self.params.items()}\n", + " self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)\n", + " self.loss_fn = nn.SmoothL1Loss() # Huber loss — more robust than MSE\n", "\n", - " # Experience replay buffer: list of (s, a, r, s', done) tuples\n", - " # Analogous to Dyna-Q's self.model dict + self.observed_sa in Lab 6A,\n", - " # but storing raw transitions for off-policy sampling\n", - " self.replay_buffer: list[tuple[np.ndarray, int, float, np.ndarray, bool]] = []\n", - "\n", - " def _init_params(self, rng: np.random.Generator) -> dict[str, np.ndarray]:\n", - " \"\"\"Initialize all MLP parameters with He initialization.\"\"\"\n", - " W1, b1 = _init_layer(self.n_features, 128, rng)\n", - " W2, b2 = _init_layer(128, 128, rng)\n", - " W3, b3 = _init_layer(128, self.action_space, rng)\n", - " return {\"W1\": W1, \"b1\": b1, \"W2\": W2, \"b2\": b2, \"W3\": W3, \"b3\": b3}\n", - "\n", - " def _forward(\n", - " self, x: np.ndarray, params: dict[str, np.ndarray],\n", - " ) -> tuple[np.ndarray, dict[str, np.ndarray]]:\n", - " \"\"\"Forward pass through the MLP. Returns output and cached activations for backprop.\n", - "\n", - " Architecture: x -> Linear -> ReLU -> Linear -> ReLU -> Linear -> output\n", - "\n", - " Args:\n", - " x: Input array of shape (batch, n_features) or (n_features,).\n", - " params: Dictionary of network parameters.\n", - "\n", - " Returns:\n", - " output: Q-values of shape (batch, n_actions) or (n_actions,).\n", - " cache: Dictionary of intermediate values for backpropagation.\n", - "\n", - " \"\"\"\n", - " # Layer 1\n", - " z1 = x @ params[\"W1\"] + params[\"b1\"]\n", - " h1 = _relu(z1)\n", - " # Layer 2\n", - " z2 = h1 @ params[\"W2\"] + params[\"b2\"]\n", - " h2 = _relu(z2)\n", - " # Layer 3 (output, no activation)\n", - " out = h2 @ params[\"W3\"] + params[\"b3\"]\n", - " cache = {\"x\": x, \"z1\": z1, \"h1\": h1, \"z2\": z2, \"h2\": h2}\n", - " return out, cache\n", - "\n", - " def _backward(\n", - " self,\n", - " d_out: np.ndarray,\n", - " cache: dict[str, np.ndarray],\n", - " params: dict[str, np.ndarray],\n", - " ) -> dict[str, np.ndarray]:\n", - " \"\"\"Backward pass: compute gradients of loss w.r.t. all parameters.\n", - "\n", - " Args:\n", - " d_out: Gradient of loss w.r.t. output, shape (batch, n_actions).\n", - " cache: Cached activations from forward pass.\n", - " params: Current network parameters.\n", - "\n", - " Returns:\n", - " Dictionary of gradients for each parameter.\n", - "\n", - " \"\"\"\n", - " batch = d_out.shape[0] if d_out.ndim > 1 else 1\n", - " if d_out.ndim == 1:\n", - " d_out = d_out.reshape(1, -1)\n", - "\n", - " x = cache[\"x\"] if cache[\"x\"].ndim > 1 else cache[\"x\"].reshape(1, -1)\n", - " h1 = cache[\"h1\"] if cache[\"h1\"].ndim > 1 else cache[\"h1\"].reshape(1, -1)\n", - " h2 = cache[\"h2\"] if cache[\"h2\"].ndim > 1 else cache[\"h2\"].reshape(1, -1)\n", - " z1 = cache[\"z1\"] if cache[\"z1\"].ndim > 1 else cache[\"z1\"].reshape(1, -1)\n", - " z2 = cache[\"z2\"] if cache[\"z2\"].ndim > 1 else cache[\"z2\"].reshape(1, -1)\n", - "\n", - " # Layer 3 gradients\n", - " dW3 = h2.T @ d_out / batch\n", - " db3 = d_out.mean(axis=0)\n", - " dh2 = d_out @ params[\"W3\"].T\n", - "\n", - " # Layer 2 gradients (through ReLU)\n", - " dz2 = dh2 * _relu_grad(z2)\n", - " dW2 = h1.T @ dz2 / batch\n", - " db2 = dz2.mean(axis=0)\n", - " dh1 = dz2 @ params[\"W2\"].T\n", - "\n", - " # Layer 1 gradients (through ReLU)\n", - " dz1 = dh1 * _relu_grad(z1)\n", - " dW1 = x.T @ dz1 / batch\n", - " db1 = dz1.mean(axis=0)\n", - "\n", - " return {\"W1\": dW1, \"b1\": db1, \"W2\": dW2, \"b2\": db2, \"W3\": dW3, \"b3\": db3}\n", + " # Experience replay buffer\n", + " self.replay_buffer = ReplayBuffer(buffer_size)\n", "\n", " def get_action(self, observation: np.ndarray, epsilon: float = 0.0) -> int:\n", - " \"\"\"Select action using ε-greedy policy over MLP Q-values.\"\"\"\n", + " \"\"\"Select action using ε-greedy policy over Q-network outputs.\"\"\"\n", + " if self.rng.random() < epsilon:\n", + " return int(self.rng.integers(0, self.action_space))\n", + "\n", " phi = normalize_obs(observation)\n", - " q_vals, _ = self._forward(phi, self.params)\n", - " if q_vals.ndim > 1:\n", - " q_vals = q_vals.squeeze(0)\n", - " return epsilon_greedy(q_vals, epsilon, self.rng)\n", + " with torch.no_grad():\n", + " state_t = torch.from_numpy(phi).float().unsqueeze(0).to(DEVICE)\n", + " q_vals = self.q_net(state_t).cpu().numpy().squeeze(0)\n", + " return epsilon_greedy(q_vals, 0.0, self.rng)\n", "\n", " def update(\n", " self,\n", @@ -721,66 +660,81 @@ " \"\"\"Store transition and perform a minibatch DQN update.\n", "\n", " Steps:\n", - " 1. Add transition to replay buffer (like Dyna-Q model update in Lab 6A)\n", + " 1. Add transition to replay buffer\n", " 2. If buffer has enough samples, sample a minibatch\n", - " 3. Compute targets using target network (Q-learning: max_a' Q_target(s', a'))\n", - " 4. Backpropagate MSE loss through the MLP\n", - " 5. Update weights with gradient descent\n", + " 3. Compute targets using target network (max_a' Q_target(s', a'))\n", + " 4. Compute Huber loss and backpropagate\n", + " 5. Clip gradients and update weights with Adam\n", " 6. Periodically sync target network\n", "\n", " \"\"\"\n", " _ = next_action # DQN is off-policy\n", "\n", - " # Store transition in replay buffer\n", - " # Analogous to Lab 6A: self.model[(s,a)] = (r, sp)\n", + " # Store transition\n", " phi_s = normalize_obs(state)\n", " phi_sp = normalize_obs(next_state)\n", - " self.replay_buffer.append((phi_s, action, reward, phi_sp, done))\n", - " if len(self.replay_buffer) > self.buffer_size:\n", - " self.replay_buffer.pop(0)\n", + " self.replay_buffer.push(phi_s, action, reward, phi_sp, done)\n", "\n", - " # Don't update until we have enough samples\n", " if len(self.replay_buffer) < self.batch_size:\n", " return\n", "\n", - " # Sample minibatch from replay buffer\n", - " # Like Dyna-Q planning: sample from observed transitions (Lab 6A)\n", - " idx = self.rng.choice(len(self.replay_buffer), size=self.batch_size, replace=False)\n", - " batch = [self.replay_buffer[i] for i in idx]\n", + " # Sample minibatch\n", + " states_b, actions_b, rewards_b, next_states_b, dones_b = self.replay_buffer.sample(\n", + " self.batch_size, self.rng,\n", + " )\n", "\n", - " states_b = np.array([t[0] for t in batch]) # (batch, n_features)\n", - " actions_b = np.array([t[1] for t in batch]) # (batch,)\n", - " rewards_b = np.array([t[2] for t in batch]) # (batch,)\n", - " next_states_b = np.array([t[3] for t in batch]) # (batch, n_features)\n", - " dones_b = np.array([t[4] for t in batch], dtype=np.float64)\n", + " # Convert to tensors on device\n", + " states_t = torch.from_numpy(states_b).float().to(DEVICE)\n", + " actions_t = torch.from_numpy(actions_b).long().to(DEVICE)\n", + " rewards_t = torch.from_numpy(rewards_b).float().to(DEVICE)\n", + " next_states_t = torch.from_numpy(next_states_b).float().to(DEVICE)\n", + " dones_t = torch.from_numpy(dones_b).float().to(DEVICE)\n", "\n", - " # Forward pass on current states\n", - " q_all, cache = self._forward(states_b, self.params)\n", - " # Gather Q-values for the taken actions\n", - " q_curr = q_all[np.arange(self.batch_size), actions_b] # (batch,)\n", + " # Current Q-values for taken actions\n", + " q_values = self.q_net(states_t)\n", + " q_curr = q_values.gather(1, actions_t.unsqueeze(1)).squeeze(1)\n", "\n", - " # Compute targets using target network (off-policy: max over actions)\n", - " # Follows Lab 5B Q-learning: td_target = r + gamma * max(Q[s2]) * (0 if terminated else 1)\n", - " q_next_all, _ = self._forward(next_states_b, self.target_params)\n", - " q_next_max = np.max(q_next_all, axis=1) # (batch,)\n", - " targets = rewards_b + (1.0 - dones_b) * self.gamma * q_next_max\n", + " # Target Q-values (off-policy: max over actions in next state)\n", + " with torch.no_grad():\n", + " q_next = self.target_net(next_states_t).max(dim=1).values\n", + " targets = rewards_t + (1.0 - dones_t) * self.gamma * q_next\n", "\n", - " # MSE loss gradient: d_loss/d_q_curr = 2 * (q_curr - targets) / batch\n", - " # We only backprop through the action that was taken\n", - " d_out = np.zeros_like(q_all)\n", - " d_out[np.arange(self.batch_size), actions_b] = 2.0 * (q_curr - targets) / self.batch_size\n", - "\n", - " # Backward pass\n", - " grads = self._backward(d_out, cache, self.params)\n", - "\n", - " # Gradient descent update\n", - " for key in self.params:\n", - " self.params[key] -= self.lr * grads[key]\n", + " # Compute loss and update\n", + " loss = self.loss_fn(q_curr, targets)\n", + " self.optimizer.zero_grad()\n", + " loss.backward()\n", + " nn.utils.clip_grad_norm_(self.q_net.parameters(), max_norm=10.0)\n", + " self.optimizer.step()\n", "\n", " # Sync target network periodically\n", " self.update_step += 1\n", " if self.update_step % self.target_update_freq == 0:\n", - " self.target_params = {k: v.copy() for k, v in self.params.items()}\n" + " self.target_net.load_state_dict(self.q_net.state_dict())\n", + "\n", + " def save(self, filename: str) -> None:\n", + " \"\"\"Save agent state using torch.save (networks + optimizer + metadata).\"\"\"\n", + " torch.save(\n", + " {\n", + " \"q_net\": self.q_net.state_dict(),\n", + " \"target_net\": self.target_net.state_dict(),\n", + " \"optimizer\": self.optimizer.state_dict(),\n", + " \"update_step\": self.update_step,\n", + " \"n_features\": self.n_features,\n", + " \"action_space\": self.action_space,\n", + " },\n", + " filename,\n", + " )\n", + "\n", + " def load(self, filename: str) -> None:\n", + " \"\"\"Load agent state from a torch checkpoint.\"\"\"\n", + " checkpoint = torch.load(filename, map_location=DEVICE, weights_only=False)\n", + " self.q_net.load_state_dict(checkpoint[\"q_net\"])\n", + " self.target_net.load_state_dict(checkpoint[\"target_net\"])\n", + " self.optimizer.load_state_dict(checkpoint[\"optimizer\"])\n", + " self.update_step = checkpoint[\"update_step\"]\n", + " self.q_net.to(DEVICE)\n", + " self.target_net.to(DEVICE)\n", + " self.target_net.eval()\n" ] }, { @@ -1226,12 +1180,12 @@ "- **Random** — random baseline (no training needed)\n", "- **SARSA** — linear approximation, semi-gradient TD(0)\n", "- **Q-Learning** — linear approximation, off-policy\n", - "- **DQN** — NumPy MLP (2 hidden layers of 128, experience replay, target network)\n", + "- **DQN** — PyTorch MLP (2 hidden layers of 256, experience replay, target network, GPU-accelerated)\n", "- **Monte Carlo** — linear approximation, first-visit returns\n", "\n", "**Workflow**:\n", "1. Train **one** selected agent (`AGENT_TO_TRAIN`)\n", - "2. Save its weights to `checkpoints/.pkl`\n", + "2. Save its weights to `checkpoints/` (`.pkl` for linear agents, `.pt` for DQN)\n", "3. Repeat later for another agent without retraining previous ones\n", "4. Load all saved checkpoints before the final evaluation" ] @@ -1282,7 +1236,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 23, "id": "4d449701", "metadata": {}, "outputs": [ @@ -1290,39 +1244,45 @@ "name": "stdout", "output_type": "stream", "text": [ - "Selected agent: Q-Learning\n", - "Checkpoint path: checkpoints/q_learning.pkl\n", + "Selected agent: DQN\n", + "Checkpoint path: checkpoints/dqn.pt\n", "\n", "============================================================\n", - "Training: Q-Learning (2500 episodes)\n", + "Training: DQN (2500 episodes)\n", "============================================================\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "02b31c90057641a09f7e6bb3f3c53f79", + "model_id": "47c9af1f459346dea335fb822a091b1b", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Training Q-Learning: 0%| | 0/2500 [00:00 Q-Learning avg reward (last 100 eps): -0.19\n", - "Checkpoint saved.\n" + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_730/423872786.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{'='*60}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m training_histories[AGENT_TO_TRAIN] = train_agent(\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0magent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0magent\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/tmp/ipykernel_730/4063369955.py\u001b[0m in \u001b[0;36mtrain_agent\u001b[0;34m(env, agent, name, episodes, epsilon_start, epsilon_end, epsilon_decay, max_steps)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;31m# Update agent with the transition\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m agent.update(\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mobs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/tmp/ipykernel_730/4071728682.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, state, action, reward, next_state, done, next_action)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;31m# Sample minibatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 138\u001b[0;31m states_b, actions_b, rewards_b, next_states_b, dones_b = self.replay_buffer.sample(\n\u001b[0m\u001b[1;32m 139\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m )\n", + "\u001b[0;32m/tmp/ipykernel_730/4071728682.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(self, batch_size, rng)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrng\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchoice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mstates\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0mactions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0mrewards\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ - "AGENT_TO_TRAIN = \"Q-Learning\" # TODO: change to: \"Q-Learning\", \"DQN\", \"Monte Carlo\", \"Random\"\n", + "AGENT_TO_TRAIN = \"DQN\" # TODO: change to: \"Q-Learning\", \"DQN\", \"Monte Carlo\", \"Random\"\n", "TRAINING_EPISODES = 2500\n", "FORCE_RETRAIN = False\n", "\n", @@ -1332,8 +1292,7 @@ "\n", "training_histories: dict[str, list[float]] = {}\n", "agent = agents[AGENT_TO_TRAIN]\n", - "ckpt_name = AGENT_TO_TRAIN.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n", - "ckpt_path = CHECKPOINT_DIR / ckpt_name\n", + "ckpt_path = _ckpt_path(AGENT_TO_TRAIN)\n", "\n", "print(f\"Selected agent: {AGENT_TO_TRAIN}\")\n", "print(f\"Checkpoint path: {ckpt_path}\")\n", @@ -1369,7 +1328,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "id": "a13a65df", "metadata": {}, "outputs": [ @@ -1400,14 +1359,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "Loaded checkpoints: ['SARSA']\n", - "Missing checkpoints (untrained/not saved yet): ['Q-Learning', 'DQN', 'Monte Carlo']\n" + "Loaded checkpoints: ['SARSA', 'Q-Learning']\n", + "Missing checkpoints (untrained/not saved yet): ['DQN', 'Monte Carlo']\n" ] } ], "source": [ "# Load all available checkpoints before final evaluation\n", - "CHECKPOINT_DIR = Path(\"checkpoints\")\n", "loaded_agents: list[str] = []\n", "missing_agents: list[str] = []\n", "\n", @@ -1415,8 +1373,7 @@ " if name == \"Random\":\n", " continue\n", "\n", - " ckpt_name = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n", - " ckpt_path = CHECKPOINT_DIR / ckpt_name\n", + " ckpt_path = _ckpt_path(name)\n", "\n", " if ckpt_path.exists():\n", " agent.load(str(ckpt_path))\n", @@ -1455,8 +1412,7 @@ " if name == \"Random\":\n", " continue\n", "\n", - " ckpt_name = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n", - " ckpt_path = CHECKPOINT_DIR / ckpt_name\n", + " ckpt_path = _ckpt_path(name)\n", "\n", " if ckpt_path.exists():\n", " agent.load(str(ckpt_path))\n", @@ -1551,8 +1507,7 @@ " # Keep only agents that have a checkpoint\n", " ready_names: list[str] = []\n", " for name in candidate_names:\n", - " ckpt_name = name.lower().replace(\" \", \"_\").replace(\"-\", \"_\") + \".pkl\"\n", - " ckpt_path = checkpoint_dir / ckpt_name\n", + " ckpt_path = _ckpt_path(name)\n", " if ckpt_path.exists():\n", " agents[name].load(str(ckpt_path))\n", " ready_names.append(name)\n", @@ -1658,7 +1613,7 @@ "\n", "This tournament uses `from pettingzoo.atari import tennis_v3` to make trained agents play against each other directly.\n", "\n", - "- Checkpoints are loaded from `checkpoints/*.pkl`\n", + "- Checkpoints are loaded from `checkpoints/` (`.pkl` for linear agents, `.pt` for DQN)\n", "- `Random` is **excluded** from ranking\n", "- Each pair plays in both seat positions (`first_0` and `second_0`) to reduce position bias\n", "- A win-rate matrix and final ranking are produced" @@ -1667,7 +1622,7 @@ ], "metadata": { "kernelspec": { - "display_name": "studies (3.13.9)", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -1681,7 +1636,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.9" + "version": "3.12.12" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 6a8de78..a396d9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,48 +7,48 @@ requires-python = ">= 3.11" dependencies = [ "accelerate>=1.12.0", "ale-py>=0.11.2", - "catboost>=1.2.8", + "catboost>=1.2.10", "datasets>=4.6.1", "faiss-cpu>=1.13.2", "folium>=0.20.0", "geopandas>=1.1.2", - "google-api-python-client>=2.190.0", + "google-api-python-client>=2.191.0", "google-auth-oauthlib>=1.3.0", "google-generativeai>=0.8.6", "gymnasium[toy-text]>=1.2.3", "imblearn>=0.0", "ipykernel>=7.2.0", "ipywidgets>=8.1.8", - "langchain>=1.2.0", + "langchain>=1.2.10", "langchain-community>=0.4.1", - "langchain-core>=1.2.16", - "langchain-huggingface>=1.2.0", + "langchain-core>=1.2.17", + "langchain-huggingface>=1.2.1", "langchain-mistralai>=1.1.1", "langchain-ollama>=1.0.1", "langchain-openai>=1.1.10", "langchain-text-splitters>=1.1.1", "mapclassify>=2.10.0", - "marimo>=0.19.11", - "matplotlib>=3.10.1", + "marimo>=0.20.2", + "matplotlib>=3.10.8", "nbformat>=5.10.4", - "numpy>=2.2.5", - "opencv-python>=4.11.0.86", + "numpy>=2.4.2", + "opencv-python>=4.13.0.92", "openpyxl>=3.1.5", - "pandas>=2.2.3", + "pandas>=3.0.1", "pandas-stubs>=3.0.0.260204", "pettingzoo[atari]>=1.25.0", - "plotly>=6.3.0", - "polars>=1.37.0", + "plotly>=6.6.0", + "polars>=1.38.1", "pyclustertend>=1.9.0", "pypdf>=6.7.5", - "rasterio>=1.4.4", + "rasterio>=1.5.0", "requests>=2.32.5", - "scikit-learn>=1.6.1", - "scipy>=1.15.2", + "scikit-learn>=1.8.0", + "scipy>=1.17.1", "seaborn>=0.13.2", "sentence-transformers>=5.2.3", # "sequenzo>=0.1.20", - "shap>=0.49.1", + "shap>=0.50.0", "spacy>=3.8.11", "statsmodels>=0.14.6", "supersuit>=3.10.0", @@ -65,10 +65,7 @@ dependencies = [ ] [dependency-groups] -dev = [ - "ipykernel>=7.2.0", - "uv>=0.10.7", -] +dev = ["ipykernel>=7.2.0", "uv>=0.10.7"] [tool.ty.rules] index-out-of-bounds = "ignore" @@ -83,18 +80,18 @@ select = ["ALL"] # Désactiver certaines règles ignore = [ - "E501", # line too long, géré par le formatter - "E402", # Imports in top of file - "T201", # Print + "E501", # line too long, géré par le formatter + "E402", # Imports in top of file + "T201", # Print "N806", "N803", "N802", "N816", "PLR0913", # Too many arguments "PLR2004", # - "FBT001", # Boolean positional argument used in function definition + "FBT001", # Boolean positional argument used in function definition "EM101", - "TRY003", # Avoid specifying long messages outside the exception class + "TRY003", # Avoid specifying long messages outside the exception class ] # Exclure certains fichiers ou répertoires @@ -127,9 +124,25 @@ unfixable = [] [tool.ruff.lint.isort] # Regrouper les imports par thématiques -section-order = ["future", "standard-library", "third-party", "data-science", "ml", "first-party", "local-folder"] +section-order = [ + "future", + "standard-library", + "third-party", + "data-science", + "ml", + "first-party", + "local-folder", +] [tool.ruff.lint.isort.sections] # On sépare les outils de manipulation de données des frameworks de ML lourds "data-science" = ["numpy", "pandas", "scipy", "matplotlib", "seaborn", "plotly"] -"ml" = ["tensorflow", "keras", "torch", "sklearn", "xgboost", "catboost", "shap"] +"ml" = [ + "tensorflow", + "keras", + "torch", + "sklearn", + "xgboost", + "catboost", + "shap", +]