mirror of
https://github.com/ArthurDanjou/handson-ml3.git
synced 2026-01-31 03:57:55 +01:00
Drop Python 2 (woohoo!) and import matplotlib as mpl
This commit is contained in:
@@ -20,7 +20,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
|
||||
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead)."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -29,8 +29,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# To support both python 2 and python 3\n",
|
||||
"from __future__ import division, print_function, unicode_literals\n",
|
||||
"# Python ≥3.5 is required\n",
|
||||
"import sys\n",
|
||||
"assert sys.version_info >= (3, 5)\n",
|
||||
"\n",
|
||||
"# Common imports\n",
|
||||
"import numpy as np\n",
|
||||
@@ -41,11 +42,11 @@
|
||||
"\n",
|
||||
"# To plot pretty figures\n",
|
||||
"%matplotlib inline\n",
|
||||
"import matplotlib\n",
|
||||
"import matplotlib as mpl\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"plt.rcParams['axes.labelsize'] = 14\n",
|
||||
"plt.rcParams['xtick.labelsize'] = 12\n",
|
||||
"plt.rcParams['ytick.labelsize'] = 12\n",
|
||||
"mpl.rc('axes', labelsize=14)\n",
|
||||
"mpl.rc('xtick', labelsize=12)\n",
|
||||
"mpl.rc('ytick', labelsize=12)\n",
|
||||
"\n",
|
||||
"# Where to save the figures\n",
|
||||
"PROJECT_ROOT_DIR = \".\"\n",
|
||||
@@ -79,13 +80,10 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def sort_by_target(mnist):\n",
|
||||
" reorder_train = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[:60000])]))[:, 1]\n",
|
||||
" reorder_test = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[60000:])]))[:, 1]\n",
|
||||
" mnist.data[:60000] = mnist.data[reorder_train]\n",
|
||||
" mnist.target[:60000] = mnist.target[reorder_train]\n",
|
||||
" mnist.data[60000:] = mnist.data[reorder_test + 60000]\n",
|
||||
" mnist.target[60000:] = mnist.target[reorder_test + 60000]"
|
||||
"from sklearn.datasets import fetch_openml\n",
|
||||
"mnist = fetch_openml('mnist_784', version=1)\n",
|
||||
"mnist.target = mnist.target.astype(np.int8) # fetch_openml() returns targets as strings\n",
|
||||
"mnist[\"data\"], mnist[\"target\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -93,30 +91,13 @@
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"try:\n",
|
||||
" from sklearn.datasets import fetch_openml\n",
|
||||
" mnist = fetch_openml('mnist_784', version=1, cache=True)\n",
|
||||
" mnist.target = mnist.target.astype(np.int8) # fetch_openml() returns targets as strings\n",
|
||||
" sort_by_target(mnist) # fetch_openml() returns an unsorted dataset\n",
|
||||
"except ImportError:\n",
|
||||
" from sklearn.datasets import fetch_mldata\n",
|
||||
" mnist = fetch_mldata('MNIST original')\n",
|
||||
"mnist[\"data\"], mnist[\"target\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mnist.data.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -126,7 +107,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -135,7 +116,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -144,18 +125,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"import matplotlib\n",
|
||||
"import matplotlib as mpl\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"some_digit = X[36000]\n",
|
||||
"some_digit_image = some_digit.reshape(28, 28)\n",
|
||||
"plt.imshow(some_digit_image, cmap = matplotlib.cm.binary,\n",
|
||||
" interpolation=\"nearest\")\n",
|
||||
"plt.imshow(some_digit_image, cmap = mpl.cm.binary, interpolation=\"nearest\")\n",
|
||||
"plt.axis(\"off\")\n",
|
||||
"\n",
|
||||
"save_fig(\"some_digit_plot\")\n",
|
||||
@@ -164,20 +144,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_digit(data):\n",
|
||||
" image = data.reshape(28, 28)\n",
|
||||
" plt.imshow(image, cmap = matplotlib.cm.binary,\n",
|
||||
" plt.imshow(image, cmap = mpl.cm.binary,\n",
|
||||
" interpolation=\"nearest\")\n",
|
||||
" plt.axis(\"off\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -194,13 +174,13 @@
|
||||
" rimages = images[row * images_per_row : (row + 1) * images_per_row]\n",
|
||||
" row_images.append(np.concatenate(rimages, axis=1))\n",
|
||||
" image = np.concatenate(row_images, axis=0)\n",
|
||||
" plt.imshow(image, cmap = matplotlib.cm.binary, **options)\n",
|
||||
" plt.imshow(image, cmap = mpl.cm.binary, **options)\n",
|
||||
" plt.axis(\"off\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -213,7 +193,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -222,7 +202,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -231,7 +211,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -250,7 +230,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -267,7 +247,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -279,7 +259,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -288,7 +268,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -298,7 +278,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -322,7 +302,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -336,7 +316,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -346,7 +326,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -357,7 +337,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -368,7 +348,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -377,7 +357,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -386,7 +366,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -397,16 +377,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"4344 / (4344 + 1307)"
|
||||
"4432 / (4432 + 1607)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -415,16 +395,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"4344 / (4344 + 1077)"
|
||||
"4432 / (4432 + 989)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -434,16 +414,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"4344 / (4344 + (1077 + 1307)/2)"
|
||||
"4432 / (4432 + (1607 + 989)/2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -453,7 +433,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -463,7 +443,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -472,18 +452,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"threshold = 200000\n",
|
||||
"threshold = -200000\n",
|
||||
"y_some_digit_pred = (y_scores > threshold)\n",
|
||||
"y_some_digit_pred"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 41,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -491,36 +471,9 @@
|
||||
" method=\"decision_function\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note: there was an [issue](https://github.com/scikit-learn/scikit-learn/issues/9589) in Scikit-Learn 0.19.0 (fixed in 0.19.1) where the result of `cross_val_predict()` was incorrect in the binary classification case when using `method=\"decision_function\"`, as in the code above. The resulting array had an extra first dimension full of 0s. Just in case you are using 0.19.0, we need to add this small hack to work around this issue:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y_scores.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# hack to work around issue #9589 in Scikit-Learn 0.19.0\n",
|
||||
"if y_scores.ndim == 2:\n",
|
||||
" y_scores = y_scores[:, 1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -531,7 +484,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"execution_count": 43,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -551,7 +504,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -560,7 +513,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -569,7 +522,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -578,7 +531,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"execution_count": 47,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -587,7 +540,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"execution_count": 48,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -612,7 +565,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"execution_count": 49,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -623,7 +576,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"execution_count": 50,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -642,7 +595,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"execution_count": 51,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -660,7 +613,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"execution_count": 52,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -672,7 +625,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"execution_count": 53,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2531,7 +2484,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 - tf2",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -2545,7 +2498,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.6"
|
||||
"version": "3.6.8"
|
||||
},
|
||||
"nav_menu": {},
|
||||
"toc": {
|
||||
|
||||
Reference in New Issue
Block a user