mirror of
https://github.com/ArthurDanjou/ArtStudies.git
synced 2026-01-14 15:54:13 +01:00
Refactor code for improved readability and consistency across multiple Jupyter notebooks
- Added missing commas in various print statements and function calls for better syntax. - Reformatted code to enhance clarity, including breaking long lines and aligning parameters. - Updated function signatures to use float type for sigma parameters instead of int for better precision. - Cleaned up comments and documentation strings for clarity and consistency. - Ensured consistent formatting in plotting functions and data handling.
This commit is contained in:
@@ -23,11 +23,12 @@
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"np.set_printoptions(\n",
|
||||
" precision=3, suppress=True\n",
|
||||
" precision=3,\n",
|
||||
" suppress=True,\n",
|
||||
") # (not mandatory) This line is for limiting floats to 3 decimal places, avoiding scientific notation (like 1.23e-04) for small numbers.\n",
|
||||
"\n",
|
||||
"# For reproducibility\n",
|
||||
"rng = np.random.default_rng(seed=42) # This line creates a random number generator.\n"
|
||||
"rng = np.random.default_rng(seed=42) # This line creates a random number generator."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -102,7 +103,7 @@
|
||||
" \"S\",\n",
|
||||
" \"G\",\n",
|
||||
" \"X\",\n",
|
||||
"} # The vector Free represents cells that the agent is allowed to move into.\n"
|
||||
"} # The vector Free represents cells that the agent is allowed to move into."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -164,7 +165,7 @@
|
||||
"print(\"Number of states (non-wall cells):\", n_states)\n",
|
||||
"print(\"Start state:\", start_state, \"at\", state_to_pos[start_state])\n",
|
||||
"print(\"Goal states:\", goal_states, \"at\", state_to_pos[goal_states[0]])\n",
|
||||
"print(\"Trap states:\", trap_states, \"at\", state_to_pos[trap_states[0]])\n"
|
||||
"print(\"Trap states:\", trap_states, \"at\", state_to_pos[trap_states[0]])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -188,7 +189,7 @@
|
||||
"def plot_maze_with_states():\n",
|
||||
" \"\"\"Plot the maze with state indices.\"\"\"\n",
|
||||
" grid = np.ones(\n",
|
||||
" (n_rows, n_cols)\n",
|
||||
" (n_rows, n_cols),\n",
|
||||
" ) # Start with a matrix of ones. Here 1 means “free cell”\n",
|
||||
" for i in range(n_rows):\n",
|
||||
" for j in range(n_cols):\n",
|
||||
@@ -316,7 +317,7 @@
|
||||
" # If the next cell is a wall, the robot stays in place.\n",
|
||||
" return i, j\n",
|
||||
"\n",
|
||||
" return candidate_i, candidate_j # Otherwise, return the new position\n"
|
||||
" return candidate_i, candidate_j # Otherwise, return the new position"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -335,7 +336,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gamma = 0.95\n",
|
||||
"p_error = 0.1 # probability of the error to a random other direction\n"
|
||||
"p_error = 0.1 # probability of the error to a random other direction"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -360,7 +361,7 @@
|
||||
"# Set rewards for each state\n",
|
||||
"step_penalty = -0.01\n",
|
||||
"goal_reward = 1.0\n",
|
||||
"trap_reward = -1.0\n"
|
||||
"trap_reward = -1.0"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -376,7 +377,7 @@
|
||||
" elif s in trap_states:\n",
|
||||
" R[s] = trap_reward\n",
|
||||
" else:\n",
|
||||
" R[s] = step_penalty\n"
|
||||
" R[s] = step_penalty"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -391,7 +392,7 @@
|
||||
"\n",
|
||||
"def is_terminal(s: int) -> bool:\n",
|
||||
" \"\"\"Check if a state is terminal (goal or trap).\"\"\"\n",
|
||||
" return s in terminal_states\n"
|
||||
" return s in terminal_states"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -437,9 +438,9 @@
|
||||
" error_i, error_j = move_deterministic(i, j, a2)\n",
|
||||
" s_error = pos_to_state[(error_i, error_j)] # get its state index s_error\n",
|
||||
" P[a, s, s_error] += p_error / len(\n",
|
||||
" other_actions\n",
|
||||
" other_actions,\n",
|
||||
" ) # add p_error / 3 to P[a, s, s_error]\n",
|
||||
"# So for each (s,a), probabilities over all s_next sum to 1.\n"
|
||||
"# So for each (s,a), probabilities over all s_next sum to 1."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -476,7 +477,7 @@
|
||||
" # If everything is correct, they should be very close to 1.\n",
|
||||
"\n",
|
||||
" probs = P[a].sum(axis=1)\n",
|
||||
" print(f\"Action {action_names[a]}:\", probs)\n"
|
||||
" print(f\"Action {action_names[a]}:\", probs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -520,7 +521,7 @@
|
||||
"\n",
|
||||
" for _it in range(max_iter): # Main iterative loop\n",
|
||||
" V_new = np.zeros_like(\n",
|
||||
" V\n",
|
||||
" V,\n",
|
||||
" ) # Create a new value vector and we will compute an updated value for each state.\n",
|
||||
"\n",
|
||||
" # Now we update each state using the Bellman expectation equation\n",
|
||||
@@ -529,7 +530,7 @@
|
||||
" V_new[s] = R[s] + gamma * np.sum(P[a, s, :] * V)\n",
|
||||
"\n",
|
||||
" delta = np.max(\n",
|
||||
" np.abs(V_new - V)\n",
|
||||
" np.abs(V_new - V),\n",
|
||||
" ) # This measures how much the value function changed in this iteration:\n",
|
||||
" # If delta is small, the values start to converge; otherwise, we need to keep iterating.\n",
|
||||
" V = V_new # Update V, i.e. Set the new values for the next iteration.\n",
|
||||
@@ -537,7 +538,7 @@
|
||||
" if delta < theta: # Check convergence: When changes are tiny, we stop.\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" return V # Return the final value function, this is our estimate for V^{pi}(s), s in the state set.\n"
|
||||
" return V # Return the final value function, this is our estimate for V^{pi}(s), s in the state set."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -550,7 +551,8 @@
|
||||
"def plot_values(V: np.ndarray, title=\"Value function\") -> None:\n",
|
||||
" \"\"\"Plot the value function V on the maze as a heatmap.\"\"\"\n",
|
||||
" grid_values = np.full(\n",
|
||||
" (n_rows, n_cols), np.nan\n",
|
||||
" (n_rows, n_cols),\n",
|
||||
" np.nan,\n",
|
||||
" ) # Initializes a grid the same size as the maze. Every cell starts as NaN.\n",
|
||||
" for (\n",
|
||||
" s,\n",
|
||||
@@ -575,14 +577,20 @@
|
||||
"\n",
|
||||
" for s, (i, j) in state_to_pos.items():\n",
|
||||
" ax.text(\n",
|
||||
" j, i, f\"{V[s]:.2f}\", ha=\"center\", va=\"center\", color=\"white\", fontsize=9\n",
|
||||
" j,\n",
|
||||
" i,\n",
|
||||
" f\"{V[s]:.2f}\",\n",
|
||||
" ha=\"center\",\n",
|
||||
" va=\"center\",\n",
|
||||
" color=\"white\",\n",
|
||||
" fontsize=9,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Remove axis ticks and set title\n",
|
||||
" ax.set_xticks([])\n",
|
||||
" ax.set_yticks([])\n",
|
||||
" ax.set_title(title)\n",
|
||||
" plt.show()\n"
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -659,7 +667,7 @@
|
||||
" ax.set_yticklabels([])\n",
|
||||
" ax.grid(True)\n",
|
||||
" ax.set_title(title)\n",
|
||||
" plt.show()\n"
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -716,7 +724,7 @@
|
||||
"source": [
|
||||
"V_random = policy_evaluation(policy=random_policy, P=P, R=R, gamma=gamma)\n",
|
||||
"print(\"Value function under random policy:\")\n",
|
||||
"print(V_random)\n"
|
||||
"print(V_random)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -748,7 +756,7 @@
|
||||
],
|
||||
"source": [
|
||||
"plot_values(V_random, title=\"Value function: random policy\")\n",
|
||||
"plot_policy(policy=random_policy, title=\"Random Policy\")\n"
|
||||
"plot_policy(policy=random_policy, title=\"Random Policy\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -847,7 +855,7 @@
|
||||
"V_my_policy = policy_evaluation(policy=my_policy, P=P, R=R, gamma=gamma)\n",
|
||||
"\n",
|
||||
"plot_values(V=V_my_policy, title=\"Value function: my policy\")\n",
|
||||
"plot_policy(policy=my_policy, title=\"My policy\")\n"
|
||||
"plot_policy(policy=my_policy, title=\"My policy\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user