Refactor code for improved readability and consistency across multiple Jupyter notebooks

- Added missing commas in various print statements and function calls for better syntax.
- Reformatted code to enhance clarity, including breaking long lines and aligning parameters.
- Updated function signatures to use float type for sigma parameters instead of int for better precision.
- Cleaned up comments and documentation strings for clarity and consistency.
- Ensured consistent formatting in plotting functions and data handling.
This commit is contained in:
2025-12-13 23:38:17 +01:00
parent f89ff4a016
commit d5a6bfd339
50 changed files with 779 additions and 449 deletions

View File

@@ -23,11 +23,12 @@
"import numpy as np\n",
"\n",
"np.set_printoptions(\n",
" precision=3, suppress=True\n",
" precision=3,\n",
" suppress=True,\n",
") # (not mandatory) This line is for limiting floats to 3 decimal places, avoiding scientific notation (like 1.23e-04) for small numbers.\n",
"\n",
"# For reproducibility\n",
"rng = np.random.default_rng(seed=42) # This line creates a random number generator.\n"
"rng = np.random.default_rng(seed=42) # This line creates a random number generator."
]
},
{
@@ -102,7 +103,7 @@
" \"S\",\n",
" \"G\",\n",
" \"X\",\n",
"} # The vector Free represents cells that the agent is allowed to move into.\n"
"} # The vector Free represents cells that the agent is allowed to move into."
]
},
{
@@ -164,7 +165,7 @@
"print(\"Number of states (non-wall cells):\", n_states)\n",
"print(\"Start state:\", start_state, \"at\", state_to_pos[start_state])\n",
"print(\"Goal states:\", goal_states, \"at\", state_to_pos[goal_states[0]])\n",
"print(\"Trap states:\", trap_states, \"at\", state_to_pos[trap_states[0]])\n"
"print(\"Trap states:\", trap_states, \"at\", state_to_pos[trap_states[0]])"
]
},
{
@@ -188,7 +189,7 @@
"def plot_maze_with_states():\n",
" \"\"\"Plot the maze with state indices.\"\"\"\n",
" grid = np.ones(\n",
" (n_rows, n_cols)\n",
" (n_rows, n_cols),\n",
" ) # Start with a matrix of ones. Here 1 means “free cell”\n",
" for i in range(n_rows):\n",
" for j in range(n_cols):\n",
@@ -316,7 +317,7 @@
" # If the next cell is a wall, the robot stays in place.\n",
" return i, j\n",
"\n",
" return candidate_i, candidate_j # Otherwise, return the new position\n"
" return candidate_i, candidate_j # Otherwise, return the new position"
]
},
{
@@ -335,7 +336,7 @@
"outputs": [],
"source": [
"gamma = 0.95\n",
"p_error = 0.1 # probability of the error to a random other direction\n"
"p_error = 0.1 # probability of the error to a random other direction"
]
},
{
@@ -360,7 +361,7 @@
"# Set rewards for each state\n",
"step_penalty = -0.01\n",
"goal_reward = 1.0\n",
"trap_reward = -1.0\n"
"trap_reward = -1.0"
]
},
{
@@ -376,7 +377,7 @@
" elif s in trap_states:\n",
" R[s] = trap_reward\n",
" else:\n",
" R[s] = step_penalty\n"
" R[s] = step_penalty"
]
},
{
@@ -391,7 +392,7 @@
"\n",
"def is_terminal(s: int) -> bool:\n",
" \"\"\"Check if a state is terminal (goal or trap).\"\"\"\n",
" return s in terminal_states\n"
" return s in terminal_states"
]
},
{
@@ -437,9 +438,9 @@
" error_i, error_j = move_deterministic(i, j, a2)\n",
" s_error = pos_to_state[(error_i, error_j)] # get its state index s_error\n",
" P[a, s, s_error] += p_error / len(\n",
" other_actions\n",
" other_actions,\n",
" ) # add p_error / 3 to P[a, s, s_error]\n",
"# So for each (s,a), probabilities over all s_next sum to 1.\n"
"# So for each (s,a), probabilities over all s_next sum to 1."
]
},
{
@@ -476,7 +477,7 @@
" # If everything is correct, they should be very close to 1.\n",
"\n",
" probs = P[a].sum(axis=1)\n",
" print(f\"Action {action_names[a]}:\", probs)\n"
" print(f\"Action {action_names[a]}:\", probs)"
]
},
{
@@ -520,7 +521,7 @@
"\n",
" for _it in range(max_iter): # Main iterative loop\n",
" V_new = np.zeros_like(\n",
" V\n",
" V,\n",
" ) # Create a new value vector and we will compute an updated value for each state.\n",
"\n",
" # Now we update each state using the Bellman expectation equation\n",
@@ -529,7 +530,7 @@
" V_new[s] = R[s] + gamma * np.sum(P[a, s, :] * V)\n",
"\n",
" delta = np.max(\n",
" np.abs(V_new - V)\n",
" np.abs(V_new - V),\n",
" ) # This measures how much the value function changed in this iteration:\n",
" # If delta is small, the values start to converge; otherwise, we need to keep iterating.\n",
" V = V_new # Update V, i.e. Set the new values for the next iteration.\n",
@@ -537,7 +538,7 @@
" if delta < theta: # Check convergence: When changes are tiny, we stop.\n",
" break\n",
"\n",
" return V # Return the final value function, this is our estimate for V^{pi}(s), s in the state set.\n"
" return V # Return the final value function, this is our estimate for V^{pi}(s), s in the state set."
]
},
{
@@ -550,7 +551,8 @@
"def plot_values(V: np.ndarray, title=\"Value function\") -> None:\n",
" \"\"\"Plot the value function V on the maze as a heatmap.\"\"\"\n",
" grid_values = np.full(\n",
" (n_rows, n_cols), np.nan\n",
" (n_rows, n_cols),\n",
" np.nan,\n",
" ) # Initializes a grid the same size as the maze. Every cell starts as NaN.\n",
" for (\n",
" s,\n",
@@ -575,14 +577,20 @@
"\n",
" for s, (i, j) in state_to_pos.items():\n",
" ax.text(\n",
" j, i, f\"{V[s]:.2f}\", ha=\"center\", va=\"center\", color=\"white\", fontsize=9\n",
" j,\n",
" i,\n",
" f\"{V[s]:.2f}\",\n",
" ha=\"center\",\n",
" va=\"center\",\n",
" color=\"white\",\n",
" fontsize=9,\n",
" )\n",
"\n",
" # Remove axis ticks and set title\n",
" ax.set_xticks([])\n",
" ax.set_yticks([])\n",
" ax.set_title(title)\n",
" plt.show()\n"
" plt.show()"
]
},
{
@@ -659,7 +667,7 @@
" ax.set_yticklabels([])\n",
" ax.grid(True)\n",
" ax.set_title(title)\n",
" plt.show()\n"
" plt.show()"
]
},
{
@@ -716,7 +724,7 @@
"source": [
"V_random = policy_evaluation(policy=random_policy, P=P, R=R, gamma=gamma)\n",
"print(\"Value function under random policy:\")\n",
"print(V_random)\n"
"print(V_random)"
]
},
{
@@ -748,7 +756,7 @@
],
"source": [
"plot_values(V_random, title=\"Value function: random policy\")\n",
"plot_policy(policy=random_policy, title=\"Random Policy\")\n"
"plot_policy(policy=random_policy, title=\"Random Policy\")"
]
},
{
@@ -847,7 +855,7 @@
"V_my_policy = policy_evaluation(policy=my_policy, P=P, R=R, gamma=gamma)\n",
"\n",
"plot_values(V=V_my_policy, title=\"Value function: my policy\")\n",
"plot_policy(policy=my_policy, title=\"My policy\")\n"
"plot_policy(policy=my_policy, title=\"My policy\")"
]
},
{