From f505977c2d40fe501a837f3f9137082a5b106141 Mon Sep 17 00:00:00 2001 From: Arthur DANJOU Date: Wed, 18 Feb 2026 18:25:41 +0100 Subject: [PATCH] =?UTF-8?q?R=C3=A9organiser=20le=20code=20pour=20d=C3=A9fi?= =?UTF-8?q?nir=20la=20classe=20Agent=20et=20am=C3=A9liorer=20la=20structur?= =?UTF-8?q?e=20du=20projet?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../project/Project.ipynb | 77 +++++++++++-------- 1 file changed, 45 insertions(+), 32 deletions(-) diff --git a/M2/Reinforcement Learning/project/Project.ipynb b/M2/Reinforcement Learning/project/Project.ipynb index 615e96f..c38b0d1 100644 --- a/M2/Reinforcement Learning/project/Project.ipynb +++ b/M2/Reinforcement Learning/project/Project.ipynb @@ -2,13 +2,16 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "id": "7e37429a", "metadata": {}, "outputs": [], "source": [ - "import ale_py\n", - "import gymnasium as gym\n" + "import torch\n", + "import numpy as np\n", + "import pickle\n", + "from pathlib import Path\n", + "\n" ] }, { @@ -16,35 +19,40 @@ "execution_count": null, "id": "85ff0eb4", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "A.L.E: Arcade Learning Environment (version 0.11.2+ecc1138)\n", - "[Powered by Stella]\n" - ] - }, - { - "ename": "", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[1;31mLe noyau s’est bloqué lors de l’exécution du code dans une cellule active ou une cellule précédente. \n", - "\u001b[1;31mVeuillez vérifier le code dans la ou les cellules pour identifier une cause possible de l’échec. \n", - "\u001b[1;31mCliquez ici pour plus d’informations. \n", - "\u001b[1;31mPour plus d’informations, consultez Jupyter log." - ] - } - ], + "outputs": [], "source": [ - "gym.register_envs(ale_py)\n", + "class Agent:\n", + " \"\"\"Base class for reinforcement learning agents.\"\"\"\n", "\n", - "env = gym.make(\"ALE/Tennis-v5\", render_mode=\"human\")\n", - "obs, info = env.reset()\n", - "obs, reward, terminated, truncated, info = env.step(env.action_space.sample())\n", - "env.close()\n", - "\n" + " def __init__(self, action_space: int) -> None:\n", + " \"\"\"Initialize the agent.\"\"\"\n", + " self.action_space = action_space\n", + "\n", + " def get_action(self, observation: np.ndarray, epsilon: float = 0.0):\n", + " \"\"\"Select an action based on the current observation.\"\"\"\n", + " raise NotImplementedError\n", + "\n", + " def update(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool):\n", + " \"\"\"Update the agent's knowledge based on the experience tuple.\"\"\"\n", + " pass\n", + "\n", + " def save(self, filename: str) -> None:\n", + " \"\"\"Save the agent's state to a file.\"\"\"\n", + " with Path(filename).open(\"wb\") as f:\n", + " pickle.dump(self.__dict__, f)\n", + "\n", + " def load(self, filename: str) -> None:\n", + " \"\"\"Load the agent's state from a file.\"\"\"\n", + " with Path(filename).open(\"rb\") as f:\n", + " self.__dict__.update(pickle.load(f)) # noqa: S301\n" + ] + }, + { + "cell_type": "markdown", + "id": "2459be52", + "metadata": {}, + "source": [ + "## Random Agent" ] }, { @@ -53,12 +61,17 @@ "id": "89633751", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "class RandomAgent(Agent):\n", + " \"\"\"A simple agent that selects actions randomly.\"\"\"\n", + " def get_action(self, observation, epsilon=0.0):\n", + " return self.action_space.sample()\n" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "studies (3.13.9)", "language": "python", "name": "python3" },