Update notebook to latest Colab library versions

This commit is contained in:
Aurélien Geron
2024-10-05 19:53:13 +13:00
parent 9143aa2fc4
commit 779848d82e

View File

@@ -841,6 +841,13 @@
"y_pred = model.predict(new_set) # or you could just pass a NumPy array"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning**: the `mean_squared_error()` function no longer exists. You must now use an instance of the `MeanSquaredError` class instead."
]
},
{
"cell_type": "code",
"execution_count": 30,
@@ -857,7 +864,7 @@
"source": [
"# extra code defines the optimizer and loss function for training\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)\n",
"loss_fn = tf.keras.losses.mean_squared_error\n",
"loss_fn = tf.keras.losses.MeanSquaredError()\n",
"\n",
"n_epochs = 5\n",
"for epoch in range(n_epochs):\n",
@@ -898,7 +905,7 @@
" optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
"\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)\n",
"loss_fn = tf.keras.losses.mean_squared_error\n",
"loss_fn = tf.keras.losses.MeanSquaredError()\n",
"for epoch in range(n_epochs):\n",
" print(\"\\rEpoch {}/{}\".format(epoch + 1, n_epochs), end=\"\")\n",
" train_one_epoch(model, optimizer, loss_fn, train_set)"
@@ -2719,10 +2726,10 @@
"tf.random.set_seed(42)\n",
"np.random.seed(42)\n",
"X_train_num = np.random.rand(10_000, 8)\n",
"X_train_cat = np.random.choice(ocean_prox, size=10_000)\n",
"X_train_cat = np.random.choice(ocean_prox, size=10_000).astype(object)\n",
"y_train = np.random.rand(10_000, 1)\n",
"X_valid_num = np.random.rand(2_000, 8)\n",
"X_valid_cat = np.random.choice(ocean_prox, size=2_000)\n",
"X_valid_cat = np.random.choice(ocean_prox, size=2_000).astype(object)\n",
"y_valid = np.random.rand(2_000, 1)\n",
"\n",
"num_input = tf.keras.layers.Input(shape=[8], name=\"num\")\n",
@@ -2768,37 +2775,6 @@
" validation_data=valid_set)"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"313/313 [==============================] - 1s 1ms/step - loss: 0.0829 - val_loss: 0.0830\n",
"Epoch 2/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0829 - val_loss: 0.0830\n",
"Epoch 3/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0828 - val_loss: 0.0830\n",
"Epoch 4/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0828 - val_loss: 0.0829\n",
"Epoch 5/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0828 - val_loss: 0.0829\n"
]
}
],
"source": [
"# extra code shows that the dataset can contain dictionaries\n",
"train_set = tf.data.Dataset.from_tensor_slices(\n",
" ({\"num\": X_train_num, \"cat\": X_train_cat}, y_train)).batch(32)\n",
"valid_set = tf.data.Dataset.from_tensor_slices(\n",
" ({\"num\": X_valid_num, \"cat\": X_valid_cat}, y_valid)).batch(32)\n",
"history = model.fit(train_set, epochs=5, validation_data=valid_set)"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -2808,7 +2784,7 @@
},
{
"cell_type": "code",
"execution_count": 101,
"execution_count": 100,
"metadata": {},
"outputs": [
{
@@ -2819,7 +2795,7 @@
" [6, 2, 1, 2]])>"
]
},
"execution_count": 101,
"execution_count": 100,
"metadata": {},
"output_type": "execute_result"
}
@@ -2833,7 +2809,7 @@
},
{
"cell_type": "code",
"execution_count": 102,
"execution_count": 101,
"metadata": {},
"outputs": [
{
@@ -2842,7 +2818,7 @@
"<tf.RaggedTensor [[2, 1], [6, 2, 1, 2]]>"
]
},
"execution_count": 102,
"execution_count": 101,
"metadata": {},
"output_type": "execute_result"
}
@@ -2855,7 +2831,7 @@
},
{
"cell_type": "code",
"execution_count": 103,
"execution_count": 102,
"metadata": {},
"outputs": [
{
@@ -2868,7 +2844,7 @@
" 1.0986123 ]], dtype=float32)>"
]
},
"execution_count": 103,
"execution_count": 102,
"metadata": {},
"output_type": "execute_result"
}
@@ -2881,7 +2857,7 @@
},
{
"cell_type": "code",
"execution_count": 104,
"execution_count": 103,
"metadata": {},
"outputs": [
{
@@ -2890,7 +2866,7 @@
"1.3862943611198906"
]
},
"execution_count": 104,
"execution_count": 103,
"metadata": {},
"output_type": "execute_result"
}
@@ -2901,7 +2877,7 @@
},
{
"cell_type": "code",
"execution_count": 105,
"execution_count": 104,
"metadata": {},
"outputs": [
{
@@ -2910,7 +2886,7 @@
"1.0986122886681098"
]
},
"execution_count": 105,
"execution_count": 104,
"metadata": {},
"output_type": "execute_result"
}
@@ -2930,7 +2906,7 @@
},
{
"cell_type": "code",
"execution_count": 106,
"execution_count": 105,
"metadata": {},
"outputs": [
{
@@ -2950,7 +2926,7 @@
" -0.1 , -0.18, -0.13, -0.04, 0.15]], dtype=float32)"
]
},
"execution_count": 106,
"execution_count": 105,
"metadata": {},
"output_type": "execute_result"
}
@@ -2972,7 +2948,7 @@
},
{
"cell_type": "code",
"execution_count": 107,
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
@@ -2985,7 +2961,7 @@
},
{
"cell_type": "code",
"execution_count": 108,
"execution_count": 107,
"metadata": {},
"outputs": [
{
@@ -3009,7 +2985,7 @@
},
{
"cell_type": "code",
"execution_count": 109,
"execution_count": 108,
"metadata": {},
"outputs": [
{
@@ -3040,7 +3016,7 @@
},
{
"cell_type": "code",
"execution_count": 110,
"execution_count": 109,
"metadata": {},
"outputs": [],
"source": [
@@ -3052,7 +3028,7 @@
},
{
"cell_type": "code",
"execution_count": 111,
"execution_count": 110,
"metadata": {},
"outputs": [],
"source": [
@@ -3064,7 +3040,7 @@
},
{
"cell_type": "code",
"execution_count": 112,
"execution_count": 111,
"metadata": {},
"outputs": [],
"source": [
@@ -3075,7 +3051,7 @@
},
{
"cell_type": "code",
"execution_count": 113,
"execution_count": 112,
"metadata": {},
"outputs": [
{
@@ -3151,7 +3127,7 @@
},
{
"cell_type": "code",
"execution_count": 114,
"execution_count": 113,
"metadata": {},
"outputs": [],
"source": [
@@ -3162,7 +3138,7 @@
},
{
"cell_type": "code",
"execution_count": 115,
"execution_count": 114,
"metadata": {},
"outputs": [
{
@@ -3184,7 +3160,7 @@
},
{
"cell_type": "code",
"execution_count": 116,
"execution_count": 115,
"metadata": {},
"outputs": [],
"source": [
@@ -3201,7 +3177,7 @@
},
{
"cell_type": "code",
"execution_count": 117,
"execution_count": 116,
"metadata": {},
"outputs": [
{
@@ -3244,7 +3220,7 @@
},
{
"cell_type": "code",
"execution_count": 118,
"execution_count": 117,
"metadata": {},
"outputs": [],
"source": [
@@ -3265,7 +3241,7 @@
},
{
"cell_type": "code",
"execution_count": 119,
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
@@ -3284,7 +3260,7 @@
},
{
"cell_type": "code",
"execution_count": 120,
"execution_count": 119,
"metadata": {},
"outputs": [],
"source": [
@@ -3314,7 +3290,7 @@
},
{
"cell_type": "code",
"execution_count": 121,
"execution_count": 120,
"metadata": {},
"outputs": [],
"source": [
@@ -3325,7 +3301,7 @@
},
{
"cell_type": "code",
"execution_count": 122,
"execution_count": 121,
"metadata": {},
"outputs": [
{
@@ -3352,7 +3328,7 @@
},
{
"cell_type": "code",
"execution_count": 123,
"execution_count": 122,
"metadata": {},
"outputs": [],
"source": [
@@ -3377,7 +3353,7 @@
},
{
"cell_type": "code",
"execution_count": 124,
"execution_count": 123,
"metadata": {},
"outputs": [
{
@@ -3441,7 +3417,7 @@
"<keras.callbacks.History at 0x7fa3e08af370>"
]
},
"execution_count": 124,
"execution_count": 123,
"metadata": {},
"output_type": "execute_result"
}
@@ -3460,7 +3436,7 @@
},
{
"cell_type": "code",
"execution_count": 125,
"execution_count": 124,
"metadata": {},
"outputs": [
{
@@ -3516,7 +3492,7 @@
},
{
"cell_type": "code",
"execution_count": 126,
"execution_count": 125,
"metadata": {},
"outputs": [
{
@@ -3534,7 +3510,7 @@
"PosixPath('datasets/aclImdb')"
]
},
"execution_count": 126,
"execution_count": 125,
"metadata": {},
"output_type": "execute_result"
}
@@ -3559,7 +3535,7 @@
},
{
"cell_type": "code",
"execution_count": 127,
"execution_count": 126,
"metadata": {},
"outputs": [],
"source": [
@@ -3582,7 +3558,7 @@
},
{
"cell_type": "code",
"execution_count": 128,
"execution_count": 127,
"metadata": {},
"outputs": [
{
@@ -3636,7 +3612,7 @@
},
{
"cell_type": "code",
"execution_count": 129,
"execution_count": 128,
"metadata": {},
"outputs": [
{
@@ -3645,7 +3621,7 @@
"(12500, 12500, 12500, 12500)"
]
},
"execution_count": 129,
"execution_count": 128,
"metadata": {},
"output_type": "execute_result"
}
@@ -3672,7 +3648,7 @@
},
{
"cell_type": "code",
"execution_count": 130,
"execution_count": 129,
"metadata": {},
"outputs": [],
"source": [
@@ -3701,7 +3677,7 @@
},
{
"cell_type": "code",
"execution_count": 131,
"execution_count": 130,
"metadata": {},
"outputs": [],
"source": [
@@ -3719,7 +3695,7 @@
},
{
"cell_type": "code",
"execution_count": 132,
"execution_count": 131,
"metadata": {},
"outputs": [
{
@@ -3747,7 +3723,7 @@
},
{
"cell_type": "code",
"execution_count": 133,
"execution_count": 132,
"metadata": {},
"outputs": [
{
@@ -3778,7 +3754,7 @@
},
{
"cell_type": "code",
"execution_count": 134,
"execution_count": 133,
"metadata": {},
"outputs": [],
"source": [
@@ -3794,7 +3770,7 @@
},
{
"cell_type": "code",
"execution_count": 135,
"execution_count": 134,
"metadata": {},
"outputs": [
{
@@ -3818,7 +3794,7 @@
},
{
"cell_type": "code",
"execution_count": 136,
"execution_count": 135,
"metadata": {},
"outputs": [
{
@@ -3835,7 +3811,7 @@
},
{
"cell_type": "code",
"execution_count": 137,
"execution_count": 136,
"metadata": {},
"outputs": [],
"source": [
@@ -3864,7 +3840,7 @@
},
{
"cell_type": "code",
"execution_count": 138,
"execution_count": 137,
"metadata": {},
"outputs": [],
"source": [
@@ -3884,7 +3860,7 @@
},
{
"cell_type": "code",
"execution_count": 139,
"execution_count": 138,
"metadata": {},
"outputs": [
{
@@ -3893,7 +3869,7 @@
"['[UNK]', 'the', 'and', 'a', 'of', 'to', 'is', 'in', 'it', 'i']"
]
},
"execution_count": 139,
"execution_count": 138,
"metadata": {},
"output_type": "execute_result"
}
@@ -3918,7 +3894,7 @@
},
{
"cell_type": "code",
"execution_count": 140,
"execution_count": 139,
"metadata": {},
"outputs": [
{
@@ -3943,7 +3919,7 @@
"<keras.callbacks.History at 0x7fa401b8f9d0>"
]
},
"execution_count": 140,
"execution_count": 139,
"metadata": {},
"output_type": "execute_result"
}
@@ -3984,7 +3960,7 @@
},
{
"cell_type": "code",
"execution_count": 141,
"execution_count": 140,
"metadata": {},
"outputs": [
{
@@ -3995,7 +3971,7 @@
" [6. , 0. , 0. ]], dtype=float32)>"
]
},
"execution_count": 141,
"execution_count": 140,
"metadata": {},
"output_type": "execute_result"
}
@@ -4021,7 +3997,7 @@
},
{
"cell_type": "code",
"execution_count": 142,
"execution_count": 141,
"metadata": {},
"outputs": [
{
@@ -4030,7 +4006,7 @@
"<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[3.535534 , 4.9497476, 2.1213202]], dtype=float32)>"
]
},
"execution_count": 142,
"execution_count": 141,
"metadata": {},
"output_type": "execute_result"
}
@@ -4048,7 +4024,7 @@
},
{
"cell_type": "code",
"execution_count": 143,
"execution_count": 142,
"metadata": {},
"outputs": [
{
@@ -4057,7 +4033,7 @@
"<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[6., 0., 0.]], dtype=float32)>"
]
},
"execution_count": 143,
"execution_count": 142,
"metadata": {},
"output_type": "execute_result"
}
@@ -4075,7 +4051,7 @@
},
{
"cell_type": "code",
"execution_count": 144,
"execution_count": 143,
"metadata": {},
"outputs": [],
"source": [
@@ -4107,7 +4083,7 @@
},
{
"cell_type": "code",
"execution_count": 145,
"execution_count": 144,
"metadata": {},
"outputs": [
{
@@ -4132,7 +4108,7 @@
"<keras.callbacks.History at 0x7fa3a0bf9460>"
]
},
"execution_count": 145,
"execution_count": 144,
"metadata": {},
"output_type": "execute_result"
}
@@ -4160,7 +4136,7 @@
},
{
"cell_type": "code",
"execution_count": 146,
"execution_count": 145,
"metadata": {},
"outputs": [],
"source": [
@@ -4172,7 +4148,7 @@
},
{
"cell_type": "code",
"execution_count": 147,
"execution_count": 146,
"metadata": {},
"outputs": [
{
@@ -4214,7 +4190,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.9.10"
},
"nav_menu": {
"height": "264px",