From 01784d2f98e6e1ff42ac7d6c7399bca5ddd7ea2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Sun, 9 Jun 2019 20:08:53 +0800 Subject: [PATCH] Do not use a layer as an activation function, especially with weights --- 11_training_deep_neural_networks.ipynb | 304 ++++++++++++++----------- 1 file changed, 170 insertions(+), 134 deletions(-) diff --git a/11_training_deep_neural_networks.ipynb b/11_training_deep_neural_networks.ipynb index d031822..ea8f71d 100644 --- a/11_training_deep_neural_networks.ipynb +++ b/11_training_deep_neural_networks.ipynb @@ -215,25 +215,6 @@ "[m for m in dir(keras.layers) if \"relu\" in m.lower()]" ] }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "leaky_relu = keras.layers.LeakyReLU(alpha=0.2)\n", - "layer = keras.layers.Dense(10, activation=leaky_relu)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "layer.activation" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -243,7 +224,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -256,26 +237,70 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ + "tf.random.set_seed(42)\n", + "np.random.seed(42)\n", + "\n", "model = keras.models.Sequential([\n", " keras.layers.Flatten(input_shape=[28, 28]),\n", - " keras.layers.Dense(300, activation=leaky_relu),\n", - " keras.layers.Dense(100, activation=leaky_relu),\n", + " keras.layers.Dense(300, kernel_initializer=\"he_normal\"),\n", + " keras.layers.LeakyReLU(),\n", + " keras.layers.Dense(100, kernel_initializer=\"he_normal\"),\n", + " keras.layers.LeakyReLU(),\n", " keras.layers.Dense(10, activation=\"softmax\")\n", "])" ] }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "model.compile(loss=\"sparse_categorical_crossentropy\",\n", + " optimizer=keras.optimizers.SGD(lr=1e-3),\n", + " metrics=[\"accuracy\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "history = model.fit(X_train, y_train, epochs=10,\n", + " validation_data=(X_valid, y_valid))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's try PReLU:" + ] + }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ - "model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"sgd\",\n", - " metrics=[\"accuracy\"])" + "tf.random.set_seed(42)\n", + "np.random.seed(42)\n", + "\n", + "model = keras.models.Sequential([\n", + " keras.layers.Flatten(input_shape=[28, 28]),\n", + " keras.layers.Dense(300, kernel_initializer=\"he_normal\"),\n", + " keras.layers.PReLU(),\n", + " keras.layers.Dense(100, kernel_initializer=\"he_normal\"),\n", + " keras.layers.PReLU(),\n", + " keras.layers.Dense(10, activation=\"softmax\")\n", + "])" ] }, { @@ -283,6 +308,17 @@ "execution_count": 16, "metadata": {}, "outputs": [], + "source": [ + "model.compile(loss=\"sparse_categorical_crossentropy\",\n", + " optimizer=keras.optimizers.SGD(lr=1e-3),\n", + " metrics=[\"accuracy\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], "source": [ "history = model.fit(X_train, y_train, epochs=10,\n", " validation_data=(X_valid, y_valid))" @@ -297,7 +333,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -307,7 +343,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -332,7 +368,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -355,7 +391,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -369,7 +405,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -379,7 +415,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -404,7 +440,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -428,7 +464,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -445,7 +481,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -455,7 +491,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -471,7 +507,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -488,7 +524,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -501,7 +537,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -518,7 +554,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -528,7 +564,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 32, "metadata": {}, "outputs": [], "source": [ @@ -542,7 +578,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -552,7 +588,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -576,7 +612,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -593,7 +629,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -602,7 +638,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ @@ -612,7 +648,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -621,7 +657,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -631,7 +667,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 40, "metadata": {}, "outputs": [], "source": [ @@ -648,7 +684,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ @@ -667,7 +703,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 42, "metadata": {}, "outputs": [], "source": [ @@ -677,7 +713,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 43, "metadata": {}, "outputs": [], "source": [ @@ -701,7 +737,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 44, "metadata": {}, "outputs": [], "source": [ @@ -710,7 +746,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ @@ -746,7 +782,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ @@ -767,7 +803,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ @@ -776,7 +812,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 48, "metadata": {}, "outputs": [], "source": [ @@ -785,7 +821,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 49, "metadata": {}, "outputs": [], "source": [ @@ -794,7 +830,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ @@ -803,7 +839,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ @@ -813,7 +849,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ @@ -826,7 +862,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 53, "metadata": {}, "outputs": [], "source": [ @@ -836,7 +872,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 54, "metadata": {}, "outputs": [], "source": [ @@ -846,7 +882,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 55, "metadata": {}, "outputs": [], "source": [ @@ -855,7 +891,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 56, "metadata": {}, "outputs": [], "source": [ @@ -868,7 +904,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 57, "metadata": {}, "outputs": [], "source": [ @@ -878,7 +914,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 58, "metadata": {}, "outputs": [], "source": [ @@ -888,7 +924,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 59, "metadata": {}, "outputs": [], "source": [ @@ -897,7 +933,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 60, "metadata": {}, "outputs": [], "source": [ @@ -908,7 +944,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 61, "metadata": {}, "outputs": [], "source": [ @@ -918,7 +954,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 62, "metadata": {}, "outputs": [], "source": [ @@ -931,7 +967,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 63, "metadata": {}, "outputs": [], "source": [ @@ -956,7 +992,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 64, "metadata": {}, "outputs": [], "source": [ @@ -965,7 +1001,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 65, "metadata": {}, "outputs": [], "source": [ @@ -981,7 +1017,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 66, "metadata": {}, "outputs": [], "source": [ @@ -1004,7 +1040,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 67, "metadata": {}, "outputs": [], "source": [ @@ -1020,7 +1056,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 68, "metadata": {}, "outputs": [], "source": [ @@ -1036,7 +1072,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 69, "metadata": {}, "outputs": [], "source": [ @@ -1052,7 +1088,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 70, "metadata": {}, "outputs": [], "source": [ @@ -1068,7 +1104,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 71, "metadata": {}, "outputs": [], "source": [ @@ -1084,7 +1120,7 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 72, "metadata": {}, "outputs": [], "source": [ @@ -1100,7 +1136,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 73, "metadata": {}, "outputs": [], "source": [ @@ -1131,7 +1167,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 74, "metadata": {}, "outputs": [], "source": [ @@ -1140,7 +1176,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 75, "metadata": {}, "outputs": [], "source": [ @@ -1155,7 +1191,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 76, "metadata": {}, "outputs": [], "source": [ @@ -1166,7 +1202,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 77, "metadata": {}, "outputs": [], "source": [ @@ -1202,7 +1238,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 78, "metadata": {}, "outputs": [], "source": [ @@ -1212,7 +1248,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 79, "metadata": {}, "outputs": [], "source": [ @@ -1226,7 +1262,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 80, "metadata": {}, "outputs": [], "source": [ @@ -1242,7 +1278,7 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 81, "metadata": {}, "outputs": [], "source": [ @@ -1254,7 +1290,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 82, "metadata": {}, "outputs": [], "source": [ @@ -1276,7 +1312,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 83, "metadata": {}, "outputs": [], "source": [ @@ -1293,7 +1329,7 @@ }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 84, "metadata": {}, "outputs": [], "source": [ @@ -1333,7 +1369,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 85, "metadata": {}, "outputs": [], "source": [ @@ -1344,7 +1380,7 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 86, "metadata": { "scrolled": true }, @@ -1368,7 +1404,7 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 87, "metadata": {}, "outputs": [], "source": [ @@ -1383,7 +1419,7 @@ }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 88, "metadata": {}, "outputs": [], "source": [ @@ -1399,7 +1435,7 @@ }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 89, "metadata": {}, "outputs": [], "source": [ @@ -1420,7 +1456,7 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 90, "metadata": {}, "outputs": [], "source": [ @@ -1442,7 +1478,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 91, "metadata": {}, "outputs": [], "source": [ @@ -1452,7 +1488,7 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 92, "metadata": {}, "outputs": [], "source": [ @@ -1474,7 +1510,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 93, "metadata": {}, "outputs": [], "source": [ @@ -1503,7 +1539,7 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 94, "metadata": {}, "outputs": [], "source": [ @@ -1531,7 +1567,7 @@ }, { "cell_type": "code", - "execution_count": 94, + "execution_count": 95, "metadata": {}, "outputs": [], "source": [ @@ -1549,7 +1585,7 @@ }, { "cell_type": "code", - "execution_count": 95, + "execution_count": 96, "metadata": {}, "outputs": [], "source": [ @@ -1589,7 +1625,7 @@ }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 97, "metadata": {}, "outputs": [], "source": [ @@ -1607,7 +1643,7 @@ }, { "cell_type": "code", - "execution_count": 97, + "execution_count": 98, "metadata": {}, "outputs": [], "source": [ @@ -1618,7 +1654,7 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": 99, "metadata": {}, "outputs": [], "source": [ @@ -1651,7 +1687,7 @@ }, { "cell_type": "code", - "execution_count": 99, + "execution_count": 100, "metadata": {}, "outputs": [], "source": [ @@ -1678,7 +1714,7 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": 101, "metadata": {}, "outputs": [], "source": [ @@ -1691,7 +1727,7 @@ }, { "cell_type": "code", - "execution_count": 101, + "execution_count": 102, "metadata": {}, "outputs": [], "source": [ @@ -1714,7 +1750,7 @@ }, { "cell_type": "code", - "execution_count": 102, + "execution_count": 103, "metadata": {}, "outputs": [], "source": [ @@ -1746,7 +1782,7 @@ }, { "cell_type": "code", - "execution_count": 103, + "execution_count": 104, "metadata": {}, "outputs": [], "source": [ @@ -1774,7 +1810,7 @@ }, { "cell_type": "code", - "execution_count": 104, + "execution_count": 105, "metadata": {}, "outputs": [], "source": [ @@ -1784,7 +1820,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": 106, "metadata": {}, "outputs": [], "source": [ @@ -1806,7 +1842,7 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 107, "metadata": {}, "outputs": [], "source": [ @@ -1815,7 +1851,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 108, "metadata": {}, "outputs": [], "source": [ @@ -1824,7 +1860,7 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 109, "metadata": {}, "outputs": [], "source": [ @@ -1840,7 +1876,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 110, "metadata": {}, "outputs": [], "source": [ @@ -1850,7 +1886,7 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 111, "metadata": {}, "outputs": [], "source": [ @@ -1862,7 +1898,7 @@ }, { "cell_type": "code", - "execution_count": 111, + "execution_count": 112, "metadata": {}, "outputs": [], "source": [ @@ -1871,7 +1907,7 @@ }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 113, "metadata": {}, "outputs": [], "source": [ @@ -1880,7 +1916,7 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 114, "metadata": {}, "outputs": [], "source": [ @@ -1889,7 +1925,7 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 115, "metadata": {}, "outputs": [], "source": [ @@ -1899,7 +1935,7 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 116, "metadata": {}, "outputs": [], "source": [ @@ -1908,7 +1944,7 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 117, "metadata": {}, "outputs": [], "source": [ @@ -1918,7 +1954,7 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 118, "metadata": {}, "outputs": [], "source": [ @@ -1933,7 +1969,7 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 119, "metadata": {}, "outputs": [], "source": [ @@ -1943,7 +1979,7 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 120, "metadata": {}, "outputs": [], "source": [ @@ -1955,7 +1991,7 @@ }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 121, "metadata": {}, "outputs": [], "source": [ @@ -1964,7 +2000,7 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 122, "metadata": {}, "outputs": [], "source": [ @@ -1974,7 +2010,7 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": 123, "metadata": {}, "outputs": [], "source": [ @@ -1990,7 +2026,7 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 124, "metadata": {}, "outputs": [], "source": [ @@ -2006,7 +2042,7 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 125, "metadata": {}, "outputs": [], "source": [ @@ -2016,7 +2052,7 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 126, "metadata": {}, "outputs": [], "source": [