From 5a02922d34da44fe570ea4f5e27e63de49cf75fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Sun, 22 May 2016 17:40:18 +0200 Subject: [PATCH] Add classification notebook --- classification.ipynb | 1213 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1213 insertions(+) create mode 100644 classification.ipynb diff --git a/classification.ipynb b/classification.ipynb new file mode 100644 index 0000000..f0abcb5 --- /dev/null +++ b/classification.ipynb @@ -0,0 +1,1213 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Classification**" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from __future__ import division, print_function, unicode_literals\n", + "\n", + "import numpy as np\n", + "import numpy.random as rnd\n", + "rnd.seed(42) # to make this notebook's output stable across runs\n", + "\n", + "%matplotlib inline\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "plt.rcParams['axes.labelsize'] = 14\n", + "plt.rcParams['xtick.labelsize'] = 12\n", + "plt.rcParams['ytick.labelsize'] = 12\n", + "\n", + "PROJECT_ROOT_DIR = \".\"\n", + "CHAPTER_ID = \"classification\"\n", + "\n", + "def save_fig(fig_id, tight_layout=True):\n", + " path = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id + \".png\")\n", + " print(\"Saving figure\", fig_id)\n", + " if tight_layout:\n", + " plt.tight_layout()\n", + " plt.savefig(path, format='png', dpi=300)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MNIST" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from sklearn.datasets import fetch_mldata\n", + "mnist = fetch_mldata('MNIST original')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "mnist" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "X, y = mnist[\"data\"], mnist[\"target\"]\n", + "X.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "y.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "28*28" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def plot_digit(data):\n", + " image = data.reshape(28, 28)\n", + " plt.imshow(image, cmap = matplotlib.cm.binary,\n", + " interpolation=\"nearest\")\n", + " plt.axis(\"off\")\n", + "\n", + "some_digit_index = 36000\n", + "some_digit = X[some_digit_index]\n", + "plot_digit(some_digit)\n", + "save_fig(\"some_digit\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# EXTRA\n", + "def plot_digits(instances, images_per_row=10, **options):\n", + " size = 28\n", + " images_per_row = min(len(instances), images_per_row)\n", + " images = [instance.reshape(size,size) for instance in instances]\n", + " n_rows = (len(instances) - 1) // images_per_row + 1\n", + " row_images = []\n", + " n_empty = n_rows * images_per_row - len(instances)\n", + " images.append(np.zeros((size, size * n_empty)))\n", + " for row in range(n_rows):\n", + " rimages = images[row * images_per_row : (row + 1) * images_per_row]\n", + " row_images.append(np.concatenate(rimages, axis=1))\n", + " image = np.concatenate(row_images, axis=0)\n", + " plt.imshow(image, cmap = matplotlib.cm.binary, **options)\n", + " plt.axis(\"off\")\n", + "\n", + "plt.figure(figsize=(9,9))\n", + "example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]\n", + "plot_digits(example_images, images_per_row=10)\n", + "save_fig(\"more_digits\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "y[some_digit_index]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "shuffle_index = rnd.permutation(60000)\n", + "X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Binary classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "y_train_5 = (y_train == 5)\n", + "y_test_5 = (y_test == 5)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.linear_model import SGDClassifier\n", + "\n", + "sgd_clf = SGDClassifier(random_state=42)\n", + "sgd_clf.fit(X_train, y_train_5)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "sgd_clf.predict([some_digit])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.cross_validation import cross_val_score\n", + "cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.cross_validation import StratifiedKFold\n", + "from sklearn.base import clone\n", + "\n", + "skfolds = StratifiedKFold(y_train_5, n_folds=3, random_state=42)\n", + "\n", + "for train_index, test_index in skfolds:\n", + " clone_clf = clone(sgd_clf)\n", + " X_train_folds = X_train[train_index]\n", + " y_train_folds = (y_train_5[train_index])\n", + " X_test_fold = X_train[test_index]\n", + " y_test_fold = (y_train_5[test_index])\n", + " \n", + " clone_clf.fit(X_train_folds, y_train_folds)\n", + " y_pred = clone_clf.predict(X_test_fold)\n", + " n_correct = sum(y_pred == y_test_fold)\n", + " print(n_correct / len(y_pred))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from sklearn.base import BaseEstimator\n", + "class Never5Classifier(BaseEstimator):\n", + " def fit(self, X, y=None):\n", + " pass\n", + " def predict(self, X):\n", + " return np.zeros((len(X), 1), dtype=bool)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "never_5_clf = Never5Classifier()\n", + "cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.cross_validation import cross_val_predict\n", + "\n", + "y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.metrics import confusion_matrix\n", + "\n", + "confusion_matrix(y_train_5, y_train_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.metrics import precision_score, recall_score\n", + "\n", + "precision_score(y_train_5, y_train_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "4344 / (4344 + 1307)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "recall_score(y_train_5, y_train_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "4344 / (4344 + 1077)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.metrics import f1_score\n", + "f1_score(y_train_5, y_train_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "4344 / (4344 + (1077 + 1307)/2)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "y_scores = sgd_clf.decision_function([some_digit])\n", + "y_scores" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "threshold = 0\n", + "y_some_digit_pred = (y_scores > threshold)\n", + "y_some_digit_pred" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "threshold = 200000\n", + "y_some_digit_pred = (y_scores > threshold)\n", + "y_some_digit_pred" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Implemented in https://github.com/scikit-learn/scikit-learn/pull/6671\n", + "# Pushed to master but not yet in pip module.\n", + "from sklearn.cross_validation import StratifiedKFold\n", + "from sklearn.base import clone\n", + "\n", + "def cross_val_predict_future(clf, X, y, cv, method=None):\n", + " clf_clone = clone(clf) # keep original intact\n", + " if method is None:\n", + " return cross_val_predict(clf, X, y, cv=cv)\n", + " else:\n", + " method_f = getattr(clf_clone, method)\n", + " scores = []\n", + " skfolds = StratifiedKFold(y, n_folds=cv)\n", + " for train_indices, test_indices in skfolds:\n", + " clf_clone.fit(X[train_indices], y[train_indices])\n", + " scores.append((method_f(X[test_indices]), test_indices))\n", + " res_shape = list(scores[0][0].shape)\n", + " res_shape[0] = len(X)\n", + " res = np.empty(tuple(res_shape))\n", + " for sc, test_indices in scores:\n", + " res[test_indices] = sc\n", + " return res" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "y_scores = cross_val_predict_future(sgd_clf, X_train, y_train_5, cv=3, method=\"decision_function\")" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from sklearn.metrics import precision_recall_curve\n", + "\n", + "precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):\n", + " plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n", + " plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n", + " plt.xlabel(\"Threshold\", fontsize=16)\n", + " plt.legend(loc=\"center left\", fontsize=16)\n", + " plt.ylim([0, 1])\n", + "\n", + "plt.figure(figsize=(8, 4))\n", + "plot_precision_recall_vs_threshold(precisions, recalls, thresholds)\n", + "plt.xlim([-700000, 700000])\n", + "save_fig(\"precision_recall_vs_threshold_plot\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "(y_train_pred == (y_scores > 0)).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "y_train_pred_90 = (y_scores > 70000)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "precision_score(y_train_5, y_train_pred_90)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "recall_score(y_train_5, y_train_pred_90)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def plot_precision_vs_recall(precisions, recalls):\n", + " plt.plot(recalls, precisions, \"b-\", linewidth=2)\n", + " plt.xlabel(\"Recall\", fontsize=16)\n", + " plt.ylabel(\"Precision\", fontsize=16)\n", + " plt.axis([0, 1, 0, 1])\n", + "\n", + "plt.figure(figsize=(8, 6))\n", + "plot_precision_vs_recall(precisions, recalls)\n", + "save_fig(\"precision_vs_recall_plot\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ROC curves" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.metrics import roc_curve\n", + "\n", + "fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def plot_roc_curve(fpr, tpr, **options):\n", + " plt.plot(fpr, tpr, linewidth=2, **options)\n", + " plt.plot([0, 1], [0, 1], 'k--')\n", + " plt.axis([0, 1, 0, 1])\n", + " plt.xlabel('False Positive Rate', fontsize=16)\n", + " plt.ylabel('True Positive Rate', fontsize=16)\n", + "\n", + "plt.figure(figsize=(8, 6))\n", + "plot_roc_curve(fpr, tpr)\n", + "save_fig(\"roc_curve_plot\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.metrics import roc_auc_score\n", + "\n", + "roc_auc_score(y_train_5, y_scores)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.ensemble import RandomForestClassifier\n", + "forest_clf = RandomForestClassifier(random_state=42)\n", + "y_probas_forest = cross_val_predict_future(forest_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n", + "y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n", + "fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)\n", + "\n", + "plt.figure(figsize=(8, 6))\n", + "plt.plot(fpr, tpr, \"b:\", linewidth=2, label=\"SGD\")\n", + "plot_roc_curve(fpr_forest, tpr_forest, label=\"Random Forest\")\n", + "plt.legend(loc=\"lower right\", fontsize=16)\n", + "save_fig(\"roc_curve_comparison_plot\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "roc_auc_score(y_train_5, y_scores_forest)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "y_train_pred_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3)\n", + "precision_score(y_train_5, y_train_pred_forest)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "recall_score(y_train_5, y_train_pred_forest)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multiclass classification" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "sgd_clf.fit(X_train, y_train)\n", + "sgd_clf.predict([some_digit])" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "some_digit_scores = sgd_clf.decision_function([some_digit])\n", + "some_digit_scores" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "np.argmax(some_digit_scores)" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "sgd_clf.classes_" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.multiclass import OneVsOneClassifier\n", + "ovo_clf = OneVsOneClassifier(SGDClassifier(random_state=42))\n", + "ovo_clf.fit(X_train, y_train)\n", + "ovo_clf.predict([some_digit])" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "len(ovo_clf.estimators_)" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "forest_clf.fit(X_train, y_train)\n", + "forest_clf.predict([some_digit])" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "forest_clf.predict_proba([some_digit])" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring=\"accuracy\")" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.preprocessing import StandardScaler\n", + "scaler = StandardScaler()\n", + "X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))\n", + "cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring=\"accuracy\")" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)\n", + "conf_mx = confusion_matrix(y_train, y_train_pred)\n", + "conf_mx" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def plot_confusion_matrix(matrix):\n", + " \"\"\"If you prefer color and a colorbar\"\"\"\n", + " fig = plt.figure(figsize=(8,8))\n", + " ax = fig.add_subplot(111)\n", + " cax = ax.matshow(conf_mx)\n", + " fig.colorbar(cax)\n", + "\n", + "plt.matshow(conf_mx, cmap=plt.cm.gray)\n", + "save_fig(\"confusion_matrix_plot\", tight_layout=False)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "row_sums = conf_mx.sum(axis=1, keepdims=True)\n", + "norm_conf_mx = conf_mx / row_sums\n", + "np.fill_diagonal(norm_conf_mx, 0)\n", + "plt.matshow(norm_conf_mx, cmap=plt.cm.gray)\n", + "save_fig(\"confusion_matrix_errors_plot\", tight_layout=False)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "cl_a, cl_b = 3, 5\n", + "X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]\n", + "X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]\n", + "X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]\n", + "X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]\n", + "\n", + "plt.figure(figsize=(8,8))\n", + "plt.subplot(221)\n", + "plot_digits(X_aa[:25], images_per_row=5)\n", + "plt.subplot(222)\n", + "plot_digits(X_ab[:25], images_per_row=5)\n", + "plt.subplot(223)\n", + "plot_digits(X_ba[:25], images_per_row=5)\n", + "plt.subplot(224)\n", + "plot_digits(X_bb[:25], images_per_row=5)\n", + "save_fig(\"error_analysis_digits_plot\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multilabel classification" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.neighbors import KNeighborsClassifier\n", + "\n", + "y_train_large = (y_train >= 7)\n", + "y_train_odd = (y_train % 2 == 1)\n", + "y_multilabel = np.c_[y_train_large, y_train_odd]\n", + "\n", + "knn_clf = KNeighborsClassifier()\n", + "knn_clf.fit(X_train, y_multilabel)" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "knn_clf.predict([some_digit])" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_train, cv=3)" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "f1_score(y_train, y_train_knn_pred, average=\"macro\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multioutput classification" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "noise = rnd.randint(0, 100, (len(X_train), 784))\n", + "X_train_mod = X_train + noise\n", + "noise = rnd.randint(0, 100, (len(X_test), 784))\n", + "X_test_mod = X_test + noise\n", + "y_train_mod = X_train\n", + "y_test_mod = X_test" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "some_index = 5500\n", + "plt.subplot(121); plot_digit(X_test_mod[some_index])\n", + "plt.subplot(122); plot_digit(y_test_mod[some_index])\n", + "save_fig(\"noisy_digit_example\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "knn_clf.fit(X_train_mod, y_train_mod)" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "clean_digit = knn_clf.predict([X_test_mod[some_index]])\n", + "plot_digit(clean_digit)\n", + "save_fig(\"cleaned_digit_example\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Extra material" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dummy (ie. random) classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.dummy import DummyClassifier\n", + "dmy_clf = DummyClassifier()\n", + "y_probas_dmy = cross_val_predict_future(dmy_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n", + "y_scores_dmy = y_probas_dmy[:, 1]" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": { + "collapsed": false, + "scrolled": true + }, + "outputs": [], + "source": [ + "fprr, tprr, thresholdsr = roc_curve(y_train_5, y_scores_dmy)\n", + "plot_roc_curve(fprr, tprr)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## KNN classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.neighbors import KNeighborsClassifier\n", + "knn_clf = KNeighborsClassifier(n_jobs=-1, weights='distance', n_neighbors=4)\n", + "knn_clf.fit(X_train, y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "y_knn_pred = knn_clf.predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from sklearn.metrics import accuracy_score\n", + "accuracy_score(y_test, y_knn_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from scipy.ndimage.interpolation import shift\n", + "def shift_digit(digit_array, dx, dy, new=0):\n", + " return shift(digit_array.reshape(28, 28), [dy, dx], cval=new).reshape(784)\n", + "\n", + "plot_digit(shift_digit(some_digit, 5, 1, new=100))" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "X_train_expanded = [X_train]\n", + "y_train_expanded = [y_train]\n", + "for dx, dy in ((1, 0), (-1, 0), (0, 1), (0, -1)):\n", + " shifted_images = np.apply_along_axis(shift_digit, axis=1, arr=X_train, dx=dx, dy=dy)\n", + " X_train_expanded.append(shifted_images)\n", + " y_train_expanded.append(y_train)\n", + "\n", + "X_train_expanded = np.concatenate(X_train_expanded)\n", + "y_train_expanded = np.concatenate(y_train_expanded)\n", + "X_train_expanded.shape, y_train_expanded.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "knn_clf.fit(X_train_expanded, y_train_expanded)" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "y_knn_expanded_pred = knn_clf.predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "accuracy_score(y_test, y_knn_expanded_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ambiguous_digit = X_test[2589]\n", + "knn_clf.predict_proba([ambiguous_digit])" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "plot_digit(ambiguous_digit)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.1" + }, + "toc": { + "toc_cell": false, + "toc_number_sections": true, + "toc_threshold": 6, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}