mirror of
https://github.com/ArthurDanjou/handson-ml3.git
synced 2026-01-14 12:14:36 +01:00
Update libraries to latest version, including TensorFlow 2.4.1 and Scikit-Learn 0.24.1
This commit is contained in:
@@ -84,11 +84,7 @@
|
||||
" print(\"Saving figure\", fig_id)\n",
|
||||
" if tight_layout:\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.savefig(path, format=fig_extension, dpi=resolution)\n",
|
||||
"\n",
|
||||
"# Ignore useless warnings (see SciPy issue #5998)\n",
|
||||
"import warnings\n",
|
||||
"warnings.filterwarnings(action=\"ignore\", message=\"^internal gelsd\")"
|
||||
" plt.savefig(path, format=fig_extension, dpi=resolution)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -735,13 +731,21 @@
|
||||
"y_proba.round(2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Warning**: `model.predict_classes(X_new)` is deprecated. It is replaced with `np.argmax(model.predict(X_new), axis=-1)`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y_pred = model.predict_classes(X_new)\n",
|
||||
"#y_pred = model.predict_classes(X_new) # deprecated\n",
|
||||
"y_pred = np.argmax(model.predict(X_new), axis=-1)\n",
|
||||
"y_pred"
|
||||
]
|
||||
},
|
||||
@@ -1514,7 +1518,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Warning**: the following cell crashes at the end of training. This seems to be caused by [Keras issue #13586](https://github.com/keras-team/keras/issues/13586), which was triggered by a recent change in Scikit-Learn. [Pull Request #13598](https://github.com/keras-team/keras/pull/13598) seems to fix the issue, so this problem should be resolved soon."
|
||||
"**Warning**: the following cell crashes at the end of training. This seems to be caused by [Keras issue #13586](https://github.com/keras-team/keras/issues/13586), which was triggered by a recent change in Scikit-Learn. [Pull Request #13598](https://github.com/keras-team/keras/pull/13598) seems to fix the issue, so this problem should be resolved soon. In the meantime, I've added `.tolist()` and `.rvs(1000).tolist()` as workarounds."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1528,8 +1532,8 @@
|
||||
"\n",
|
||||
"param_distribs = {\n",
|
||||
" \"n_hidden\": [0, 1, 2, 3],\n",
|
||||
" \"n_neurons\": np.arange(1, 100),\n",
|
||||
" \"learning_rate\": reciprocal(3e-4, 3e-2),\n",
|
||||
" \"n_neurons\": np.arange(1, 100) .tolist(),\n",
|
||||
" \"learning_rate\": reciprocal(3e-4, 3e-2) .rvs(1000).tolist(),\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"rnd_search_cv = RandomizedSearchCV(keras_reg, param_distribs, n_iter=10, cv=3, verbose=2)\n",
|
||||
@@ -1888,6 +1892,7 @@
|
||||
"plt.gca().set_xscale('log')\n",
|
||||
"plt.hlines(min(expon_lr.losses), min(expon_lr.rates), max(expon_lr.rates))\n",
|
||||
"plt.axis([min(expon_lr.rates), max(expon_lr.rates), 0, expon_lr.losses[0]])\n",
|
||||
"plt.grid()\n",
|
||||
"plt.xlabel(\"Learning rate\")\n",
|
||||
"plt.ylabel(\"Loss\")"
|
||||
]
|
||||
@@ -1896,7 +1901,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The loss starts shooting back up violently around 3e-1, so let's try using 2e-1 as our learning rate:"
|
||||
"The loss starts shooting back up violently when the learning rate goes over 6e-1, so let's try using half of that, at 3e-1:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1931,7 +1936,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
|
||||
" optimizer=keras.optimizers.SGD(lr=2e-1),\n",
|
||||
" optimizer=keras.optimizers.SGD(lr=3e-1),\n",
|
||||
" metrics=[\"accuracy\"])"
|
||||
]
|
||||
},
|
||||
@@ -1958,7 +1963,7 @@
|
||||
"\n",
|
||||
"history = model.fit(X_train, y_train, epochs=100,\n",
|
||||
" validation_data=(X_valid, y_valid),\n",
|
||||
" callbacks=[early_stopping_cb, checkpoint_cb, tensorboard_cb])"
|
||||
" callbacks=[checkpoint_cb, early_stopping_cb, tensorboard_cb])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -2011,7 +2016,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
"version": "3.7.9"
|
||||
},
|
||||
"nav_menu": {
|
||||
"height": "264px",
|
||||
|
||||
Reference in New Issue
Block a user