Update notebook to latest Colab library versions

This commit is contained in:
Aurélien Geron
2024-10-05 19:53:13 +13:00
parent 9143aa2fc4
commit 779848d82e

View File

@@ -841,6 +841,13 @@
"y_pred = model.predict(new_set) # or you could just pass a NumPy array" "y_pred = model.predict(new_set) # or you could just pass a NumPy array"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning**: the `mean_squared_error()` function no longer exists. You must now use an instance of the `MeanSquaredError` class instead."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 30,
@@ -857,7 +864,7 @@
"source": [ "source": [
"# extra code defines the optimizer and loss function for training\n", "# extra code defines the optimizer and loss function for training\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)\n", "optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)\n",
"loss_fn = tf.keras.losses.mean_squared_error\n", "loss_fn = tf.keras.losses.MeanSquaredError()\n",
"\n", "\n",
"n_epochs = 5\n", "n_epochs = 5\n",
"for epoch in range(n_epochs):\n", "for epoch in range(n_epochs):\n",
@@ -898,7 +905,7 @@
" optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n", " optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
"\n", "\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)\n", "optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)\n",
"loss_fn = tf.keras.losses.mean_squared_error\n", "loss_fn = tf.keras.losses.MeanSquaredError()\n",
"for epoch in range(n_epochs):\n", "for epoch in range(n_epochs):\n",
" print(\"\\rEpoch {}/{}\".format(epoch + 1, n_epochs), end=\"\")\n", " print(\"\\rEpoch {}/{}\".format(epoch + 1, n_epochs), end=\"\")\n",
" train_one_epoch(model, optimizer, loss_fn, train_set)" " train_one_epoch(model, optimizer, loss_fn, train_set)"
@@ -2719,10 +2726,10 @@
"tf.random.set_seed(42)\n", "tf.random.set_seed(42)\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
"X_train_num = np.random.rand(10_000, 8)\n", "X_train_num = np.random.rand(10_000, 8)\n",
"X_train_cat = np.random.choice(ocean_prox, size=10_000)\n", "X_train_cat = np.random.choice(ocean_prox, size=10_000).astype(object)\n",
"y_train = np.random.rand(10_000, 1)\n", "y_train = np.random.rand(10_000, 1)\n",
"X_valid_num = np.random.rand(2_000, 8)\n", "X_valid_num = np.random.rand(2_000, 8)\n",
"X_valid_cat = np.random.choice(ocean_prox, size=2_000)\n", "X_valid_cat = np.random.choice(ocean_prox, size=2_000).astype(object)\n",
"y_valid = np.random.rand(2_000, 1)\n", "y_valid = np.random.rand(2_000, 1)\n",
"\n", "\n",
"num_input = tf.keras.layers.Input(shape=[8], name=\"num\")\n", "num_input = tf.keras.layers.Input(shape=[8], name=\"num\")\n",
@@ -2768,37 +2775,6 @@
" validation_data=valid_set)" " validation_data=valid_set)"
] ]
}, },
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"313/313 [==============================] - 1s 1ms/step - loss: 0.0829 - val_loss: 0.0830\n",
"Epoch 2/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0829 - val_loss: 0.0830\n",
"Epoch 3/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0828 - val_loss: 0.0830\n",
"Epoch 4/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0828 - val_loss: 0.0829\n",
"Epoch 5/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0828 - val_loss: 0.0829\n"
]
}
],
"source": [
"# extra code shows that the dataset can contain dictionaries\n",
"train_set = tf.data.Dataset.from_tensor_slices(\n",
" ({\"num\": X_train_num, \"cat\": X_train_cat}, y_train)).batch(32)\n",
"valid_set = tf.data.Dataset.from_tensor_slices(\n",
" ({\"num\": X_valid_num, \"cat\": X_valid_cat}, y_valid)).batch(32)\n",
"history = model.fit(train_set, epochs=5, validation_data=valid_set)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@@ -2808,7 +2784,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 101, "execution_count": 100,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -2819,7 +2795,7 @@
" [6, 2, 1, 2]])>" " [6, 2, 1, 2]])>"
] ]
}, },
"execution_count": 101, "execution_count": 100,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -2833,7 +2809,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 102, "execution_count": 101,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -2842,7 +2818,7 @@
"<tf.RaggedTensor [[2, 1], [6, 2, 1, 2]]>" "<tf.RaggedTensor [[2, 1], [6, 2, 1, 2]]>"
] ]
}, },
"execution_count": 102, "execution_count": 101,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -2855,7 +2831,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 103, "execution_count": 102,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -2868,7 +2844,7 @@
" 1.0986123 ]], dtype=float32)>" " 1.0986123 ]], dtype=float32)>"
] ]
}, },
"execution_count": 103, "execution_count": 102,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -2881,7 +2857,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 104, "execution_count": 103,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -2890,7 +2866,7 @@
"1.3862943611198906" "1.3862943611198906"
] ]
}, },
"execution_count": 104, "execution_count": 103,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -2901,7 +2877,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 105, "execution_count": 104,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -2910,7 +2886,7 @@
"1.0986122886681098" "1.0986122886681098"
] ]
}, },
"execution_count": 105, "execution_count": 104,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -2930,7 +2906,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 106, "execution_count": 105,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -2950,7 +2926,7 @@
" -0.1 , -0.18, -0.13, -0.04, 0.15]], dtype=float32)" " -0.1 , -0.18, -0.13, -0.04, 0.15]], dtype=float32)"
] ]
}, },
"execution_count": 106, "execution_count": 105,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -2972,7 +2948,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 107, "execution_count": 106,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -2985,7 +2961,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 108, "execution_count": 107,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3009,7 +2985,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 109, "execution_count": 108,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3040,7 +3016,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 110, "execution_count": 109,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3052,7 +3028,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 111, "execution_count": 110,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3064,7 +3040,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 112, "execution_count": 111,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3075,7 +3051,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 113, "execution_count": 112,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3151,7 +3127,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 114, "execution_count": 113,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3162,7 +3138,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 115, "execution_count": 114,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3184,7 +3160,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 116, "execution_count": 115,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3201,7 +3177,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 117, "execution_count": 116,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3244,7 +3220,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 118, "execution_count": 117,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3265,7 +3241,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 119, "execution_count": 118,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3284,7 +3260,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 120, "execution_count": 119,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3314,7 +3290,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 121, "execution_count": 120,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3325,7 +3301,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 122, "execution_count": 121,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3352,7 +3328,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 123, "execution_count": 122,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3377,7 +3353,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 124, "execution_count": 123,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3441,7 +3417,7 @@
"<keras.callbacks.History at 0x7fa3e08af370>" "<keras.callbacks.History at 0x7fa3e08af370>"
] ]
}, },
"execution_count": 124, "execution_count": 123,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -3460,7 +3436,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 125, "execution_count": 124,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3516,7 +3492,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 126, "execution_count": 125,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3534,7 +3510,7 @@
"PosixPath('datasets/aclImdb')" "PosixPath('datasets/aclImdb')"
] ]
}, },
"execution_count": 126, "execution_count": 125,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -3559,7 +3535,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 127, "execution_count": 126,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3582,7 +3558,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 128, "execution_count": 127,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3636,7 +3612,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 129, "execution_count": 128,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3645,7 +3621,7 @@
"(12500, 12500, 12500, 12500)" "(12500, 12500, 12500, 12500)"
] ]
}, },
"execution_count": 129, "execution_count": 128,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -3672,7 +3648,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 130, "execution_count": 129,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3701,7 +3677,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 131, "execution_count": 130,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3719,7 +3695,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 132, "execution_count": 131,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3747,7 +3723,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 133, "execution_count": 132,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3778,7 +3754,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 134, "execution_count": 133,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3794,7 +3770,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 135, "execution_count": 134,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3818,7 +3794,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 136, "execution_count": 135,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3835,7 +3811,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 137, "execution_count": 136,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3864,7 +3840,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 138, "execution_count": 137,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -3884,7 +3860,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 139, "execution_count": 138,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3893,7 +3869,7 @@
"['[UNK]', 'the', 'and', 'a', 'of', 'to', 'is', 'in', 'it', 'i']" "['[UNK]', 'the', 'and', 'a', 'of', 'to', 'is', 'in', 'it', 'i']"
] ]
}, },
"execution_count": 139, "execution_count": 138,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -3918,7 +3894,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 140, "execution_count": 139,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3943,7 +3919,7 @@
"<keras.callbacks.History at 0x7fa401b8f9d0>" "<keras.callbacks.History at 0x7fa401b8f9d0>"
] ]
}, },
"execution_count": 140, "execution_count": 139,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -3984,7 +3960,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 141, "execution_count": 140,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -3995,7 +3971,7 @@
" [6. , 0. , 0. ]], dtype=float32)>" " [6. , 0. , 0. ]], dtype=float32)>"
] ]
}, },
"execution_count": 141, "execution_count": 140,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -4021,7 +3997,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 142, "execution_count": 141,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -4030,7 +4006,7 @@
"<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[3.535534 , 4.9497476, 2.1213202]], dtype=float32)>" "<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[3.535534 , 4.9497476, 2.1213202]], dtype=float32)>"
] ]
}, },
"execution_count": 142, "execution_count": 141,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -4048,7 +4024,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 143, "execution_count": 142,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -4057,7 +4033,7 @@
"<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[6., 0., 0.]], dtype=float32)>" "<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[6., 0., 0.]], dtype=float32)>"
] ]
}, },
"execution_count": 143, "execution_count": 142,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -4075,7 +4051,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 144, "execution_count": 143,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -4107,7 +4083,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 145, "execution_count": 144,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -4132,7 +4108,7 @@
"<keras.callbacks.History at 0x7fa3a0bf9460>" "<keras.callbacks.History at 0x7fa3a0bf9460>"
] ]
}, },
"execution_count": 145, "execution_count": 144,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -4160,7 +4136,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 146, "execution_count": 145,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -4172,7 +4148,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 147, "execution_count": 146,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -4214,7 +4190,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.6" "version": "3.9.10"
}, },
"nav_menu": { "nav_menu": {
"height": "264px", "height": "264px",