Drop Python 2 (woohoo!) and import matplotlib as mpl

This commit is contained in:
Aurélien Geron
2019-01-16 23:42:00 +08:00
parent ca6eb8c147
commit d2518a679b
8 changed files with 223 additions and 259 deletions

View File

@@ -20,7 +20,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead)."
]
},
{
@@ -29,8 +29,9 @@
"metadata": {},
"outputs": [],
"source": [
"# To support both python 2 and python 3\n",
"from __future__ import division, print_function, unicode_literals\n",
"# Python ≥3.5 is required\n",
"import sys\n",
"assert sys.version_info >= (3, 5)\n",
"\n",
"# Common imports\n",
"import numpy as np\n",
@@ -41,11 +42,11 @@
"\n",
"# To plot pretty figures\n",
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['axes.labelsize'] = 14\n",
"plt.rcParams['xtick.labelsize'] = 12\n",
"plt.rcParams['ytick.labelsize'] = 12\n",
"mpl.rc('axes', labelsize=14)\n",
"mpl.rc('xtick', labelsize=12)\n",
"mpl.rc('ytick', labelsize=12)\n",
"\n",
"# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n",
@@ -79,13 +80,10 @@
"metadata": {},
"outputs": [],
"source": [
"def sort_by_target(mnist):\n",
" reorder_train = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[:60000])]))[:, 1]\n",
" reorder_test = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[60000:])]))[:, 1]\n",
" mnist.data[:60000] = mnist.data[reorder_train]\n",
" mnist.target[:60000] = mnist.target[reorder_train]\n",
" mnist.data[60000:] = mnist.data[reorder_test + 60000]\n",
" mnist.target[60000:] = mnist.target[reorder_test + 60000]"
"from sklearn.datasets import fetch_openml\n",
"mnist = fetch_openml('mnist_784', version=1)\n",
"mnist.target = mnist.target.astype(np.int8) # fetch_openml() returns targets as strings\n",
"mnist[\"data\"], mnist[\"target\"]"
]
},
{
@@ -93,30 +91,13 @@
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" from sklearn.datasets import fetch_openml\n",
" mnist = fetch_openml('mnist_784', version=1, cache=True)\n",
" mnist.target = mnist.target.astype(np.int8) # fetch_openml() returns targets as strings\n",
" sort_by_target(mnist) # fetch_openml() returns an unsorted dataset\n",
"except ImportError:\n",
" from sklearn.datasets import fetch_mldata\n",
" mnist = fetch_mldata('MNIST original')\n",
"mnist[\"data\"], mnist[\"target\"]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"mnist.data.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -126,7 +107,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -135,7 +116,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -144,18 +125,17 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"\n",
"some_digit = X[36000]\n",
"some_digit_image = some_digit.reshape(28, 28)\n",
"plt.imshow(some_digit_image, cmap = matplotlib.cm.binary,\n",
" interpolation=\"nearest\")\n",
"plt.imshow(some_digit_image, cmap = mpl.cm.binary, interpolation=\"nearest\")\n",
"plt.axis(\"off\")\n",
"\n",
"save_fig(\"some_digit_plot\")\n",
@@ -164,20 +144,20 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def plot_digit(data):\n",
" image = data.reshape(28, 28)\n",
" plt.imshow(image, cmap = matplotlib.cm.binary,\n",
" plt.imshow(image, cmap = mpl.cm.binary,\n",
" interpolation=\"nearest\")\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -194,13 +174,13 @@
" rimages = images[row * images_per_row : (row + 1) * images_per_row]\n",
" row_images.append(np.concatenate(rimages, axis=1))\n",
" image = np.concatenate(row_images, axis=0)\n",
" plt.imshow(image, cmap = matplotlib.cm.binary, **options)\n",
" plt.imshow(image, cmap = mpl.cm.binary, **options)\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -213,7 +193,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -222,7 +202,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -231,7 +211,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@@ -250,7 +230,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@@ -267,7 +247,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@@ -279,7 +259,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
@@ -288,7 +268,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@@ -298,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@@ -322,7 +302,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
@@ -336,7 +316,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@@ -346,7 +326,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
@@ -357,7 +337,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
@@ -368,7 +348,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
@@ -377,7 +357,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
@@ -386,7 +366,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
@@ -397,16 +377,16 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"4344 / (4344 + 1307)"
"4432 / (4432 + 1607)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
@@ -415,16 +395,16 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"4344 / (4344 + 1077)"
"4432 / (4432 + 989)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
@@ -434,16 +414,16 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"4344 / (4344 + (1077 + 1307)/2)"
"4432 / (4432 + (1607 + 989)/2)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
@@ -453,7 +433,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
@@ -463,7 +443,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
@@ -472,18 +452,18 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"threshold = 200000\n",
"threshold = -200000\n",
"y_some_digit_pred = (y_scores > threshold)\n",
"y_some_digit_pred"
]
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
@@ -491,36 +471,9 @@
" method=\"decision_function\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: there was an [issue](https://github.com/scikit-learn/scikit-learn/issues/9589) in Scikit-Learn 0.19.0 (fixed in 0.19.1) where the result of `cross_val_predict()` was incorrect in the binary classification case when using `method=\"decision_function\"`, as in the code above. The resulting array had an extra first dimension full of 0s. Just in case you are using 0.19.0, we need to add this small hack to work around this issue:"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"y_scores.shape"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"# hack to work around issue #9589 in Scikit-Learn 0.19.0\n",
"if y_scores.ndim == 2:\n",
" y_scores = y_scores[:, 1]"
]
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
@@ -531,7 +484,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
@@ -551,7 +504,7 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
@@ -560,7 +513,7 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
@@ -569,7 +522,7 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
@@ -578,7 +531,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
@@ -587,7 +540,7 @@
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
@@ -612,7 +565,7 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
@@ -623,7 +576,7 @@
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
@@ -642,7 +595,7 @@
},
{
"cell_type": "code",
"execution_count": 48,
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
@@ -660,7 +613,7 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
@@ -672,7 +625,7 @@
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
@@ -2531,7 +2484,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 - tf2",
"language": "python",
"name": "python3"
},
@@ -2545,7 +2498,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
"version": "3.6.8"
},
"nav_menu": {},
"toc": {