diff --git a/M2/Reinforcement Learning/project/Project.ipynb b/M2/Reinforcement Learning/project/Project.ipynb
index 615e96f..c38b0d1 100644
--- a/M2/Reinforcement Learning/project/Project.ipynb
+++ b/M2/Reinforcement Learning/project/Project.ipynb
@@ -2,13 +2,16 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 4,
"id": "7e37429a",
"metadata": {},
"outputs": [],
"source": [
- "import ale_py\n",
- "import gymnasium as gym\n"
+ "import torch\n",
+ "import numpy as np\n",
+ "import pickle\n",
+ "from pathlib import Path\n",
+ "\n"
]
},
{
@@ -16,35 +19,40 @@
"execution_count": null,
"id": "85ff0eb4",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "A.L.E: Arcade Learning Environment (version 0.11.2+ecc1138)\n",
- "[Powered by Stella]\n"
- ]
- },
- {
- "ename": "",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[1;31mLe noyau s’est bloqué lors de l’exécution du code dans une cellule active ou une cellule précédente. \n",
- "\u001b[1;31mVeuillez vérifier le code dans la ou les cellules pour identifier une cause possible de l’échec. \n",
- "\u001b[1;31mCliquez ici pour plus d’informations. \n",
- "\u001b[1;31mPour plus d’informations, consultez Jupyter log."
- ]
- }
- ],
+ "outputs": [],
"source": [
- "gym.register_envs(ale_py)\n",
+ "class Agent:\n",
+ " \"\"\"Base class for reinforcement learning agents.\"\"\"\n",
"\n",
- "env = gym.make(\"ALE/Tennis-v5\", render_mode=\"human\")\n",
- "obs, info = env.reset()\n",
- "obs, reward, terminated, truncated, info = env.step(env.action_space.sample())\n",
- "env.close()\n",
- "\n"
+ " def __init__(self, action_space: int) -> None:\n",
+ " \"\"\"Initialize the agent.\"\"\"\n",
+ " self.action_space = action_space\n",
+ "\n",
+ " def get_action(self, observation: np.ndarray, epsilon: float = 0.0):\n",
+ " \"\"\"Select an action based on the current observation.\"\"\"\n",
+ " raise NotImplementedError\n",
+ "\n",
+ " def update(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool):\n",
+ " \"\"\"Update the agent's knowledge based on the experience tuple.\"\"\"\n",
+ " pass\n",
+ "\n",
+ " def save(self, filename: str) -> None:\n",
+ " \"\"\"Save the agent's state to a file.\"\"\"\n",
+ " with Path(filename).open(\"wb\") as f:\n",
+ " pickle.dump(self.__dict__, f)\n",
+ "\n",
+ " def load(self, filename: str) -> None:\n",
+ " \"\"\"Load the agent's state from a file.\"\"\"\n",
+ " with Path(filename).open(\"rb\") as f:\n",
+ " self.__dict__.update(pickle.load(f)) # noqa: S301\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2459be52",
+ "metadata": {},
+ "source": [
+ "## Random Agent"
]
},
{
@@ -53,12 +61,17 @@
"id": "89633751",
"metadata": {},
"outputs": [],
- "source": []
+ "source": [
+ "class RandomAgent(Agent):\n",
+ " \"\"\"A simple agent that selects actions randomly.\"\"\"\n",
+ " def get_action(self, observation, epsilon=0.0):\n",
+ " return self.action_space.sample()\n"
+ ]
}
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "studies (3.13.9)",
"language": "python",
"name": "python3"
},