Sync notebook with book's code examples, and better identify extra code

This commit is contained in:
Aurélien Geron
2022-02-19 18:17:36 +13:00
parent 1c2421fc88
commit b63019fd28
9 changed files with 318 additions and 301 deletions

View File

@@ -142,7 +142,7 @@
"metadata": {},
"outputs": [],
"source": [
"# not in the book it's a bit too long\n",
"# extra code it's a bit too long\n",
"print(mnist.DESCR)"
]
},
@@ -152,7 +152,7 @@
"metadata": {},
"outputs": [],
"source": [
"mnist.keys() # not in the book we only use data and target in this notebook"
"mnist.keys() # extra code we only use data and target in this notebook"
]
},
{
@@ -216,7 +216,7 @@
"\n",
"some_digit = X[0]\n",
"plot_digit(some_digit)\n",
"save_fig(\"some_digit_plot\") # not in the book\n",
"save_fig(\"some_digit_plot\") # extra code\n",
"plt.show()"
]
},
@@ -235,7 +235,7 @@
"metadata": {},
"outputs": [],
"source": [
"# not in the book this code generates Figure 32\n",
"# extra code this cell generates and saves Figure 32\n",
"plt.figure(figsize=(9, 9))\n",
"for idx, image_data in enumerate(X[:100]):\n",
" plt.subplot(10, 10, idx + 1)\n",
@@ -427,7 +427,7 @@
"metadata": {},
"outputs": [],
"source": [
"# not in the book this code also computes the precision: TP / (FP + TP)\n",
"# extra code this cell also computes the precision: TP / (FP + TP)\n",
"cm[1, 1] / (cm[0, 1] + cm[1, 1])"
]
},
@@ -446,7 +446,7 @@
"metadata": {},
"outputs": [],
"source": [
"# not in the book this code also computes the recall: TP / (FN + TP)\n",
"# extra code this cell also computes the recall: TP / (FN + TP)\n",
"cm[1, 1] / (cm[1, 0] + cm[1, 1])"
]
},
@@ -467,7 +467,7 @@
"metadata": {},
"outputs": [],
"source": [
"# not in the book this code also computes the f1 score\n",
"# extra code this cell also computes the f1 score\n",
"cm[1, 1] / (cm[1, 1] + (cm[1, 0] + cm[0, 1]) / 2)"
]
},
@@ -513,8 +513,8 @@
"metadata": {},
"outputs": [],
"source": [
"# not in the book this code just shows that y_scores > 0 produces the same\n",
"# result as calling predict()\n",
"# extra code just shows that y_scores > 0 produces the same result as\n",
"# calling predict()\n",
"y_scores > 0"
]
},
@@ -556,12 +556,12 @@
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(8, 4)) # not in the book it's not needed, just formatting\n",
"plt.figure(figsize=(8, 4)) # extra code it's not needed, just formatting\n",
"plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n",
"plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n",
"plt.vlines(threshold, 0, 1.0, \"k\", \"dotted\", label=\"threshold\")\n",
"\n",
"# not in the book this section just beautifies and saves Figure 35\n",
"# extra code this section just beautifies and saves Figure 35\n",
"idx = (thresholds >= threshold).argmax() # first index ≥ threshold\n",
"plt.plot(thresholds[idx], precisions[idx], \"bo\")\n",
"plt.plot(thresholds[idx], recalls[idx], \"go\")\n",
@@ -580,13 +580,13 @@
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.patches as patches # not in the book for the curved arrow\n",
"import matplotlib.patches as patches # extra code for the curved arrow\n",
"\n",
"plt.figure(figsize=(6, 5)) # not in the book not needed, just formatting\n",
"plt.figure(figsize=(6, 5)) # extra code not needed, just formatting\n",
"\n",
"plt.plot(recalls, precisions, linewidth=2, label=\"Precision/Recall curve\")\n",
"\n",
"# not in the book just beautifies and saves Figure 36\n",
"# extra code just beautifies and saves Figure 36\n",
"plt.plot([recalls[idx], recalls[idx]], [0., precisions[idx]], \"k:\")\n",
"plt.plot([0.0, recalls[idx]], [precisions[idx], precisions[idx]], \"k:\")\n",
"plt.plot([recalls[idx]], [precisions[idx]], \"ko\",\n",
@@ -673,12 +673,12 @@
"idx_for_threshold_at_90 = (thresholds <= threshold_for_90_precision).argmax()\n",
"tpr_90, fpr_90 = tpr[idx_for_threshold_at_90], fpr[idx_for_threshold_at_90]\n",
"\n",
"plt.figure(figsize=(6, 5)) # not in the book not needed, just formatting\n",
"plt.figure(figsize=(6, 5)) # extra code not needed, just formatting\n",
"plt.plot(fpr, tpr, linewidth=2, label=\"ROC curve\")\n",
"plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier's ROC curve\")\n",
"plt.plot([fpr_90], [tpr_90], \"ko\", label=\"Threshold for 90% precision\")\n",
"\n",
"# not in the book just beautifies and saves Figure 37\n",
"# extra code just beautifies and saves Figure 37\n",
"plt.gca().add_patch(patches.FancyArrowPatch(\n",
" (0.20, 0.89), (0.07, 0.70),\n",
" connectionstyle=\"arc3,rad=.4\",\n",
@@ -778,13 +778,13 @@
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(6, 5)) # not in the book not needed, just formatting\n",
"plt.figure(figsize=(6, 5)) # extra code not needed, just formatting\n",
"\n",
"plt.plot(recalls_forest, precisions_forest, \"b-\", linewidth=2,\n",
" label=\"Random Forest\")\n",
"plt.plot(recalls, precisions, \"--\", linewidth=2, label=\"SGD\")\n",
"\n",
"# not in the book just beautifies and saves Figure 38\n",
"# extra code just beautifies and saves Figure 38\n",
"plt.xlabel(\"Recall\")\n",
"plt.ylabel(\"Precision\")\n",
"plt.axis([0, 1, 0, 1])\n",
@@ -925,7 +925,7 @@
"metadata": {},
"outputs": [],
"source": [
"# not in the book this code shows how to get all 45 OvO scores if needed\n",
"# extra code shows how to get all 45 OvO scores if needed\n",
"svm_clf.decision_function_shape = \"ovo\"\n",
"some_digit_scores_ovo = svm_clf.decision_function([some_digit])\n",
"some_digit_scores_ovo.round(2)"
@@ -1033,7 +1033,7 @@
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"\n",
"y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)\n",
"plt.rc('font', size=9) # not in the book make the text smaller\n",
"plt.rc('font', size=9) # extra code make the text smaller\n",
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred)\n",
"plt.show()"
]
@@ -1044,7 +1044,7 @@
"metadata": {},
"outputs": [],
"source": [
"plt.rc('font', size=10) # not in the book\n",
"plt.rc('font', size=10) # extra code\n",
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred,\n",
" normalize=\"true\", values_format=\".0%\")\n",
"plt.show()"
@@ -1057,7 +1057,7 @@
"outputs": [],
"source": [
"sample_weight = (y_train_pred != y_train)\n",
"plt.rc('font', size=10) # not in the book\n",
"plt.rc('font', size=10) # extra code\n",
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred,\n",
" sample_weight=sample_weight,\n",
" normalize=\"true\", values_format=\".0%\")\n",
@@ -1077,7 +1077,7 @@
"metadata": {},
"outputs": [],
"source": [
"# not in the book this code generates Figure 39\n",
"# extra code this cell generates and saves Figure 39\n",
"fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))\n",
"plt.rc('font', size=9)\n",
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0])\n",
@@ -1096,7 +1096,7 @@
"metadata": {},
"outputs": [],
"source": [
"# not in the book this code generates Figure 310\n",
"# extra code this cell generates and saves Figure 310\n",
"fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))\n",
"plt.rc('font', size=10)\n",
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0],\n",
@@ -1131,7 +1131,7 @@
"metadata": {},
"outputs": [],
"source": [
"# not in the book this code generates Figure 311\n",
"# extra code this cell generates and saves Figure 311\n",
"size = 5\n",
"pad = 0.2\n",
"plt.figure(figsize=(size, size))\n",
@@ -1224,9 +1224,9 @@
"metadata": {},
"outputs": [],
"source": [
"# not in the book this code shows that we get a negligible performance\n",
"# improvement when we set average=\"weighted\" because the\n",
"# classes are already pretty well balanced.\n",
"# extra code shows that we get a negligible performance improvement when we\n",
"# set average=\"weighted\" because the classes are already pretty\n",
"# well balanced.\n",
"f1_score(y_multilabel, y_train_knn_pred, average=\"weighted\")"
]
},
@@ -1279,7 +1279,7 @@
"metadata": {},
"outputs": [],
"source": [
"# not in the book this code generates Figure 312\n",
"# extra code this cell generates and saves Figure 312\n",
"plt.subplot(121); plot_digit(X_test_mod[0])\n",
"plt.subplot(122); plot_digit(y_test_mod[0])\n",
"save_fig(\"noisy_digit_example_plot\")\n",
@@ -1296,7 +1296,7 @@
"knn_clf.fit(X_train_mod, y_train_mod)\n",
"clean_digit = knn_clf.predict([X_test_mod[0]])\n",
"plot_digit(clean_digit)\n",
"save_fig(\"cleaned_digit_example_plot\") # not in the book saves Figure 313\n",
"save_fig(\"cleaned_digit_example_plot\") # extra code saves Figure 313\n",
"plt.show()"
]
},