diff --git a/M2/Reinforcement Learning/project/Project.ipynb b/M2/Reinforcement Learning/project/Project.ipynb index 615e96f..c38b0d1 100644 --- a/M2/Reinforcement Learning/project/Project.ipynb +++ b/M2/Reinforcement Learning/project/Project.ipynb @@ -2,13 +2,16 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "id": "7e37429a", "metadata": {}, "outputs": [], "source": [ - "import ale_py\n", - "import gymnasium as gym\n" + "import torch\n", + "import numpy as np\n", + "import pickle\n", + "from pathlib import Path\n", + "\n" ] }, { @@ -16,35 +19,40 @@ "execution_count": null, "id": "85ff0eb4", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "A.L.E: Arcade Learning Environment (version 0.11.2+ecc1138)\n", - "[Powered by Stella]\n" - ] - }, - { - "ename": "", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[1;31mLe noyau s’est bloqué lors de l’exécution du code dans une cellule active ou une cellule précédente. \n", - "\u001b[1;31mVeuillez vérifier le code dans la ou les cellules pour identifier une cause possible de l’échec. \n", - "\u001b[1;31mCliquez ici pour plus d’informations. \n", - "\u001b[1;31mPour plus d’informations, consultez Jupyter log." - ] - } - ], + "outputs": [], "source": [ - "gym.register_envs(ale_py)\n", + "class Agent:\n", + " \"\"\"Base class for reinforcement learning agents.\"\"\"\n", "\n", - "env = gym.make(\"ALE/Tennis-v5\", render_mode=\"human\")\n", - "obs, info = env.reset()\n", - "obs, reward, terminated, truncated, info = env.step(env.action_space.sample())\n", - "env.close()\n", - "\n" + " def __init__(self, action_space: int) -> None:\n", + " \"\"\"Initialize the agent.\"\"\"\n", + " self.action_space = action_space\n", + "\n", + " def get_action(self, observation: np.ndarray, epsilon: float = 0.0):\n", + " \"\"\"Select an action based on the current observation.\"\"\"\n", + " raise NotImplementedError\n", + "\n", + " def update(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool):\n", + " \"\"\"Update the agent's knowledge based on the experience tuple.\"\"\"\n", + " pass\n", + "\n", + " def save(self, filename: str) -> None:\n", + " \"\"\"Save the agent's state to a file.\"\"\"\n", + " with Path(filename).open(\"wb\") as f:\n", + " pickle.dump(self.__dict__, f)\n", + "\n", + " def load(self, filename: str) -> None:\n", + " \"\"\"Load the agent's state from a file.\"\"\"\n", + " with Path(filename).open(\"rb\") as f:\n", + " self.__dict__.update(pickle.load(f)) # noqa: S301\n" + ] + }, + { + "cell_type": "markdown", + "id": "2459be52", + "metadata": {}, + "source": [ + "## Random Agent" ] }, { @@ -53,12 +61,17 @@ "id": "89633751", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "class RandomAgent(Agent):\n", + " \"\"\"A simple agent that selects actions randomly.\"\"\"\n", + " def get_action(self, observation, epsilon=0.0):\n", + " return self.action_space.sample()\n" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "studies (3.13.9)", "language": "python", "name": "python3" },