From de0eb33694d94c78837a232d7b62ce0de6df5966 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Tue, 14 Nov 2023 18:20:45 +1300 Subject: [PATCH] Use AdamW from tf.keras.optimizers instead of TensorFlow-Addons --- 11_training_deep_neural_networks.ipynb | 190 ++++++++++++------------- 1 file changed, 89 insertions(+), 101 deletions(-) diff --git a/11_training_deep_neural_networks.ipynb b/11_training_deep_neural_networks.ipynb index 5237514..823936a 100644 --- a/11_training_deep_neural_networks.ipynb +++ b/11_training_deep_neural_networks.ipynb @@ -1780,36 +1780,24 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "On Colab or Kaggle, we need to install the TensorFlow-Addons library:" + "Note: Since TF 1.12, `AdamW` is no longer experimental. It is available at `tf.keras.optimizers.AdamW` instead of `tf.keras.optimizers.experimental.AdamW`." ] }, { "cell_type": "code", "execution_count": 62, - "metadata": {}, - "outputs": [], - "source": [ - "if \"google.colab\" in sys.modules:\n", - " %pip install -q -U tensorflow-addons" - ] - }, - { - "cell_type": "code", - "execution_count": 63, "metadata": { "tags": [] }, "outputs": [], "source": [ - "import tensorflow_addons as tfa\n", - "\n", - "optimizer = tfa.optimizers.AdamW(weight_decay=1e-5, learning_rate=0.001,\n", - " beta_1=0.9, beta_2=0.999)" + "optimizer = tf.keras.optimizers.AdamW(weight_decay=1e-5, learning_rate=0.001,\n", + " beta_1=0.9, beta_2=0.999)" ] }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 63, "metadata": {}, "outputs": [ { @@ -1845,7 +1833,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 64, "metadata": {}, "outputs": [ { @@ -1917,7 +1905,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 65, "metadata": {}, "outputs": [], "source": [ @@ -1926,7 +1914,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 66, "metadata": {}, "outputs": [ { @@ -1962,7 +1950,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 67, "metadata": {}, "outputs": [ { @@ -2017,7 +2005,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 68, "metadata": {}, "outputs": [], "source": [ @@ -2027,7 +2015,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 69, "metadata": {}, "outputs": [], "source": [ @@ -2041,7 +2029,7 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 70, "metadata": {}, "outputs": [], "source": [ @@ -2056,7 +2044,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 71, "metadata": {}, "outputs": [ { @@ -2125,7 +2113,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 72, "metadata": {}, "outputs": [ { @@ -2162,7 +2150,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 73, "metadata": {}, "outputs": [], "source": [ @@ -2179,7 +2167,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 74, "metadata": {}, "outputs": [], "source": [ @@ -2203,7 +2191,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 75, "metadata": {}, "outputs": [], "source": [ @@ -2216,7 +2204,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 76, "metadata": {}, "outputs": [ { @@ -2288,7 +2276,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 77, "metadata": { "scrolled": true }, @@ -2330,7 +2318,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 78, "metadata": {}, "outputs": [], "source": [ @@ -2345,7 +2333,7 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 79, "metadata": {}, "outputs": [], "source": [ @@ -2364,7 +2352,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 80, "metadata": {}, "outputs": [ { @@ -2442,7 +2430,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 81, "metadata": {}, "outputs": [ { @@ -2479,7 +2467,7 @@ }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 82, "metadata": {}, "outputs": [], "source": [ @@ -2493,7 +2481,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 83, "metadata": {}, "outputs": [ { @@ -2562,7 +2550,7 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 84, "metadata": {}, "outputs": [ { @@ -2606,7 +2594,7 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 85, "metadata": {}, "outputs": [], "source": [ @@ -2622,7 +2610,7 @@ }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 86, "metadata": {}, "outputs": [ { @@ -2666,7 +2654,7 @@ }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 87, "metadata": {}, "outputs": [], "source": [ @@ -2692,7 +2680,7 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 88, "metadata": {}, "outputs": [], "source": [ @@ -2727,7 +2715,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 89, "metadata": {}, "outputs": [], "source": [ @@ -2755,7 +2743,7 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 90, "metadata": {}, "outputs": [], "source": [ @@ -2779,7 +2767,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 91, "metadata": {}, "outputs": [], "source": [ @@ -2798,7 +2786,7 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 92, "metadata": {}, "outputs": [ { @@ -2844,7 +2832,7 @@ }, { "cell_type": "code", - "execution_count": 94, + "execution_count": 93, "metadata": {}, "outputs": [], "source": [ @@ -2885,7 +2873,7 @@ }, { "cell_type": "code", - "execution_count": 95, + "execution_count": 94, "metadata": {}, "outputs": [ { @@ -2974,7 +2962,7 @@ }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 95, "metadata": {}, "outputs": [], "source": [ @@ -2992,7 +2980,7 @@ }, { "cell_type": "code", - "execution_count": 97, + "execution_count": 96, "metadata": {}, "outputs": [], "source": [ @@ -3001,7 +2989,7 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": 97, "metadata": {}, "outputs": [], "source": [ @@ -3022,7 +3010,7 @@ }, { "cell_type": "code", - "execution_count": 99, + "execution_count": 98, "metadata": {}, "outputs": [ { @@ -3054,7 +3042,7 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": 99, "metadata": {}, "outputs": [], "source": [ @@ -3063,7 +3051,7 @@ }, { "cell_type": "code", - "execution_count": 101, + "execution_count": 100, "metadata": {}, "outputs": [], "source": [ @@ -3082,7 +3070,7 @@ }, { "cell_type": "code", - "execution_count": 102, + "execution_count": 101, "metadata": {}, "outputs": [ { @@ -3130,7 +3118,7 @@ }, { "cell_type": "code", - "execution_count": 103, + "execution_count": 102, "metadata": {}, "outputs": [ { @@ -3146,7 +3134,7 @@ "[0.30816400051116943, 0.8849090933799744]" ] }, - "execution_count": 103, + "execution_count": 102, "metadata": {}, "output_type": "execute_result" } @@ -3157,7 +3145,7 @@ }, { "cell_type": "code", - "execution_count": 104, + "execution_count": 103, "metadata": {}, "outputs": [ { @@ -3173,7 +3161,7 @@ "[0.3628920316696167, 0.8700000047683716]" ] }, - "execution_count": 104, + "execution_count": 103, "metadata": {}, "output_type": "execute_result" } @@ -3198,7 +3186,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": 104, "metadata": {}, "outputs": [], "source": [ @@ -3207,7 +3195,7 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 105, "metadata": {}, "outputs": [], "source": [ @@ -3218,7 +3206,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 106, "metadata": {}, "outputs": [ { @@ -3228,7 +3216,7 @@ " 0.844]], dtype=float32)" ] }, - "execution_count": 107, + "execution_count": 106, "metadata": {}, "output_type": "execute_result" } @@ -3239,7 +3227,7 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 107, "metadata": {}, "outputs": [ { @@ -3249,7 +3237,7 @@ " 0.723], dtype=float32)" ] }, - "execution_count": 108, + "execution_count": 107, "metadata": {}, "output_type": "execute_result" } @@ -3260,7 +3248,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 108, "metadata": {}, "outputs": [ { @@ -3270,7 +3258,7 @@ " 0.183], dtype=float32)" ] }, - "execution_count": 109, + "execution_count": 108, "metadata": {}, "output_type": "execute_result" } @@ -3282,7 +3270,7 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 109, "metadata": {}, "outputs": [ { @@ -3291,7 +3279,7 @@ "0.8717" ] }, - "execution_count": 110, + "execution_count": 109, "metadata": {}, "output_type": "execute_result" } @@ -3304,7 +3292,7 @@ }, { "cell_type": "code", - "execution_count": 111, + "execution_count": 110, "metadata": {}, "outputs": [], "source": [ @@ -3315,7 +3303,7 @@ }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 111, "metadata": {}, "outputs": [], "source": [ @@ -3330,7 +3318,7 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 112, "metadata": {}, "outputs": [ { @@ -3375,7 +3363,7 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 113, "metadata": {}, "outputs": [ { @@ -3385,7 +3373,7 @@ " dtype=float32)" ] }, - "execution_count": 114, + "execution_count": 113, "metadata": {}, "output_type": "execute_result" } @@ -3406,7 +3394,7 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 114, "metadata": {}, "outputs": [], "source": [ @@ -3417,7 +3405,7 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 115, "metadata": {}, "outputs": [ { @@ -3512,7 +3500,7 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 116, "metadata": {}, "outputs": [], "source": [ @@ -3543,7 +3531,7 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 117, "metadata": {}, "outputs": [], "source": [ @@ -3559,7 +3547,7 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 118, "metadata": {}, "outputs": [], "source": [ @@ -3578,7 +3566,7 @@ }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 119, "metadata": {}, "outputs": [], "source": [ @@ -3600,7 +3588,7 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 120, "metadata": {}, "outputs": [], "source": [ @@ -3616,7 +3604,7 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": 121, "metadata": {}, "outputs": [ { @@ -3653,7 +3641,7 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 122, "metadata": {}, "outputs": [ { @@ -3739,7 +3727,7 @@ "" ] }, - "execution_count": 123, + "execution_count": 122, "metadata": {}, "output_type": "execute_result" } @@ -3752,7 +3740,7 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 123, "metadata": {}, "outputs": [ { @@ -3768,7 +3756,7 @@ "[1.5061508417129517, 0.4675999879837036]" ] }, - "execution_count": 124, + "execution_count": 123, "metadata": {}, "output_type": "execute_result" } @@ -3805,7 +3793,7 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 124, "metadata": {}, "outputs": [ { @@ -3891,7 +3879,7 @@ "[1.4236289262771606, 0.5073999762535095]" ] }, - "execution_count": 125, + "execution_count": 124, "metadata": {}, "output_type": "execute_result" } @@ -3948,7 +3936,7 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 125, "metadata": { "scrolled": true }, @@ -4042,7 +4030,7 @@ "[1.4607702493667603, 0.5026000142097473]" ] }, - "execution_count": 126, + "execution_count": 125, "metadata": {}, "output_type": "execute_result" } @@ -4103,7 +4091,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 126, "metadata": {}, "outputs": [ { @@ -4183,7 +4171,7 @@ "[1.4779616594314575, 0.498199999332428]" ] }, - "execution_count": 127, + "execution_count": 126, "metadata": {}, "output_type": "execute_result" } @@ -4244,7 +4232,7 @@ }, { "cell_type": "code", - "execution_count": 128, + "execution_count": 127, "metadata": {}, "outputs": [], "source": [ @@ -4262,7 +4250,7 @@ }, { "cell_type": "code", - "execution_count": 129, + "execution_count": 128, "metadata": {}, "outputs": [], "source": [ @@ -4285,7 +4273,7 @@ }, { "cell_type": "code", - "execution_count": 130, + "execution_count": 129, "metadata": {}, "outputs": [], "source": [ @@ -4307,7 +4295,7 @@ }, { "cell_type": "code", - "execution_count": 131, + "execution_count": 130, "metadata": {}, "outputs": [ { @@ -4316,7 +4304,7 @@ "0.4984" ] }, - "execution_count": 131, + "execution_count": 130, "metadata": {}, "output_type": "execute_result" } @@ -4348,7 +4336,7 @@ }, { "cell_type": "code", - "execution_count": 132, + "execution_count": 131, "metadata": {}, "outputs": [], "source": [ @@ -4372,7 +4360,7 @@ }, { "cell_type": "code", - "execution_count": 133, + "execution_count": 132, "metadata": {}, "outputs": [ { @@ -4404,7 +4392,7 @@ }, { "cell_type": "code", - "execution_count": 134, + "execution_count": 133, "metadata": {}, "outputs": [], "source": [ @@ -4428,7 +4416,7 @@ }, { "cell_type": "code", - "execution_count": 135, + "execution_count": 134, "metadata": {}, "outputs": [ { @@ -4501,7 +4489,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.9.10" }, "nav_menu": { "height": "360px",