From 84c512121d9d9a0acd76111e954522471ec83003 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Tue, 23 Nov 2021 18:21:10 +1300 Subject: [PATCH] Use @ operator, fix animation (shorten, improve code, and zoom) --- extra_gradient_descent_comparison.ipynb | 111 +++++++++++++----------- 1 file changed, 60 insertions(+), 51 deletions(-) diff --git a/extra_gradient_descent_comparison.ipynb b/extra_gradient_descent_comparison.ipynb index a84570e..da344aa 100644 --- a/extra_gradient_descent_comparison.ipynb +++ b/extra_gradient_descent_comparison.ipynb @@ -20,10 +20,10 @@ "source": [ "\n", " \n", " \n", "
\n", - " \"Open\n", + " \"Open\n", " \n", - " \n", + " \n", "
" ] @@ -34,11 +34,11 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "\n", - "%matplotlib nbagg\n", + "import matplotlib\n", "import matplotlib.pyplot as plt\n", - "from matplotlib.animation import FuncAnimation" + "from matplotlib.animation import FuncAnimation\n", + "\n", + "matplotlib.rc('animation', html='jshtml')" ] }, { @@ -47,10 +47,12 @@ "metadata": {}, "outputs": [], "source": [ + "import numpy as np\n", + "\n", "m = 100\n", - "X = 2*np.random.rand(m, 1)\n", + "X = 2 * np.random.rand(m, 1)\n", "X_b = np.c_[np.ones((m, 1)), X]\n", - "y = 4 + 3*X + np.random.rand(m, 1)" + "y = 4 + 3 * X + np.random.rand(m, 1)" ] }, { @@ -65,8 +67,8 @@ " thetas = np.random.randn(2, 1)\n", " thetas_path = [thetas]\n", " for i in range(n_iterations):\n", - " gradients = 2*X_b.T.dot(X_b.dot(thetas) - y)/m\n", - " thetas = thetas - learning_rate*gradients\n", + " gradients = 2 * X_b.T @ (X_b @ thetas - y) / m\n", + " thetas = thetas - learning_rate * gradients\n", " thetas_path.append(thetas)\n", "\n", " return thetas_path" @@ -88,9 +90,9 @@ " random_index = np.random.randint(m)\n", " xi = X_b[random_index:random_index+1]\n", " yi = y[random_index:random_index+1]\n", - " gradients = 2*xi.T.dot(xi.dot(thetas) - yi)\n", - " eta = learning_schedule(epoch*m + i, t0, t1)\n", - " thetas = thetas - eta*gradients\n", + " gradients = 2 * xi.T @ (xi @ thetas - yi)\n", + " eta = learning_schedule(epoch * m + i, t0, t1)\n", + " thetas = thetas - eta * gradients\n", " thetas_path.append(thetas)\n", "\n", " return thetas_path" @@ -115,11 +117,11 @@ " y_shuffled = y[shuffled_indices]\n", " for i in range(0, m, minibatch_size):\n", " t += 1\n", - " xi = X_b_shuffled[i:i+minibatch_size]\n", - " yi = y_shuffled[i:i+minibatch_size]\n", - " gradients = 2*xi.T.dot(xi.dot(thetas) - yi)/minibatch_size\n", + " xi = X_b_shuffled[i : i + minibatch_size]\n", + " yi = y_shuffled[i : i + minibatch_size]\n", + " gradients = 2 * xi.T @ (xi @ thetas - yi) / minibatch_size\n", " eta = learning_schedule(t, t0, t1)\n", - " thetas = thetas - eta*gradients\n", + " thetas = thetas - eta * gradients\n", " thetas_path.append(thetas)\n", "\n", " return thetas_path" @@ -132,7 +134,7 @@ "outputs": [], "source": [ "def compute_mse(theta):\n", - " return np.sum((np.dot(X_b, theta) - y)**2)/m" + " return ((X_b @ theta - y) ** 2).sum() / m" ] }, { @@ -142,7 +144,7 @@ "outputs": [], "source": [ "def learning_schedule(t, t0, t1):\n", - " return t0/(t+t1)" + " return t0 / (t + t1)" ] }, { @@ -166,7 +168,7 @@ "metadata": {}, "outputs": [], "source": [ - "exact_solution = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)\n", + "exact_solution = np.linalg.inv(X_b.T @ X_b) @ X_b.T @ y\n", "bgd_thetas = np.array(batch_gradient_descent())\n", "sgd_thetas = np.array(stochastic_gradient_descent())\n", "mbgd_thetas = np.array(mini_batch_gradient_descent())" @@ -194,9 +196,38 @@ "data_ax = fig.add_subplot(121)\n", "cost_ax = fig.add_subplot(122)\n", "\n", + "data_ax.plot(X, y, 'k.')\n", + "\n", "cost_ax.plot(exact_solution[0,0], exact_solution[1,0], 'y*')\n", - "cost_img = cost_ax.pcolor(theta0, theta1, cost_map)\n", - "fig.colorbar(cost_img)" + "cost_ax.pcolor(theta0, theta1, cost_map, shading='auto')\n", + "\n", + "i = -1\n", + "[bgd_data_plot] = data_ax.plot(X, X_b @ bgd_thetas[i,:], 'r-')\n", + "[bgd_cost_plot] = cost_ax.plot(bgd_thetas[:i,0], bgd_thetas[:i,1], 'r--')\n", + "\n", + "[sgd_data_plot] = data_ax.plot(X, X_b @ sgd_thetas[i,:], 'g-')\n", + "[sgd_cost_plot] = cost_ax.plot(sgd_thetas[:i,0], sgd_thetas[:i,1], 'g--')\n", + "\n", + "[mbgd_data_plot] = data_ax.plot(X, X_b @ mbgd_thetas[i,:], 'b-')\n", + "[mbgd_cost_plot] = cost_ax.plot(mbgd_thetas[:i,0], mbgd_thetas[:i,1], 'b--')\n", + "\n", + "data_ax.set_xlim([0, 2])\n", + "data_ax.set_ylim([0, 15])\n", + "cost_ax.set_xlim([3, 5])\n", + "cost_ax.set_ylim([2, 5])\n", + "\n", + "data_ax.set_xlabel(r'$x_1$')\n", + "data_ax.set_ylabel(r'$y$', rotation=0)\n", + "cost_ax.set_xlabel(r'$\\theta_0$')\n", + "cost_ax.set_ylabel(r'$\\theta_1$')\n", + "\n", + "data_ax.legend(('Data', 'BGD', 'SGD', 'MBGD'), loc=\"upper left\")\n", + "cost_ax.legend(('Normal Equation', 'BGD', 'SGD', 'MBGD'), loc=\"upper left\")\n", + "\n", + "cost_ax.plot(exact_solution[0,0], exact_solution[1,0], 'y*')\n", + "cost_img = cost_ax.pcolor(theta0, theta1, cost_map, shading='auto')\n", + "fig.colorbar(cost_img)\n", + "plt.show()" ] }, { @@ -206,35 +237,14 @@ "outputs": [], "source": [ "def animate(i):\n", - " data_ax.cla()\n", - " cost_ax.cla()\n", + " bgd_data_plot.set_data(X, X_b @ bgd_thetas[i,:])\n", + " bgd_cost_plot.set_data(bgd_thetas[:i,0], bgd_thetas[:i,1])\n", "\n", - " data_ax.plot(X, y, 'k.')\n", + " sgd_data_plot.set_data(X, X_b @ sgd_thetas[i,:])\n", + " sgd_cost_plot.set_data(sgd_thetas[:i,0], sgd_thetas[:i,1])\n", "\n", - " cost_ax.plot(exact_solution[0,0], exact_solution[1,0], 'y*')\n", - " cost_ax.pcolor(theta0, theta1, cost_map)\n", - "\n", - " data_ax.plot(X, X_b.dot(bgd_thetas[i,:]), 'r-')\n", - " cost_ax.plot(bgd_thetas[:i,0], bgd_thetas[:i,1], 'r--')\n", - "\n", - " data_ax.plot(X, X_b.dot(sgd_thetas[i,:]), 'g-')\n", - " cost_ax.plot(sgd_thetas[:i,0], sgd_thetas[:i,1], 'g--')\n", - "\n", - " data_ax.plot(X, X_b.dot(mbgd_thetas[i,:]), 'b-')\n", - " cost_ax.plot(mbgd_thetas[:i,0], mbgd_thetas[:i,1], 'b--')\n", - "\n", - " data_ax.set_xlim([0, 2])\n", - " data_ax.set_ylim([0, 15])\n", - " cost_ax.set_xlim([0, 5])\n", - " cost_ax.set_ylim([0, 5])\n", - "\n", - " data_ax.set_xlabel(r'$x_1$')\n", - " data_ax.set_ylabel(r'$y$', rotation=0)\n", - " cost_ax.set_xlabel(r'$\\theta_0$')\n", - " cost_ax.set_ylabel(r'$\\theta_1$')\n", - "\n", - " data_ax.legend(('Data', 'BGD', 'SGD', 'MBGD'), loc=\"upper left\")\n", - " cost_ax.legend(('Normal Equation', 'BGD', 'SGD', 'MBGD'), loc=\"upper left\")" + " mbgd_data_plot.set_data(X, X_b @ mbgd_thetas[i,:])\n", + " mbgd_cost_plot.set_data(mbgd_thetas[:i,0], mbgd_thetas[:i,1])" ] }, { @@ -243,8 +253,7 @@ "metadata": {}, "outputs": [], "source": [ - "animation = FuncAnimation(fig, animate, frames=n_iter)\n", - "plt.show()" + "FuncAnimation(fig, animate, frames=n_iter // 3)" ] }, {