mirror of
https://github.com/ArthurDanjou/ArtStudies.git
synced 2026-01-25 09:52:48 +01:00
Refactor code for improved readability and consistency across multiple Jupyter notebooks
- Added missing commas in various print statements and function calls for better syntax. - Reformatted code to enhance clarity, including breaking long lines and aligning parameters. - Updated function signatures to use float type for sigma parameters instead of int for better precision. - Cleaned up comments and documentation strings for clarity and consistency. - Ensured consistent formatting in plotting functions and data handling.
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"from sklearn.cluster import KMeans"
|
||||
]
|
||||
},
|
||||
@@ -77,11 +78,11 @@
|
||||
" return eigenvalues[idx_L], eigenvectors[:, idx_L]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def compute_W_matrix(sigma: int, X: np.ndarray) -> np.ndarray:\n",
|
||||
"def compute_W_matrix(sigma: float, X: np.ndarray) -> np.ndarray:\n",
|
||||
" \"\"\"Fill the similarity matrix W.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" sigma (int): Parameter for the Gaussian kernel.\n",
|
||||
" sigma (float): Parameter for the Gaussian kernel.\n",
|
||||
" X (np.ndarray): Input data.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
@@ -94,7 +95,7 @@
|
||||
" W[i, j] = (\n",
|
||||
" 0 if i == j else np.exp(-(np.abs(X[i] - X[j]) ** 2) / (2 * sigma**2))\n",
|
||||
" )\n",
|
||||
" return W\n"
|
||||
" return W"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -104,11 +105,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def create_X(sigma: int, n: int, m: int) -> np.ndarray:\n",
|
||||
"def create_X(sigma: float, n: int, m: int) -> np.ndarray:\n",
|
||||
" \"\"\"Create a dataset with 4 Gaussian clusters.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" sigma (int): Standard deviation of the clusters.\n",
|
||||
" sigma (float): Standard deviation of the clusters.\n",
|
||||
" n (int): Total number of data points.\n",
|
||||
" m (int): Number of clusters.\n",
|
||||
"\n",
|
||||
@@ -125,11 +126,11 @@
|
||||
" return np.concatenate([norm_1, norm_2, norm_3, norm_4])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def plot_eigenvalues(sigma: int, n: int, m: int, k: int) -> None:\n",
|
||||
"def plot_eigenvalues(sigma: float, n: int, m: int, k: int) -> None:\n",
|
||||
" \"\"\"Plot the eigenvalues of the Laplacian for different sigma values.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" sigma (int): Standard deviation of the clusters.\n",
|
||||
" sigma (float): Standard deviation of the clusters.\n",
|
||||
" n (int): Total number of data points.\n",
|
||||
" m (int): Number of clusters.\n",
|
||||
" k (int): Number of eigenvalues to compute.\n",
|
||||
@@ -153,7 +154,7 @@
|
||||
" plt.xlabel(\"Index\")\n",
|
||||
" plt.ylabel(\"Eigenvalue\")\n",
|
||||
"\n",
|
||||
" plt.show()\n"
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -236,7 +237,7 @@
|
||||
"n = 200\n",
|
||||
"m = 4\n",
|
||||
"for sigma in [0.1, 1 / 4, 0.5, 1]:\n",
|
||||
" plot_eigenvalues(sigma, n=n, m=m, k=k)\n"
|
||||
" plot_eigenvalues(sigma, n=n, m=m, k=k)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -246,14 +247,14 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def spectral_clustering(sigma, n: int, m: int, k: int) -> np.ndarray:\n",
|
||||
"def spectral_clustering(sigma: float, n: int, m: int, k: int) -> np.ndarray:\n",
|
||||
" \"\"\"Perform spectral clustering on the data.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" sigma: The sigma value for the similarity matrix.\n",
|
||||
" n: Number of data points.\n",
|
||||
" m: Number of clusters.\n",
|
||||
" k: Number of eigenvectors to use.\n",
|
||||
" sigma (float): The sigma value for the similarity matrix.\n",
|
||||
" n (int): Number of data points.\n",
|
||||
" m (int): Number of clusters.\n",
|
||||
" k (int): Number of eigenvectors to use.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" X: The data points.\n",
|
||||
@@ -269,7 +270,7 @@
|
||||
"\n",
|
||||
" kmeans = KMeans(n_clusters=k)\n",
|
||||
"\n",
|
||||
" return X, kmeans.fit_predict(U_normalized)\n"
|
||||
" return X, kmeans.fit_predict(U_normalized)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -312,7 +313,7 @@
|
||||
"\n",
|
||||
"plt.scatter(X, np.zeros_like(X), c=clusters, cmap=\"magma\")\n",
|
||||
"plt.title(\"Spectral Clustering Results\")\n",
|
||||
"plt.xlabel(\"Data Points\")\n"
|
||||
"plt.xlabel(\"Data Points\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user