Files
ml_exercises/notebooks/7_mnist_keras.ipynb
2022-08-13 18:02:20 +02:00

640 lines
52 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Analyze (F)MNIST with `tensorflow` / `keras`\n",
"\n",
"Careful: do **not** hit 'Kernel' > 'Restart & Run All', since some of the cells below take a long time to execute if you are not running the code on a GPU, so we already executed them for you.\n",
"\n",
"In this notebook we compare different types of neural network architectures on the MNIST and Fashion MNIST datasets, to see how the performance improves when using a more complicated architecture. Additionally, we compare the networks to a simple logistic regression classifier from `sklearn`, which should have approximately the same accuracy as a linear FFNN (= a FFNN with only one layer mapping from the input directly to the output and no hidden layers, i.e., that has the same number of trainable parameters as the logistic regression model)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-28T18:18:20.020015Z",
"start_time": "2019-06-28T18:17:33.894451Z"
}
},
"outputs": [],
"source": [
"import os\n",
"import gzip\n",
"import random\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import accuracy_score\n",
"# neural network libraries\n",
"import tensorflow as tf\n",
"# set random seeds before importing keras to get (at least more or less) reproducable results\n",
"random.seed(28)\n",
"np.random.seed(28)\n",
"tf.random.set_seed(28)\n",
"from tensorflow import keras\n",
"from tensorflow.keras.datasets import mnist, fashion_mnist\n",
"from tensorflow.keras import Sequential\n",
"from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D\n",
"from tensorflow.keras import backend as K"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def fashion_mnist_load_local_data():\n",
" from tensorflow.python.keras.utils.data_utils import get_file\n",
" base = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'\n",
" dirname = os.path.abspath(\"../data/fashion-mnist\") \n",
" files = [\n",
" 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',\n",
" 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'\n",
" ]\n",
" paths = []\n",
" for fname in files:\n",
" paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname))\n",
"\n",
" with gzip.open(paths[0], 'rb') as lbpath:\n",
" y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)\n",
" with gzip.open(paths[1], 'rb') as imgpath:\n",
" x_train = np.frombuffer(\n",
" imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)\n",
" with gzip.open(paths[2], 'rb') as lbpath:\n",
" y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)\n",
" with gzip.open(paths[3], 'rb') as imgpath:\n",
" x_test = np.frombuffer(\n",
" imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)\n",
" return (x_train, y_train), (x_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load and look at the data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-28T18:18:20.049855Z",
"start_time": "2019-06-28T18:18:20.021568Z"
}
},
"outputs": [],
"source": [
"# input image dimensions\n",
"img_rows, img_cols = 28, 28\n",
"n_features = img_rows*img_cols\n",
"\n",
"def load_data(use_fashion=False, reshape=False):\n",
" # the data, split between train and test sets (load from local data folder, not ~/.keras)\n",
" if use_fashion:\n",
" (x_train, y_train), (x_test, y_test) = fashion_mnist_load_local_data() # fashion_mnist.load_data()\n",
" else:\n",
" # might need to use mnist.load_data(path=\"mnist.npz\") when executing on Google Colab\n",
" (x_train, y_train), (x_test, y_test) = mnist.load_data(path=os.path.join(os.path.abspath(\"../data/\"), \"mnist.npz\"))\n",
"\n",
" if K.image_data_format() == 'channels_first':\n",
" x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)\n",
" x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)\n",
" input_shape = (1, img_rows, img_cols)\n",
" else:\n",
" x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n",
" x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n",
" input_shape = (img_rows, img_cols, 1)\n",
"\n",
" # normalize (data is ints from 0 to 255, get them to [0, 1])\n",
" x_train = x_train.astype('float32')\n",
" x_test = x_test.astype('float32')\n",
" x_train /= 255.\n",
" x_test /= 255.\n",
" \n",
" if reshape:\n",
" # transform images into regular feature vectors\n",
" x_train = x_train.reshape(x_train.shape[0], n_features)\n",
" x_test = x_test.reshape(x_test.shape[0], n_features)\n",
"\n",
" return x_train, x_test, y_train, y_test\n",
"\n",
"def convert_cat(y_train, y_test, num_classes=10):\n",
" # convert class vectors to binary class matrices\n",
" y_train_cat = keras.utils.to_categorical(y_train, num_classes)\n",
" y_test_cat = keras.utils.to_categorical(y_test, num_classes)\n",
" return y_train_cat, y_test_cat\n",
"\n",
"def plot_images(x):\n",
" n = 10\n",
" plt.figure(figsize=(20, 4))\n",
" for i in range(1, n+1):\n",
" # display original\n",
" ax = plt.subplot(2, n, i)\n",
" plt.imshow(x[i].reshape(28, 28))\n",
" plt.gray()\n",
" ax.get_xaxis().set_visible(False)\n",
" ax.get_yaxis().set_visible(False)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-28T18:18:22.818959Z",
"start_time": "2019-06-28T18:18:20.051562Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABG0AAABwCAYAAACkaY2RAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAaSUlEQVR4nO3df7zO9f3H8fcxOjFOhJVqYQ7jTCFskdFNRDKsRuR3a4xbp9UirQyFiWE7MSG/Ztwmm1/ZNKxEP3CjstvthIYVqfxackQdP3a+f3T7vnq93p3r+Jxzrh+f6zqP+1/Pz97v63O9c53P9eOz9+v9TisoKHAAAAAAAAAIl3KJHgAAAAAAAAC+jps2AAAAAAAAIcRNGwAAAAAAgBDipg0AAAAAAEAIcdMGAAAAAAAghLhpAwAAAAAAEELli9M5LS2N/cETpKCgIC0a5+E1TKgTBQUFNaNxIl7HxOFaTAlciymAazElcC2mAK7FlMC1mAK4FlNCodciM22A+DmY6AEAcM5xLQJhwbUIhAPXIhAOhV6L3LQBAAAAAAAIIW7aAAAAAAAAhBA3bQAAAAAAAEKImzYAAAAAAAAhxE0bAAAAAACAEOKmDQAAAAAAQAhx0wYAAAAAACCEuGkDAAAAAAAQQty0AQAAAAAACCFu2gAAAAAAAIQQN20AAAAAAABCqHyiB5AozZs3l/zAAw+YtgEDBkhevHix5BkzZph+b731VoxGBwAA8KWcnBzJDz74oOTc3FzTr2vXrpIPHjwY+4EBAErspZdekpyWlia5ffv2iRhOqGVlZZlj/Xk3ZMgQyTt27DD93n777Yjn/P3vfy/53LlzpR1iTDHTBgAAAAAAIIS4aQMAAAAAABBCZaY8qmnTpuZ448aNkjMyMkxbQUGB5P79+0vu1q2b6Ve9evVoDhEJcNttt0leunSpaWvXrp3kd999N25jQuFGjx4t+cknnzRt5cp9df/51ltvNW2bN2+O6biAVFGlShXJlStXNm133nmn5Jo1a0qePn266Zefnx+j0ZUtderUMcf9+vWT/L///U9yo0aNTL+GDRtKpjwq8Ro0aGCOK1SoILlt27aSZ82aZfrp17ik1qxZI7l3796mLexlAGGmX8PWrVtL/s1vfmP63XLLLXEbE5LH7373O3Os/4b0khz40tChQyVPnTrVtPnfU/5fvXr1zLH//qfpUqpNmzaVZIhxw0wbAAAAAACAEOKmDQAAAAAAQAildHnU97//fckrVqwwbVdccYVkXQ7lnHOnT5+WrKeQ+uVQN998s2R/J6lUm3qqp/Hqf4dVq1YlYjhR07JlS8n+auNIvEGDBkkeNWqU5KKmjvvXM4Cv6LIbfU0551yrVq0kN27cOND5atWqZY71zkYouePHj5vjLVu2SPZLtZF43/ve9yTrz62ePXuafrqU95prrpHsf6ZF43NM/53Mnj3btD300EOS8/LySv1cZYn+/aDLKY4cOWL6XX311RHbULY8/fTTkn/+85+btvPnz0vWO0nhS3/5y18kP/XUU6YtUnlUcaxcuVLyPffcI3nDhg2lPne0MdMGAAAAAAAghLhpAwAAAAAAEELctAEAAAAAAAihpF/TplKlSub4pptukrxkyRLJft19Ufbt2yd5ypQpkpctW2b6vf7665L1dsTOOTdp0qTAz5cM9DbK9evXl5xsa9roenLnnKtbt67k2rVrm7a0tLS4jAmR6dfk8ssvT+BIyqYf/OAHkvWWw+3atTP99HoOvhEjRkj+6KOPJLdp08b00+/X27dvL/5gIfS2z87Z9Sv69u0ruWLFiqaffs/74IMPTJte601vM92rVy/TT29dvHfv3uIMG8qZM2fMMdt3h5v+ztelS5cEjqRwAwYMMMfz58+XrL/LouT0Gjb+MWvalG16DVS9Zbxzzr322muSly9fHrcxJYtPPvlE8tixY03btGnTJOv7AYcOHTL9rr/++ojnr1q1quTOnTtLZk0bAAAAAAAABMJNGwAAAAAAgBBK+vKoOXPmmOM+ffqU+py6xEpvJ7Z582bTT5cM3XjjjaV+3jDTU2u3bt2awJGUjl8m97Of/UyyLs9wjqn9idChQwdznJ2dXWg//7Xp2rWr5KNHj0Z/YGWE3u7QOedycnIk16hRQ7JfOvjKK69Irlmzpmn77W9/W+hz+efQj+vdu3ewAZdxeuvZyZMnS/ZfxypVqgQ6ny4N7tSpk2nTU7r19af/Lgo7RsnoKdvOOdekSZMEjQRBbNy4UXJR5VHHjh2TrEuU/NJtfwtwrXXr1pL9UlUkDiX1yaVt27aSn3jiCcn+70hdnhOUf47GjRtLPnDggGnTJeQo2uzZs82x3j5df0bm5eWV6PwzZ84s2cDihJk2AAAAAAAAIcRNGwAAAAAAgBDipg0AAAAAAEAIJeWaNs2bN5d85513mrZINaX+ejRr166VPHXqVNOmt6V9++23JZ88edL0a9++/SWfN1X49dbJat68eRHb9HoOiB+99fPChQtNm16zQ/PXSWE73OIpX/6rt/4WLVpIfu6550w/vYXili1bJI8fP97001tWpqenmza9heXtt98ecUw7d+681LDh+fGPfyz5/vvvL/bj/dr6jh07Sva3/M7MzCz2+VFy+tpzrugtS7WWLVtK9tf+4n0ydp599lnJq1evjtjv/Pnzkku6DXRGRobk3Nxcyddcc03Ex/hj4v02+goKCszx5ZdfnqCRIIi5c+dKrl+/vuSsrCzTT3+/Cerxxx83x9WrV5es19J0zrl//etfxT4/vjRhwgTJel2ipk2bluh8l112WanHFEup8UscAAAAAAAgxXDTBgAAAAAAIISSpjxKT3XSWyvqaaLO2emJL774omR/+zW9TeLo0aNNmy6hOX78uGR/CpvektEv09Lbhr/11lsu2fhbmF911VUJGkl0RSq3cc7+XSF+Bg4cKLmo6d16W+nFixfHckgpr1+/fpKLKhnU14TeRrqo7RT97aYjlUQdPnzYHP/xj3+MeE4UrmfPnoH6vf/++5J37NghedSoUaafXxKlNWrUqHiDQ6noMm3nnFu0aJHkcePGRXycbvv0009NW9i3M01mFy5ckFzUdRQNnTp1klytWrVAj/Hfb/Pz86M6JnydLj3etm1bAkeCwpw9e1ay/u1Y0rI2/Tu1du3apk3/XqRsLnr++te/StZlbBs2bDD9brjhhkDn0+VWP/nJT0o5uuhjpg0AAAAAAEAIcdMGAAAAAAAghEJbHtWgQQNzPHLkSMm6xOXEiROm38cffyxZT7f/7LPPTL+///3vheaSqlixojl+5JFHJPft27fU54+3Ll26mGP/vy+Z6NKuunXrRuz34YcfxmM4ZV6NGjXM8X333SdZTyF1zk7v19MWUTz+bk96ZwM9LXjWrFmmny4dLaokStMr+BflwQcfNMe6FBXB6F0ohgwZItmfGrx//37Jx44dK9FzpUqJbLLS13BR5VFIPb179zbH+roP+t1szJgxUR1TWaZL4U6dOiXZL7+vV69e3MaES/O/B+mSmT179kguzm5O3/zmNyXrcmN/9z9dHqdLelA6+vd1kyZNJDdu3LhE5yvJTmHxxEwbAAAAAACAEOKmDQAAAAAAQAhx0wYAAAAAACCEQrWmTXp6uuSpU6eaNr3GyunTpyUPGDDA9Nu5c6fkRK7Dcv311yfsuaPhu9/9bsS2d955J44jKT39t+Svy/Dvf/9bsv67QnTVqVNH8ooVKwI/bsaMGZI3bdoUzSGlPL2GgV7Dxjnnzp07J3n9+vWS/S2gP//880LP7W9Zqbf19t/70tLSJOt1idasWRNx7AhGbwsd63VOWrVqFdPzI7hy5b76/9v8dcCQnPy1Dx977DHJmZmZpq1ChQqBzrlr1y7J58+fL8XooOm19l599VXJXbt2TcRwUIRvf/vbkvVaUM7ZtYkeeOABycVZX2/69OmSe/bsKVl/Njvn3C233BL4nLAaNmwoedWqVaZNvzeWL1/6WxovvPBCqc8RS8y0AQAAAAAACCFu2gAAAAAAAIRQqMqjmjVrJtnfclrr3r275M2bN8d0TPi6HTt2JHoIzjnnMjIyJHfu3Nm09evXT7Iu3fDpLQD1lFdEl359brzxxoj9XnrpJXOck5MTszGlmqpVq5rj4cOHS9bbejtnS6J69OgR6Px6GurSpUtNW/PmzSM+Tm9vOWXKlEDPhdjRW63r7UovRW+Pqr3xxhvmeOvWrSUbGALTJVH+tY3E0CXA/fv3l9yhQ4dAj2/Tpo05Dvq65uXlSdYlVc45t27dOsmRSl2BVKO3e9blNDVq1DD9dPl90N+SI0aMMMeDBg0qtN/EiRMDnQ+X1qhRI8l169Y1bdEoidIefvhhydnZ2VE9dzQw0wYAAAAAACCEuGkDAAAAAAAQQqEqj9KrcOsdR5yzU9fCUhJVVndwuPLKK0v0uCZNmkjWr68/ffi6666TfNlll0n2d1fQ//7+1N/t27dLzs/Pl+xPpXvzzTcDjR3Fp8tunn766Yj9XnvtNckDBw40badOnYr+wFKUvlac+/pUYE2XyHzrW9+SPHjwYNOvW7dukvWU48qVK5t+eiq/P61/yZIlks+cORNxTCidSpUqSc7KyjJtY8eOlVxU6XHQzzS9M4b/N3Px4sVLDxZIcvr90Dm760g8dw/VuxfNnTs3bs+LS6tevXqih5Cy9Hd5vRyCc87Nnz9fclGfaXpHxF/96leS9W9R5+xvHr1DlHP2t8zixYslz5kzp+j/AASmS9weffRR0zZ58mTJ/q6mJVGrVq1SnyOWmGkDAAAAAAAQQty0AQAAAAAACCFu2gAAAAAAAIRQQte06dq1qzlu2rSpZH9dBF0vHBZFbbu5a9eueA8nqvw1YvR/3+zZsyU//vjjgc+pt3rWdaAXLlww/c6ePSt59+7dkhcsWGD67dy5U7K/ztHRo0clHz58WHLFihVNv7179wYaOy5Nb3nqnHMrVqwI9Lj//Oc/kvXrhuI5d+6cOT5+/LjkmjVrmrb33ntPctCtZfU6JnqbWedsHfCJEydM29q1awOdH5dWoUIFc9ysWTPJ+nrz67L1+7l+Hf3tuTt37ixZr5Hj0+sJ3HXXXaYtJydHsv83CaQq/Z3GX5MxCL32hnPB10nU36PvuOMO0/biiy8WexyIHr0mHKKrd+/ekufNm2fa9HcafR3t37/f9GvRokWhuXv37qbftddeK9n/bNXfs+67775AY0fJPfPMM+Z43759kqtWrRrxcfo7y8yZM01bRkZGlEYXe8y0AQAAAAAACCFu2gAAAAAAAIRQQsuj/FIVvWXtsWPHTNvzzz8flzH50tPTJY8bNy5iv5dfftkc6+3jktHw4cPN8cGDByW3bt26ROc8dOiQ5NWrV0ves2eP6bdt27YSnV8bMmSIZF0aoktxEF2jRo0yx0Gndxe1HTiC+/TTT82x3nL9b3/7m2nTW1geOHBA8po1a0y/RYsWSf7kk08kL1u2zPTTU4b9NpSO/lzU5UvOObdy5cpCH/Pkk0+aY/359Prrr0vWfwd+P39LY02/p06aNMm0RXqfd865/Pz8iOdEcEG3Zm/btq059qeFo+Ryc3PN8a233ipZb0G8fv160++LL74o9nP99Kc/NcfZ2dnFPgdiY9OmTZL9JR8QPffcc485XrhwoeTz58+bNv1d6N5775V88uRJ02/atGmS27VrJ1mXSjlnyx39cvIaNWpI/uCDDyTr9wPn7PcsRE/QElD9GmZmZpq2MWPGSNbLtNSuXdv007+DE4WZNgAAAAAAACHETRsAAAAAAIAQ4qYNAAAAAABACCV0TZui+LXvH3/8cdyeW69jM3r0aMkjR440/fRW0ro20jnnPvvssxiNLjEmT56c6CEUy2233Vbo/x50G2oEo+s/b7/99kCP8ddNeffdd6M6Jnxp+/btkv0tv0tCr4+h67+ds+tqsG5U6fjbeuv1afzPIE3Xds+YMcO06Rp//bewbt060++GG26Q7G/XPWXKFMl6vRt/e9SlS5dK/uc//2na9OeIv76AtmvXrohtsNebv8aC5m/HnpWVJXn37t3RH1gZptc7mDhxYlTP7a+nyJo24aHX8PLp9/Iwro+RTIYOHWqO9b/7hAkTTJte76Yo+jqaM2eO5FatWgUel14rRa9vxBo24aLXBtRr2Pj0+kgXL16M6ZhKgpk2AAAAAAAAIcRNGwAAAAAAgBAKbXnUCy+8ELfn0iUeztkp6HqbOb+s4+67747twBB1q1atSvQQUsqGDRskV6tWLWI/vY37oEGDYjkkxEjFihUl+9sM6xINtvwuvm984xuSx48fb9pGjBgh+cyZM6btsccek6z/3f3t3/UWpnrb52bNmpl++/btkzxs2DDTpqd+Z2RkSG7durXp17dvX8ndunUzbRs3bnSF0VulOudc3bp1C+2HL82ePVuyXzZQlCFDhkh+6KGHojomxE6nTp0SPQREcOHChYhtunRGL7uA4vN/f61cuVKy//kRlN6uW5f8+vr06SM5Nzc3Yj+9ZAbCxS+hi2T+/PmSw/h6MtMGAAAAAAAghLhpAwAAAAAAEEIJLY/SUwf94x49epi2X/ziF1F97ocffljyr3/9a9N2xRVXSNY7YQwYMCCqYwCSXfXq1SX7JTParFmzJKfazmplxfr16xM9hJSly1Z0OZRzzp09e1ayXwqjyxNvvvlmyYMHDzb97rjjDsm6zO2pp54y/fSuG0VNOc/Ly5P8j3/8w7TpYz2t3Dnn7r333kLPpz+PcWl79+5N9BDKBH8nN71D4ssvv2zaPv/886g+t76Gc3JyonpuRI8u2/Gvy4YNG0r2yxGHDx8e24GlmGhcA/q3nXPO9ezZU7Iu+fV3flq+fHmpnxv294Jz9vvGn//8Z9PmHxdXrVq1zLH+jlUUXXYXRsy0AQAAAAAACCFu2gAAAAAAAIQQN20AAAAAAABCKKFr2uhtYv3jq6++2rQ988wzkhcsWCD5v//9r+mn6/r79+8vuUmTJqbfddddJ/nQoUOmTa/doNfiQHLSayU1aNDAtOmtqBGMrkMtVy7Yfd833ngjVsNBnLDtbOyMGTMmYpveDnzkyJGmbdy4cZIzMzMDPZd+zKRJk0zbxYsXA50jqGjXqeNLM2bMkJydnW3a6tWrF/Fxem1AfQ5/DYeyrE2bNpKfeOIJ09axY0fJ/rb0Jdl2+Morr5TcpUsX0zZ9+nTJlSpVingOvZbOF198UewxIHr0GmPOOXfttddK/uUvfxnv4cDjryM0bNgwyceOHZPcvn37uI2pLNG/451z7kc/+pFk/7fZRx99JPnDDz+UvH//ftOvefPmhZ7j0UcfNf30mkW+adOmFfq8YcRMGwAAAAAAgBDipg0AAAAAAEAIJbQ8qih6Srhzdlrb3XffLVlvPeqcc/Xr1w90fl2usWnTJtNW1FR1JB9ddhe0nAdfadq0qTnu0KGDZL3N97lz50y/P/zhD5KPHj0ao9EhXr7zne8keggp68iRI5Jr1qxp2tLT0yX7Zb7aunXrJG/ZssW0rV69WvL7778vOdrlUIi/d955xxwXdZ3q92sUbubMmZIbN24csZ8//f706dPFfi5dbnXTTTeZNn/5AO2VV16R/Oyzz0r2v8sisfRr6H8/QnzUrl1b8v3332/a9Oszd+5cyYcPH479wMogXZLrnC0xbdWqlWnT73H6O8vu3btNvx/+8IeSq1SpEvG59Wu9d+9e0zZ27FjJYS8x5RcsAAAAAABACHHTBgAAAAAAIIS4aQMAAAAAABBCCV3TZuvWreZ4x44dklu2bBnxcXo78KuuuipiP70d+LJly0yb3voSZYdfN7lo0aLEDCSJVK1a1Rzr60/T2/I559yIESNiNibE36uvvirZXxuKtTJKp23btpJ79Ohh2vRaF3pbUuecW7BggeSTJ09KZv2EskOvxeCc3UYVsaO3C44Ffa2vXbvWtOnvr2Ffg6Es09sMd+/e3bStWrUq3sMpkzZu3ChZr2/jnHNLliyRrNc1QWxs27bNHOt7AH/6059M26xZsyTXqVOn0Fwc+vtRVlZWic4RBsy0AQAAAAAACCFu2gAAAAAAAIRQQsuj/G3V7rrrLslDhw41baNHjw50zpycHMl6K8T9+/eXZIhIAWlpaYkeApD0cnNzJe/bt8+06W2G69WrZ9qOHz8e24GlAL1dsD9N2D8GNH8L1D179khu1KhRvIeT9AYNGiQ5OzvbtA0cOLDU5z9w4IDks2fPStblp87Zsjf93ovw6tWrlznOz8+XrK9LxM/ChQsljx8/3rStWbMm3sOB8sgjj0hOT083bZUrVy70Mc2aNTPHffr0KbTfqVOnzHHHjh1LMsTQYaYNAAAAAABACHHTBgAAAAAAIITSCgoKgndOSwveGVFVUFAQlRqfsvIa6inOeoeV5557zvTzy/Bi7M2CgoIW0ThRPF9Hf7eo559/XnKbNm0kv/fee6ZfZmZmbAeWIFyL9vpyzrl58+ZJ3rx5s2nTJQZ+KUcCJeW1CItrMSWE9lr0p+zr970JEyaYtmrVqklevXq1ZL17jXO2JOPIkSPRGGYocC1+fZdaXZ7YrVs303bw4MG4jKmYQnstIjiuxZRQ6LXITBsAAAAAAIAQ4qYNAAAAAABACHHTBgAAAAAAIIRY0yZJUKOYEqgXTgFci85lZGSY4+XLl0vu0KGDaVu5cqXkwYMHSz5z5kyMRhcI12IK4FpMCVyLKYBrMSVwLaYArsWUwJo2AAAAAAAAyYKbNgAAAAAAACFUPtEDAAAkl7y8PHPcq1cvyRMnTjRtw4YNkzxu3DjJIdr+GwAAAAgtZtoAAAAAAACEEDdtAAAAAAAAQoibNgAAAAAAACHElt9Jgi3cUgLbKaYArsWUwLWYArgWUwLXYgrgWkwJXIspgGsxJbDlNwAAAAAAQLLgpg0AAAAAAEAIFXfL7xPOuYOxGAiKVDuK5+I1TBxex+THa5gaeB2TH69hauB1TH68hqmB1zH58RqmhkJfx2KtaQMAAAAAAID4oDwKAAAAAAAghLhpAwAAAAAAEELctAEAAAAAAAghbtoAAAAAAACEEDdtAAAAAAAAQoibNgAAAAAAACHETRsAAAAAAIAQ4qYNAAAAAABACHHTBgAAAAAAIIT+D3Vp3aEqmZudAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 1440x288 with 10 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABG0AAABwCAYAAACkaY2RAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2debhWVfXHF78GSjDmGWUeRARkSERRhlQccELEQLN4pNCeUDTDnNJITe1JTH3MocdZLHNKARUNcgQnCAgVBZlBRicyLeP3Rw/b7/py9+bcy3vvPfd9v59/XMe97zn7Pfvs4RzWd61a27dvNyGEEEIIIYQQQgiRL/6vuhsghBBCCCGEEEIIIXZGH22EEEIIIYQQQgghcog+2gghhBBCCCGEEELkEH20EUIIIYQQQgghhMgh+mgjhBBCCCGEEEIIkUP00UYIIYQQQgghhBAih3y1PJVr1aqVi/zg3/jGN4K99957u7ItW7YE+5///GewObU5Hn/zm990ZQ0aNAj2v/71r2C///77rt4XX3xRnmbvFtu3b69ViPNUZx9+9atfPm6NGjUK9ubNm129//znP7t9LexTfF4++OADV6+KU95v2r59e5NCnKgq+/HrX/+6O95zzz2DXb9+/WBzv2G/4ljE/jDz4+1b3/qWK/vvf/9b5vk2bdqUqe2VQTGMxarka1/7WrD//e9/V2NLHDVyLDI4p+K4bNLE/zQcm7im8fz3la98Jdh169Z1ZZ988kmw16xZEz1HVaKxWBQUxViMwePos88+C3bW+ZDX4Dp16gR769atu9G6wqGxWBQU9VgsFWrKWPy///vSb6Rly5auDOdN3Ptv3LixMpvk3kcaN27syj788MNgb9iwoVLbYZGxWK6PNoWgVq0vn6WKbvbatm0b7BtvvNGVPfjgg8GeN29esD///HNXDxfL7t27u7ITTjgh2EuXLg32tdde6+rxBwCRpmHDhsE+/fTTg3333Xe7euvXr9/ta3Xp0iXYXbt2DfZDDz3k6lXxS+SKqrxYoeDJdNCgQcE+7rjjgs0f3+69995gv/HGG8HG/jAzGzFiRLCHDh3qyvBjD57v1ltvzdJ0kQPwA8LatWursSWOah+LhVgLcU4dMmRIsM844wxXD9eqN998M9i8LuJH2AEDBriyOXPmBPvCCy8M9qeffpq5vYX4zaLoqPaxiOAzylTkme3Tp487xj3l6tWrM52D1+B+/foFG/e8QuwmuRqLorjBf1w/99xzXRnuP+66665g33zzzZXapu985zvB5n3UjBkzgj1lypRKbYdFxmKt8ixCWb+68aKX9Rq9evUK9imnnOLK8MUOPVzwXxzM/EOA3hzlYcmSJcHGf+nHDwFm3vPmqaeecmW/+c1vgr1o0aIKtQOpKV9OEf4XJuzTs88+O9j84oBeFFjG9fBflmvXru3KWrduHezHHnss2C+//LKrV8Ubnte3b9/etxAnKnQ/Hnnkke544sSJweaXMvxXP/xXe+wPM/8xtFmzZsFevny5q4deAOvWrXNl+GUb+7hVq1au3rPPPhvsCRMmWGVSE8ci3h8z/68J+LFt3Lhxrh73VQx8qZg1a5Yrwzl5xYov16Fhw4a5etu2bct0rQJR5WMx67rI/7qDcyVuKMz8mMD7x/MhfijlcYrgR2x+ocSxiX2K3q1mZs8991ywb7jhBldWaK+AmjgWxU7kal3Ef/3F/R+Dewwzs7Fjxwb7vPPOCzZ7jxYC3APj+jlp0iRX7/rrr890vqy/OYXGYlGQq7EoKkZex+Lvf/97d3zIIYcEG718zfz7dbdu3YLNXvarVq0KNr67f/TRR64e/gMX/4MUvtPgfM3/yIjvtHhdM7Mf/vCHwV62bJkVgDLHomLaCCGEEEIIIYQQQuQQfbQRQgghhBBCCCGEyCH6aCOEEEIIIYQQQgiRQyolpk0K1ItxANoePXoEGzW2ZmYff/xxsDGOBgeSRa0vZi2pV6+eq4f6f9bwZr0nmAWHM1ChRu75558P9mmnnZbp3ExeNYrlYeTIkcHGOCkXXXSRq4fxMTAWCsdpwPgImNnEzGzmzJnBnjp1arA5zs6jjz6aqe0FIld64Q4dOgT7sssuc2WoJ91jjz1cWUz/ztmj9tprrzKvy+MNjzGGDZ8TxzrH0cAYNxwg/Kc//WmZ7agoNXEszp492x1j3+O44nkM510M4n3qqae6eqhHxvnZzPcHjvuePXtmaXplkauYNtgfjz/+uKuHY5HvLY4JXPswQ42ZHy84B3IGRPw7zliDAaUxaxXXw2MMJG7mNe2PPPKI7S41cSyKnaj2dTFrTBcMqN+pUydXhvtBfO45VhfWwz0Mr1stWrQINq/BeH6cs3l/g+P+mWeecWVjxoyxsuC9d9YYN6U0FnEuTz07qXeJWMDrigZoxzgdL730kivDeJwY96OM61X7WKwuCh2AvDzcc889wb7uuuuCjfONmd+r8RqP5GksDh48ONgXXHCBK8N4ihxrD8cVznGcGRPnRkxi8/rrr7t6fft++VhzBlt878D9VtOmTV09nE8xaYOZ3ytjMqPdQDFthBBCCCGEEEIIIWoK+mgjhBBCCCGEEEIIkUO+uusqheXhhx8Odps2bVzZhg0bgs1uhuiOjZIJdmnDeljGacI4vRjC7qEx0NWf3dbRnQ7TmmHqVTOzt956K9O1igF0m0dX4BtvvNHVw5TN6ALI8ig8B7vC3XHHHcFu165dsDdu3FjeZhctmJY0dV94PKBrIY5Flke99957wUb3Q3ZNxLHOfYyglAPHuZlPJY2pxs3Mjj766GBPmzYtev5iBt1QzfyYwDJMi2hm1rx582D/5Cc/CTZLm1Daymmdsa+4HaVEysX6qquuCja6+Jp5l1yU/PI5U+siyiZwTuV1C8dfnTp1XBlKsfBafA6cL1g69eMf/zjYKGFleasQlQmPj5gE6OWXX3bH++23X7B5nOLYwXHJYwDXMZxfURZu5iVQn3/+uStDuQDuQ9E28/PF6NGjXRmO7+OPPz7YfC/wXlW2TKSmU577U5F7OWjQIHeMzyPK9a688kpXD/vw8MMPd2UpmU1NISU9zloPba6XdQzgeOPQHbgvRam5mVnnzp2DjTIhHJe7unZewedt+fLlrgznTH5/wH0jvr9zPewbfK/HNOFmfp/CklWUNmGoBZZ3495mzZo1rgxDvxx00EHBfvHFF62QyNNGCCGEEEIIIYQQIofoo40QQgghhBBCCCFEDtFHGyGEEEIIIYQQQogcUiUxbfr06RNsjGPDcWZQw8YxZzAOBmrOUumIUVPIMTBQV8z6RdQlon4OdW9mZqtXry6zHoPXOuOMM1xZodMR5xmMW9C4ceNgYzwSM7Nzzz032K1btw42p3rDmCkcKwPPH4tzVOrceeedwZ44caIrwxg3mALPzGtuWbeLoA4f+4P56KOPgs2a/CznNjOrV69esFetWuXKSjWODbJs2TJ33L9//2Dj3MX69th4YW3ywIEDg81aX4y/wPN1KYMpfTG2Bae9x5gYvM7g/cQYFam0vbgeccpvXGc5pg3WxXbwOXCe53g3eM7hw4cHe+rUqSZEVZGKDYHpWg844ABXhnu+1L4RxxtfC49xT8nnwzHMZTjmcH7leDQ4TleuXOnKMNbEkUceGewZM2ZE21vsZI1dgmU8/8X43ve+547nzJkTbFw/MaajmdnatWuDjbHjzMzeeeedYGN66HPOOcfVmz9/fqY21lS4r7KmU4/FNuX1E98heI8ae+fEWKZmPp4r75sxtinGfWNS++28grG6cK9v5mPa8G/DvkmlOsf4NDgHp+ZMjD9j5vdRGMeG3/nx+eFnBMtwPCumjRBCCCGEEEIIIUQJoI82QgghhBBCCCGEEDmkSuRRgwcPDja6OXF6X3TtZLc1dImaNGlSsNF10My7r6Jb1rp161w9dG1iqQW2C1Ol9u7d29XDFLgpqRf+rpNOOsnVKyV5VExClpLO4H3lFJvo0oaSOTPvChdL61fqvPLKK8Hm1KbHHntssOfOnevK8NnGPmCJGo4r7EeWTOA5WMaI7pQsj4ud44ILLojWK1UWL17sjmNuwZwKEfuQXbMRdBlmt1TsU3aPLWUaNGgQbJRHsbs9yqNYsoRzKq5bqbS9KYkoPheptMhYxu3FccrrIv6Www47LNiSR4nKBp/tlKQFZQz8/KI0+IMPPnBlMTl+SpKB+9CKpouO7XXM/DhF6YCZl2FOnz492CjbNPP7Ll6fU2EBSpGuXbu6Y7xfnK67b9++wca1AGXrZmbPPfdcsFECZeZDT/Tr1y/Y/E7TsWPHYL/77rux5hcNWcdSbB7g/5+SJeG6uNdeewWbZfkoG+b9F4aEQHl51lTmeQPnNZQisfQbj1GazeA44jkIwTmOxwCWpeRvWMZzZip8Az4HmMK90MjTRgghhBBCCCGEECKH6KONEEIIIYQQQgghRA6pEnkUSoLQnZJdxNAljV2l0I3qtttuCzZGwDfzEqY77rgj2D/60Y9cvUWLFgW7YcOGrgzbhZlzrrvuOlfvrLPOCja7bGH7MRo1u0+iG9WSJUusmIm5ArMrIt7/+vXrV+hasSwAKde6UuZ3v/udOz777LODzVknMLMUymnwOTfbOfL6Dnjc4zm4f9A9Ec+H2aLMfMYLSXB2hjM6obtvyh0UZaXoms19i+fn/sWxyO6xpQzKzfCeoVTKzPcPu/Wi1BClwkuXLnX1MNsXjjeWKmIZu4SjtAnbfswxx0TbxPM3yo1Z6iVEZZKSRD322GPBRtkTShrMfPZTlkehe3xKNsRjeHdJyb/xN6fWXXT7ZxnPAw88UOb5ipGs8hOUYw8YMCDYLOHHvcgf/vAHV4YZO3Hu5veMpk2bRtv39ttvBxulUig9NfNzcinIo3CMsVQ4RrNmzYKNcjUzs0aNGgUbZW38d7h/3bp1q6uHzwbvX19//fVMbawptGvXLtjYF5jpzszvB/l+4b3E+89zK8rCca/J3xCwjPc2sQzD/OzgMb/vIByuo5DI00YIIYQQQgghhBAih+ijjRBCCCGEEEIIIUQO0UcbIYQQQgghhBBCiBxSJQE+evbsGexVq1YFm7W9nAIcwbRhyJNPPumOUafbrVu3YHNq7UceeSTYw4cPd2Wob8M4DqgZNfPaOtbno/YXdXAcH+TAAw8MdrHHtMF4BtjXHFcBtdepNPCp1LWxOBCptHKlBj7nrBM9+OCDg33FFVdEz4G6Tj4H6ldRM89xa/D4s88+c2Ux/T///8cffzzaRuE182Ze05vS8OLYxLThHPsG+4Pj1sQ0x6UOxop4/vnngz1mzBhXr3v37sG+8sorXdlbb72V6VoYgwHHJWvMcR3juRLXVkzR/fOf/9zVe/XVV4ONen8zP1+0b98+U9uFqGxwH4ZgHCezdKp7JBVnJna+ipK6Vqq9OIfjWOeYHThP1ZSUwxUltvfk3417WVwjca428/GBOK7msGHDgv3UU09F27Rhw4ZoGca72bJlS7A5psbYsWOD/eKLL7oyjO9ZLMT6sUOHDq7elClTgo3x1zhm37777htsjg+IZbNnz47Ww7mE97mFiLW54zfnIe4UxuXD38r7S5yfVqxY4cqwDzG2GM+ZuGfBdxC+Fu55+X7jvgT/jvsJ4xLhnsrMbM899wz25s2bg92kSRNXD2OCVgR52gghhBBCCCGEEELkEH20EUIIIYQQQgghhMghlSKPYhdBdAdKpfxGtyd220Z3o9S10J2pRYsWwWaJRyr9F5bF3GbNvOSA3RFj8iiUiZiZDRw4MNh33XVX9FrFQCytGru7odSiIvXM/HOG9fiZK2VSaUkx1TOnD8Z0fugazC6l+NxjPZY2oesjuxLG+pFdKUWaTZs2ueO2bdsGGyU2LFXEcZVy4f3888/L/BszPxfyXFvKXHPNNcHGsTJr1ixXb968ecFmmTD2Hd53TnuP6yemKub+QBkA9yOmKUWXcJ4fUN7FKZOxHex6XKrEJDIsycgq3UjJXmPwnJw1TS7Ckkm8dt5lNbgvQxlDSmrA/YZjCe9FKr0s3he+Vkzivau/Q7AdPN7wd6L0keWZHFqgmEmNKwSfF+ybIUOGuHr33ntvsMePH1+IJjowFTKuDa+99pqrh33PYSjwHLH3rJpGbJ/Ba9X3v//9YBfit+O7LsuLFy5cGOw//elPrgzfJWPzPJel3nmqm8aNGwcb3yU41Tm+/953332uDO8Jvsvz84tjEfudxy/Ok7hfNYvP1yxN7N+/f7C5b958881g41js0qWLqyd5lBBCCCGEEEIIIUQRoo82QgghhBBCCCGEEDlEH22EEEIIIYQQQgghckilxLSZNGmSO8b4NKhxZy0u1uPYCqjXw5SEqMc0M2vYsGGwUafGqUdRt8bXQq0vpoEbNWqUq9egQYNgc6wa1O7F9NJmO6dXLGZQ+4sp1jjOTCxWTdYUm4xiJ+werKfH1Hao62StKcbVwOeexxvrS5GYTjeVBlPsDKYqZLB/U6m8ER5vWWM4bN26ddeNLREwzevQoUODPWLECFfv8MMPDzbHPTvzzDODjWtVx44dXT1MUYt9x3MvjlMelzjWMVYDx7LC9Z/Pgf1/4oknBnvAgAGuHqavLXayxnvBtTD1N1ljG+Czc/HFF7syjtGXhZoUr6pnz57uGGMw4LrFcSnweeayWNw2jn2Ax6kYKrF6KXi+xj7hGBi4f8XflafYGFVN1rGIc95zzz1Xps1wnE58XrKmhed6GOsD50yek2fMmBHsli1burI2bdoEu1hi2mQFf28q9mXWuQ3j0eH6ZubXvkMPPdSVXX311cFOveekyna8C2PMuuoCY1Pi3mPw4MGuHs67/C6MY6lHjx7B5t+HcyP2IfcZ7m24f3Eux28IK1eudPXwvfWAAw6InmPVqlXB7tWrl6v3wgsv2O4gTxshhBBCCCGEEEKIHKKPNkIIIYQQQgghhBA5pFLkUS+99JI7bt68ebDRbZvTl9apUyfY77zzjitDt7A5c+YEO+V6in/D7lCx9NP8d+huxS6HS5YsCfYee+zhyvB6eA5MY2Zm9uijj1qpEJNacN9gH8bu467A/kV5VNOmTTOfo5RIpXxdvXq1K0NXRfw7lqGhKy/KZ9jFE90KWWaILsToSrlmzZoyfsX/4NTUpezuHSMmGUy5aWMZz7vYp9y/qVTUpcyvf/3rYKMrL68RmEpy+PDhruzSSy8t89zsGoz9jf3D/Y1jhedlHMPo8sySt1deeSXYLMtD93Fc40tJDpUiJYXIOo9997vfDfb+++/vykaOHBlsnGs3bdrk6k2dOrXM86Vg6ffPfvazYP/qV7/KdI6qgtcIfNbxvuOe1MyPHd434vjAMl5bY2U8p8Zk4tyO2N+YpcczluH5WrduXea5RdmkUjSn9qxYlpK9pEAZCoae4OcF24hzt1lp749i821KDpXaX959993BxrnWzPc3y5dROsd7YKRbt27Bvummm1zZjn16bE9Qldx+++3BnjlzZrBRkmlmNmHChGCPHTvWlXXt2jXY+B7Akmtcd3D8sdQf+5rPgbInDP/Qr18/V+/kk08O9sSJE10Zzpvjx48PdqHDc8jTRgghhBBCCCGEECKH6KONEEIIIYQQQgghRA6pFHnUzTffHD1G96hOnTq5epjNgKNro/v0okWLgs2RpNElit1BsxJzX+WsN5ghasGCBa5szJgxFbp2McGucNgfKTfw8sigdsBuqejCiP3G7s4ozeH+Ff9j+fLl7hj7B10Tub/x79CFlDO+obyCXXXRtRCvW8ouvYUgazYSHJs4Ztn9GuEyPMe2bduyNrHoefjhh4ON2aM4iwJm/vjLX/7iylDuiZkOUtImnPPY1RvhMYYuxOhezDJnzEZyzjnnRMsGDRoU7Hnz5rl68+fPj7arppMaHyl5IrrUo+s9Z97CbGNLly51ZSh1Rali27ZtXb2jjjoq2o4Yp5xyijvm7Bp5onfv3u4Yxwf2Ae9F8LlnGQPKTlIZEVMyUyQmE2ewLFWPfwtKMlD6jzIbM9+Pc+fOjZ6/VMma7Yefl1hfpeYHBvezp59+erCfeOIJV+/+++8PNvcvzuulRtZsYUhqzOJ9Z8kvvi9++OGHrmzIkCHBxjka9wgM77dHjx5tZjtLXaubFStWBJszaiELFy50xwMHDgw23pPU+MgqS2UpL66F+H7CYxT79JJLLinjV1Q+8rQRQgghhBBCCCGEyCH6aCOEEEIIIYQQQgiRQ/TRRgghhBBCCCGEECKHVEpMmxQYvwJTg5r5+BWo8TPzujXUo3GMklT6PSQVnwH/rnbt2sFmnTLGBuA052LnVGd4nFVLmqqX0i8i+EywllRxbHYNa7Fj44r/P953HCtcD+cETOtt5tPvIZzOT5SPrHGjcIxljZfAYxZ1/RiDpdTB9J04xjhN9pw5c4J90EEHubLu3bsHG+97qq9w/HFfpdbF2NrK7cX4CRybZtmyZcFetWpVsJcsWRJtb57gcYP3AfclWWOaMPXr1w/2FVdc4cpGjRoVbIxDsW7dOlcP91U8T2Ick7feeivYnOZ58uTJ0TbiGMY2/fa3v3X1MGVrnz59XNnrr78ePX9VwM92LPV2KvVv6pwYDwr3kGZ+PsSYUuVJF43g88TXwv1Oaq+cai/Gpcqa/j3PpOIpViX4HKTm61TMHIxfgnHBOC7aLbfcEuwOHTq4slJ6d8kaL4jrVeSZwTgsZn4v27BhQ1eGsXDw/Bs2bHD1cD6aPXu2K+N1oDqJvZvxnIa/h2PaYOylVJwxXONwHkvNpzze8Py4tvK6mCI2hlPjtyLI00YIIYQQQgghhBAih+ijjRBCCCGEEEIIIUQOqRJ5FLpKoSsTuxCjixKm4DLzrkfobpRVPlMIN8iUCyOnHo/9HbtsVad7ZmXDv62iKdh399rs7it2JiUl5NS/GzduDDaOYZQ5MVjG4x5d9tkdtEmTJsHmVJWi4sRSdqfcglMp17Eep5HGupxauJRp3759sPGesUsuyo84PSveW0zbyy7EWC+2lu4KlFegWzOOUW4jyxvxt6EUqHnz5q4eyqiqm6yp7lOSKATTu5uZjRgxItg70raamW3evNnVW7x4cbCxPznlOqYsZWkr9g1KKFjihu04//zzXRmeE13aeZ1FSSw+m3kg1R4cO9yn+NynJFZI1noVBdvEc29W6RS2iWXt2I/FQB733Fnn4V69ernjv//978F+4IEHgn3MMce4ekcccUSwOd0xylSLnYr2fWp/HKNnz57ueMGCBcFu2bKlKzvllFOCjfP55Zdf7urhGjxz5sxyt6mqwPuMz3bqPm7bti1ahvMwz0exvU1K+s3twHNiO8ojj03JzguJPG2EEEIIIYQQQgghcog+2gghhBBCCCGEEELkkCqRR6GrUMrdaOnSpcFmeRS6fWbNzJBVHpXV5TmVsYbbi6DraaEjSeeZrBlMsroLVzS7Qur+x7JGlBqprCgscWjQoEGw0d2eI+IjmOVgjz32cGX16tULdmps4zht06ZNtB5Ld8TOxOY8fg6yyqgQHvc45iSP+hK815jFjucolHLw2MFxived+yBrNoeUCzHWRRd7vhaOdQbnCFzT2V08T/KomKt3igkTJrjj8ePHB7tZs2auDLOMoNyIr8V/t4OU5Do1r6PMlSVWCGeXOeGEE8qsd/HFF7vjs846K9grV650Zaeeemqw33333ei1K4sLL7zQHeO+NJVJCZ9ffs5Tc2KhwTGHayY/C9h+3r/ivIISZZbUHX/88cHOmn1H7JqsMtVJkyYFm/dYN998c7BPO+20YLO0cvr06cHmvVNWWWexk3pfxLWK+yr2zskyQ3xHzDpXXHTRRe4Yn5kHH3ww0znyBO8VcK7l+QnLcH7mLHhYhvMdvwek3gNx/sN+K09Wy6rKSidPGyGEEEIIIYQQQogcoo82QgghhBBCCCGEEDlEH22EEEIIIYQQQgghckiVxLRBUroy1NKyzjKmVeMUhzFdWdZUtvx3qG/jeAJ4DsXR2BlOzYb3NdU3+Fxg35QnZXjsOWCtIcZmwLgSpUYqng/GPjAzW7RoUbAxXSSPD7yfGI+Bx/by5cvL/BszH+9m3bp1weYYGCJN586d3TE+99j3PJ8iOBZT8ymX4dzYuHHjjC0ufmL3k8fili1bgo3aa66L50tpqlPzIbaD48/hGozPCa+fmD6axzPO7Tifc9ys6qR3797u+LDDDgt2ly5dXBmucTgn1a1b19X74IMPgr1mzRpXhnMcni+1fmIsMY4FgH3IeyzsN3x2OI4J9tu3v/1tV7Z27dpg4+/E2DxmZu+8806weW0YN25csDFmR1XRvn17d4z7PHzOOabNihUrgs1jsapiGjB4XV5bsX9S6cBxLHI9XJ8Vw6ZwxGK9XXbZZa4e9g3vxU466aRg43jjPsS5qTxpjKub1HtaKi4Mzm2FiFWZNZ3zq6++GuxZs2a5Mky7niIVLw7nn1TsuJoI7+nxOeW1EMEYN6lnG/uQ10z8u9Q3itatWweb17uqimkmTxshhBBCCCGEEEKIHKKPNkIIIYQQQgghhBA5pMrlUSnXMnRfSqVVS6W0jJ0vJa1ht6aYm3kqBWpWd/RSIquEoqLp2CvaDiRr2vBSZuDAge4Y0/GiuyZLITDFIaaURTmAWVoW2aJFizLb1Lx5c3fctGnTYG/YsMGVKa272T777OOO0bUTXUPZbRTBObQ8YwqlByiTGzBggKvHqYVLCby3/Iy+//77wWZJRgzun5gEjvsqJYGLSZuYVArZmOtxeaSvlUGTJk1s1KhRZmZ24oknujK853xP8Lfi2EH5Ev8dS6ewb7Zt2xZslFSZxaVN7DqO12J5D95n/F18DvwtOI+bebnj1q1by/z/fP48yN9atWoVbJZrodQAy/hZTu0pY1JFHs9ZxyLC++GYhJzTDONay9IBXK9xfeZ+3GuvvaLtyhNZU2hX5nV5fkCpC88JXbt2Dfa1114bbJQ5mfn7f95557my2N65V69e7hjlgC+//HKZf1OZpFLFp8rQrso+ZVL7xoceeijYCxcuDPYPfvCD6N/wWI/NCTxPzZs3b9eNzTGpd70DDzzQHeN8lZKM4ZwXS91tlpZH4diMraUNeQIAABPeSURBVJFm/j2D5VEpWVUh0RurEEIIIYQQQgghRA7RRxshhBBCCCGEEEKIHKKPNkIIIYQQQgghhBA5pMpj2mQF9cdmXjuNmjPWyKU0+RUBz8eaYDx/dWvy80gh7klK+4qkdLHYDm5TKsVxsZPS3aOOulu3bq4MY9rUr18/2JzO+d133w02puVr166dq4exG1Bbn+KTTz5xx6NHjw72lClTXFmpxrFBhg4d6o5jccFS4yjL/zfbeYxh3aVLlwb7zDPPdPVKLaZN7B5yH+Dax1rsWMy1VNr1VCy2VL/GzsHXQh04x2WJpe5MpfSsCrZs2WL33HOPmfm0rWY+9lL37t1dWZs2bYKNcVsaNGjg6uE6w3p3vJdNmjQp0zaLx1NBvT9fKxUnBedQjKVj5mO5cIwTvB7GReF24Dk5vsC0adOi7aosODYbgn2Cv4Nj2uDvbdiwoSvD/WEqFmJF5tSscHsxVgO3A59XfGY4Nl1N2dvG4kik9o2FuOep2Fx4//mdBuPT/PWvfw12//79Xb2RI0eWu038u7BdHFunKuD2ZI1pmQJjAo0dO9aVYYwgTpOOxPbAvB7hmJg8ebIrwzgnI0aM2FWzd7pWqoyfJ9w/MTvuaZ5jqKZ+d8eOHd0xrjsY24f3QLi24DzG7+upa2N/4/jgNa1Lly7BfuONN1xZVd13edoIIYQQQgghhBBC5BB9tBFCCCGEEEIIIYTIIblK+Y2wSy6CLkvsEhlLK12elHOx1GDs4ovnSKXKzbO7WmWSShmbuv8xl+7ypFWPnYOvhSkxObVpsZNyFzziiCOCvXjxYleGroR4z9q2bevqrVmzJtjoysrXxdR5PXr0cGWY7rhRo0bBRsmImXc9ZjdLlGmVKuxyja6jqZSlOK6ySgl57OHzgm7GnOJR7Bp2245JorgPYhKB8sypeIwyDL4WyqN47GEqWjxHIaTMu8uONixatMj9/7lz50b/BlNqo+yT5yCcG1u2bOnKsE9TfYh9jSmqWSq6efPmYLM8DY/R/vTTT129lIQC91+pfsM2svyqOvZE7C6P4N4uJf1DOTD3D54/JT3GMrRZ2pSStsUkSyk5F5ehvAvPl9p710Qq41mLyXtSqX4vu+wyd7x27dpg9+zZM9ijRo3a7fZxO1C6zs9BZVGrVq2wZ0i9Y/HzhvKjcePGBXv9+vXRa7Hk/rjjjgs2SlqY2PrJEkEMF3DyySe7sqOOOqrMc3O6aJxjU3M7Smt57njhhRfKvJZZfuVRqbkQ95QoMzPzfZA1TAaux/yc45jIOnfzmpH1WapM5GkjhBBCCCGEEEIIkUP00UYIIYQQQgghhBAih+Q2dQ5LkWLum+wmii5KsQxCZt51it3J0GULy1Iuw+g2K/4HS8ayZvYqRGR5JCbLMvPudOJLUKa0YMECVxbLXJK6l6kMFDhm2cUQXSTRRZWlbCmZluRRO98TlJelsgkhqax9KfDvMAtA8+bNXT18fnj+L0Y+/vjjYGN2tZQsgl2uY+tYylU35WqcyoiIf4duwyl568qVK11Z3759g419XN0Zar744osgF8K+MDNr0aJFsFPr1pYtW4I9e/ZsV4YSqJRMJ2tmTDwf3zuck1nSiH9Xt27dYHOmKszix+s4th/Pj2PbzD/f/JtXrFgRbJajVRZ/+9vfomWxsZPK9MWyjtjznNpfYr2U1J/LUhkxY+3lZwGP8bfkTV6Rldi+kffmzZo1CzaObbOdx22MrPfo8ssvDzY/L7jHOuGEEzKdLyVRxvNzPc7sWRVs3749OdfF6N27d7Cxr1Lz4YYNG1wZzmfDhw8P9uOPP55sb4z7778/2E8++aQri2V0YslpVvA3s6y0JmbYTK2ZuM6grNfM9yGuJZj1ziwubWJSYQBi8zWPow4dOkTPH5PaFXo+laeNEEIIIYQQQgghRA7RRxshhBBCCCGEEEKIHKKPNkIIIYQQQgghhBA5JLcxbbKmz0qlqEWypkDlc6S0aagh5VgDWdpU7LAeMBYvoTLuTyxtJWtsUxrIUoJjnqxbty7YnGYYU8zGdPFm8THB9XCsp+LiYEwp1P2a+fTiHJ+hVMHUkaxpx1Tq2L88FrOmNk3Fq8IYG08//XSwR44c6er16dMn2DVRu70r8D6Y+fuJ94/jNSGp+CKxc/O1se9S6yDP3/h3sdhx/HfLly93Zdh+PB//ruqE4wjwcQyc7/j34G/FWDJmfs5L3QdcM1OxVWJ/w2CcAEw/bOafC34OsI2pOBpYxvEA+XpVwdFHHx0tw9hQaPNagvNmKkU33gvey+J9wfuc2svyvY3FIePnB2PC8bMQW7tTaavzTGwf2a1bN3ecio2HcZlSMSxjtGrVyh0PGDAg2LyPGjhwYLnPz78x9p7E9fbee+9yX2t3qVu3bohPw9f/85//HGxOr92yZcsyz/fhhx+6Y4whxvFjcM6eMmVKsFMxbZDHHnvMHXfv3j3Yxx9/fKZzVBSMwVSeuDiptbw6SbULxyLHqsFnGNfI1D4qFWMzlkLczK/duC7y2or7rdReDOfnQs+nemMVQgghhBBCCCGEyCH6aCOEEEIIIYQQQgiRQ3Irj8oqW8kqramoPCqVDhddpzjdpdjZjQ1JpacttGQplqrWTP22A3ZfxT5h12zsV3T5ZTfAWHpKlO2YpV3s8fi9994LdqdOnVw9dFuvV6+eK2vYsGGw0aW22OnVq1eweb6LSWSyphnmsZ2Sy2D/dunSJdjc1/vss0+wi1EexfclJn9AqR+TSsOdkhTHZBj8XKTSHcckrVwP3ZyXLFniymKykby6dpcHdGVPubVv3bq1KpojiGHDhkXLcF+AqbvZZf/MM88M9r333uvKcE5EF3selyirwrGTGtspWQzKAFiCg2shpzxv06ZNsHeku98VLEvGdbeQ7JgPyiOdj61jVbmW3Hrrre64c+fOwU7J87KSSgufqte1a9fdvnZ5qV27trVv397MzG655RZXNnny5GCj3N7My6OwjPfuKK1p3bq1K4uNq2uuucbVu/3224N99dVXB3vw4MGu3syZM4PNqakLDaahT0mlmZoYhgOfS0z/bebXSXxnYFkq7inQ5vAMKI/ic6AkDctYuofzK79nbNq0KdiVuZ+Rp40QQgghhBBCCCFEDtFHGyGEEEIIIYQQQogcoo82QgghhBBCCCGEEDmkymPaVFR3l0pdGTt/SleWOl/WtOEpPbJIp2aLpb0sFLGUa6yL7dixY7Dnz59f8HbUFPj5xfvHqS8xDhCmvWOdaCzOCae8xWcB4wmY+RSar732WrAPOeQQVw9TlHOsFNTDllJMm+HDhwcb9bZmfhxgP3H8BewrHKec7hC1vqzDxms1b9482JxOcb/99ivjVxQvsdhpqZg2vAbF0v1yPRzfWWPf8DqYNVU4ar3/8Y9/uDJsVypNvBCFJhZzxsysTp06wU6Nj0ceeSTYN9xwgysbPXp0sDEWTqNGjVw9THfOaWmRVJwwXGsbN24cbI5lMnfu3GBff/31ruzQQw8t81qp33/ssce649tuuy1ad3eoyHtC7G94bpk+fXqwOUX3VVddFeypU6dmuu6ll14abI6bhPd80aJFmc5XCFJ7oKpi8+bNduedd5qZ2bhx41zZvvvuG2xuGz7D69evDzaOUTMfh4T3NxzbaQfnn39+9Hjjxo3B5phkv/jFL8o8n5lfx1JjJyv4u7LGmirUtasajDfJfYb7RtxTcEwhfNZjeyozvz/iOEp4flwb+L0Ij3Eva7bzM1hZyNNGCCGEEEIIIYQQIofoo40QQgghhBBCCCFEDqlyeVTK/RphqUXW1MzoIoauTOyKn7UdKbLKo2piKrZCgKn7mFQq9Vgfpu4ju8LF0snyc1BVLm15B12szbwrObqNmpl179492ClZDJ4D7zunUcV6nGKvR48ewZ42bVqw2W0Uz8HutrHU48VOhw4dgs33HF07ceywfAzrodzqiSeecPXQnZjnapYi7IDdndFluhSIufKuXLky+jcsH8SxifeZ5zkkJXNKSZbwOJVmGPuVpV54DpyjS3WMiqoDxxvPh+WRIezgggsuSB7HwPGC7eDxlkr5jfvj8qQFjoHX5rGIczuuAWaVI4+qW7eu9e3b18x2fg/A34opgc3Mtm3bFmycJ3lPgce4RpqZnXfeecF+9tlng71hwwZX7/DDDw/2hAkTgs1p1bM+ExUlaygHvgdVzfLly91x//79g71q1SpXhjIWTDHP4wP7m2WGsTAZvL/h9XQHnMo+JW2ryPsdtxfHGEp1uB0Ir7vV3ccxUtLndu3aBZvHOv4d7imWLVvm6sUkpqkU4nwtnIcxVTg/H9gmDvMQq1do5GkjhBBCCCGEEEIIkUP00UYIIYQQQgghhBAih9QYn+RYNqCUe3fMNovLZ5hUNGpE2aN2ht310O0R7yvfO+yPrBI0zgoVy5bCLm0rVqyInrOUYHkUPuscrR3dN9GVGjM4mXnJEromolsrXysFRnxn12jsYz5/ixYtgv32229nulYxgBKmQYMGRevhvUPXUIYj7iMox2HXUwTHM88PCxcujP5dMZCSGyEpuQO7AuMxzoGYlcHM3/esmftSayu2kWVuKIvlPsY5AecOzjQoRKE544wzgj1ixAhXhpLO2F6zUOCYqE5Jw3vvvRfsJk2aBJulYijDePHFFyu9XbVr17a2bduamYX/7gDbyfIHnP9QBsNZdVCOc99997myBQsWBHvo0KHBHjBggKuHsm28JyivMvNrIc/dMWlOIeCMn08//XSlXSsLmJXLzGdaa926tSvDdQf3HCyzxnvLfYzvGlmzKuK7wZgxY8r4FWWfoyJZm1LrLo43luWl2lETwfmVxwPuRbGv+V0P9w64F+E9EM53qf1Gav7nrKlZzlFoan6vCyGEEEIIIYQQQhQh+mgjhBBCCCGEEEIIkUP00UYIIYQQQgghhBAih1R5TJus6dHWrl3rjjt37hxs1OSznhCPUX+WqsdtQh1bKhVpKi5LrF4p8corr7hj7MP69esHG9PdMal03VnvK8Y0YY3ikiVLMp2j2OFYP6iJ5hTaCOpvOZYJjh3UonMKcdShYj0zH2sH03PyeEYNKZdxetdSAdOx3nrrra4MxxWmvU/ps1NleA6MeWTmNcjYFxyT4Prrr4+evxjgNQLHC85tKT30Qw895I7xHqL+ndetWApwrofPBevusf/xfB9++KGr99prr5V5Lf67rL9ZiEKAsVratGnjyjAuCc5fU6dO3e3r8rMdi7WY2s+kynBc8hyNY5jP8dRTTwUb4/3wejlt2rRgX3311dF2FIrNmzfbnXfeWe6/a9SoUbAxTgrHtsAynuPwucA4NnxPpk+fHuz7778/2Jy+GqnMGDYMx0qaOHFisCdPnlxl7dgBp8zG+z5s2DBX9stf/jLY/fr1CzbvFwrN888/H+xZs2ZV6rVSeyl87vg9GCmG90rcA6Xix+Dehu8djiv8Gz4fxrnCGGZmPnZS6l0CScUjq0ico6xopySEEEIIIYQQQgiRQ/TRRgghhBBCCCGEECKH5DblN8pnzLyEAl26U6mK0c6aqsssnmaaXR/RxQqlG0xWd6tig9MO3n333cEePHhwsLkPsa/x/sdc/M12dkHGPsRUb+z2yG0sVTp16uSO8Z6hBIrB+84uh+g++NJLLwUb0z2a+fH87LPPRs+PNs8PmOYb225W+a6uNYH99tvPHcfSa6dcuJs2bRota9asWbA5bTj2L7qZH3HEEa7eihUroucvBvi+oIt46tlGOHVqTQNdurP+ZiEKzcqVK90xpmPGOYrTESOc6h7XICQlza9sUvun+fPnBxslrCyVvummmyqpdYVl8+bNZdqlyPLly91xnvvwySefTB7vAMMrmJn16dMn2JiC3cysVatWwU7J+9esWRPs8ePHR+vhWl2I8ZvaZ11zzTXBfvvtt6P1OBxBTSQWMsPMz0lYxv2J6bvxXZLldPiOw3vZ/fffP9j4rsKySHwOquv+y9NGCCGEEEIIIYQQIofoo40QQgghhBBCCCFEDtFHGyGEEEIIIYQQQogcUuUxbVIpCJF58+a548WLFwcbUzemYtWgZh5TevG1Oe1fLKU4a9hQW8fprZFSimOD8H3FGCczZsyI/h2maGzevHmwUyn/1q9fHz1OpWbL+jwWO2eddZY7TqXj/eMf/xhsjOXEMUkwHgBqrFMpgRlOcbyDBx98MPM5RDrl5sEHHxzsbt26uXpDhgwJNqbFZVAzz3rhBx54INipcV/sYMpJM7MlS5YEe/Xq1cGeO3du9Bw8pyI1Yf667777gt2+fftgv/HGG9XRHFGi8Dg6//zzg43jdN26ddFzVGUK54qSmhMwje6nn34abN7nlur+tZi45JJLqrsJuw2ul3w8derUSr12odfW1PmeeeaZTOfglNZ5JTV/4LsAxzbF+Qnf4TZt2uTq4bsKxjJq0aKFq4d7DIyDY2bWtm3bYGPfcMzTXr16BZvfORGl/BZCCCGEEEIIIYQoMfTRRgghhBBCCCGEECKH1CqP21etWrU2mllx52XNJ222b9/epBAnUh9WK+rHmo/6sDhQP9Z81IfFgfqx5qM+LA7UjzUf9WFxUGY/luujjRBCCCGEEEIIIYSoGiSPEkIIIYQQQgghhMgh+mgjhBBCCCGEEEIIkUP00UYIIYQQQgghhBAih+ijjRBCCCGEEEIIIUQO0UcbIYQQQgghhBBCiByijzZCCCGEEEIIIYQQOUQfbYQQQgghhBBCCCFyiD7aCCGEEEIIIYQQQuQQfbQRQgghhBBCCCGEyCH/DyHQRX5I8r1UAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 1440x288 with 10 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# MNIST\n",
"x_train, x_test, y_train, y_test = load_data()\n",
"plot_images(x_train)\n",
"# fashion MNIST\n",
"x_train, x_test, y_train, y_test = load_data(True)\n",
"plot_images(x_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a \"regular\" linear model on the original MNIST dataset\n",
"\n",
"As you see below, the simple logistic regression classifier is already very good on this easy task, with a test accuracy of over 92.6%."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-28T18:20:11.767300Z",
"start_time": "2019-06-28T18:18:31.492075Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/franzi/opt/anaconda3/lib/python3.8/site-packages/sklearn/linear_model/_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n"
]
},
{
"data": {
"text/plain": [
"0.9262"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# load data and reshape images to regular feature vectors\n",
"x_train, x_test, y_train, y_test = load_data(reshape=True)\n",
"# train LogReg classifier\n",
"clf = LogisticRegression(class_weight='balanced', random_state=1, fit_intercept=True)\n",
"clf.fit(x_train, y_train)\n",
"# accuracy on test dataset\n",
"clf.score(x_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train different NN architectures on MNIST\n",
"\n",
"In the code below we define 3 different neural network architectures: a linear FFNN, a FFNN with multiple hidden layers, and a CNN, which is an architecture particularly well suited for image classification tasks.\n",
"\n",
"You will see that the more complex architectures use an additional operation between layers called `Dropout`. This is a regularization technique used for training neural networks, where a certain percentage of the values in the hidden layer representation of a data point are randomly set to zero. You can think of this as the network suffering from a temporary stroke, which forces the neurons learn redundant representations (i.e., such that one neuron can take over for another neuron that was knocked out), which improves generalization.\n",
"\n",
"The linear FFNN has almost the same accuracy (90.5%) as the LogReg model (please note the NNs were only trained for a single epoch!) and the multi-layer FFNN is already better than the LogReg model (95.8%), while the CNN beats them all (98.3%), which is expected since this architecture is designed for the image classification task."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-28T18:20:11.798437Z",
"start_time": "2019-06-28T18:20:11.769464Z"
}
},
"outputs": [],
"source": [
"def train_linnn(x_train, y_train_cat, num_classes=10, epochs=1):\n",
" model = Sequential()\n",
" model.add(Dense(num_classes, activation='softmax', input_shape=(784,)))\n",
"\n",
" model.compile(loss=keras.losses.categorical_crossentropy,\n",
" optimizer=keras.optimizers.Adadelta(learning_rate=1.),\n",
" metrics=['accuracy'])\n",
"\n",
" model.fit(x_train, y_train_cat, epochs=epochs, batch_size=128)\n",
" return model\n",
"\n",
"def train_ffnn(x_train, y_train_cat, num_classes=10, epochs=1):\n",
" model = Sequential()\n",
" model.add(Dense(512, activation='relu', input_shape=(784,)))\n",
" model.add(Dropout(0.2))\n",
" model.add(Dense(512, activation='relu'))\n",
" model.add(Dropout(0.2))\n",
" model.add(Dense(num_classes, activation='softmax'))\n",
"\n",
" model.compile(loss=keras.losses.categorical_crossentropy,\n",
" optimizer=keras.optimizers.Adadelta(learning_rate=1.),\n",
" metrics=['accuracy'])\n",
"\n",
" model.fit(x_train, y_train_cat, epochs=epochs, batch_size=128)\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-28T18:20:24.912378Z",
"start_time": "2019-06-28T18:20:11.799908Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"469/469 [==============================] - 0s 810us/step - loss: 0.6451 - accuracy: 0.8401\n",
"Test accuracy: 0.9046000242233276\n",
"0.9046\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.2850 - accuracy: 0.9138\n",
"Test accuracy: 0.9577999711036682\n"
]
}
],
"source": [
"# load data and reshape \n",
"x_train, x_test, y_train, y_test = load_data(reshape=True)\n",
"y_train_cat, y_test_cat = convert_cat(y_train, y_test)\n",
"# train simple linear model\n",
"model = train_linnn(x_train, y_train_cat)\n",
"score = model.evaluate(x_test, y_test_cat, verbose=0)\n",
"print('Test accuracy:', score[1])\n",
"# equivalent evaluation:\n",
"# get predictions (as probabilities)\n",
"y_pred = model.predict(x_test)\n",
"# convert predictions to classes\n",
"y_pred_classes = np.argmax(y_pred, axis=1)\n",
"print(accuracy_score(y_test, y_pred_classes), \"\\n\")\n",
"# train multi-layer FFNN\n",
"model = train_ffnn(x_train, y_train_cat)\n",
"score = model.evaluate(x_test, y_test_cat, verbose=0)\n",
"print('Test accuracy:', score[1])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-28T18:20:24.943000Z",
"start_time": "2019-06-28T18:20:24.913859Z"
}
},
"outputs": [],
"source": [
"# CNN classifier\n",
"def train_cnn(x_train, y_train_cat, input_shape=(28, 28, 1), num_classes=10, epochs=1):\n",
" # setup CNN\n",
" model = Sequential()\n",
" model.add(Conv2D(32, kernel_size=(3, 3),\n",
" activation='relu',\n",
" input_shape=input_shape))\n",
" model.add(Conv2D(64, (3, 3), activation='relu'))\n",
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
" model.add(Dropout(0.25))\n",
" model.add(Flatten())\n",
" model.add(Dense(128, activation='relu'))\n",
" model.add(Dropout(0.5))\n",
" model.add(Dense(num_classes, activation='softmax'))\n",
"\n",
" model.compile(loss=keras.losses.categorical_crossentropy,\n",
" optimizer=keras.optimizers.Adadelta(learning_rate=1.),\n",
" metrics=['accuracy'])\n",
" # train\n",
" model.fit(x_train, y_train_cat, epochs=epochs, batch_size=128)\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-28T18:21:21.387931Z",
"start_time": "2019-06-28T18:20:24.944280Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"469/469 [==============================] - 32s 68ms/step - loss: 0.2628 - accuracy: 0.9199\n",
"Test accuracy: 0.9827\n"
]
}
],
"source": [
"x_train, x_test, y_train, y_test = load_data()\n",
"y_train_cat, y_test_cat = convert_cat(y_train, y_test)\n",
"# train cnn\n",
"model = train_cnn(x_train, y_train_cat)\n",
"score = model.evaluate(x_test, y_test_cat, verbose=0)\n",
"print('Test accuracy:', score[1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test all models on the Fashion MNIST dataset\n",
"\n",
"On the more difficult FMNIST task, the LogReg model has a much lower accuracy of 84.4% compared to the 92.6% achieved on the original MNIST dataset. When trained for only a single epoch, both the linear and multi-layer FFNNs have a lower accuracy than the LogReg model (80.5 and 81.9% respectively) and only the CNN does a bit better (86.5%). "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-28T18:25:38.101988Z",
"start_time": "2019-06-28T18:21:21.389691Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/franzi/opt/anaconda3/lib/python3.8/site-packages/sklearn/linear_model/_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy LogReg: 0.8438\n",
"469/469 [==============================] - 0s 781us/step - loss: 0.7530 - accuracy: 0.7527\n",
"Test accuracy Linear NN: 0.8045\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.5820 - accuracy: 0.7895\n",
"Test accuracy FFNN: 0.8194\n",
"469/469 [==============================] - 34s 73ms/step - loss: 0.5575 - accuracy: 0.8024\n",
"Test accuracy CNN: 0.8652\n"
]
}
],
"source": [
"# load data and reshape images to regular feature vectors\n",
"x_train, x_test, y_train, y_test = load_data(True, reshape=True)\n",
"y_train_cat, y_test_cat = convert_cat(y_train, y_test)\n",
"# train LogReg classifier\n",
"clf = LogisticRegression(class_weight='balanced', random_state=1, fit_intercept=True)\n",
"clf.fit(x_train, y_train)\n",
"print('Test accuracy LogReg:', clf.score(x_test, y_test), \"\\n\")\n",
"# train simple linear model\n",
"model = train_linnn(x_train, y_train_cat)\n",
"score = model.evaluate(x_test, y_test_cat, verbose=0)\n",
"print('Test accuracy Linear NN:', score[1], \"\\n\")\n",
"# train multi-layer FFNN\n",
"model = train_ffnn(x_train, y_train_cat)\n",
"score = model.evaluate(x_test, y_test_cat, verbose=0)\n",
"print('Test accuracy FFNN:', score[1], \"\\n\")\n",
"# load data again (not reshaped)\n",
"x_train, x_test, y_train, y_test = load_data(True)\n",
"y_train_cat, y_test_cat = convert_cat(y_train, y_test)\n",
"# train cnn\n",
"model = train_cnn(x_train, y_train_cat)\n",
"score = model.evaluate(x_test, y_test_cat, verbose=0)\n",
"print('Test accuracy CNN:', score[1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, when trained for more epochs, the performance of all models improves, with the accuracy of the linear FFNN now being very close to that of the LogReg model (84.3%), while the multi-layer FFNN is better (89.3%) and the CNN can now solve the task quite well with an accuracy of 92.4%.\n",
"\n",
"(See how the loss decreases over time - observing how this metric develops can help you judge whether you've set your learning rate correctly.)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2019-06-28T18:42:43.369976Z",
"start_time": "2019-06-28T18:25:38.103589Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/15\n",
"469/469 [==============================] - 0s 785us/step - loss: 0.7222 - accuracy: 0.7677\n",
"Epoch 2/15\n",
"469/469 [==============================] - 0s 766us/step - loss: 0.5074 - accuracy: 0.8297\n",
"Epoch 3/15\n",
"469/469 [==============================] - 0s 759us/step - loss: 0.4694 - accuracy: 0.8403\n",
"Epoch 4/15\n",
"469/469 [==============================] - 0s 832us/step - loss: 0.4507 - accuracy: 0.8480\n",
"Epoch 5/15\n",
"469/469 [==============================] - 0s 776us/step - loss: 0.4388 - accuracy: 0.8495\n",
"Epoch 6/15\n",
"469/469 [==============================] - 0s 755us/step - loss: 0.4306 - accuracy: 0.8535\n",
"Epoch 7/15\n",
"469/469 [==============================] - 0s 772us/step - loss: 0.4239 - accuracy: 0.8544\n",
"Epoch 8/15\n",
"469/469 [==============================] - 0s 757us/step - loss: 0.4183 - accuracy: 0.8567\n",
"Epoch 9/15\n",
"469/469 [==============================] - 0s 767us/step - loss: 0.4139 - accuracy: 0.8577\n",
"Epoch 10/15\n",
"469/469 [==============================] - 0s 750us/step - loss: 0.4109 - accuracy: 0.8590\n",
"Epoch 11/15\n",
"469/469 [==============================] - 0s 765us/step - loss: 0.4079 - accuracy: 0.8602\n",
"Epoch 12/15\n",
"469/469 [==============================] - 0s 809us/step - loss: 0.4045 - accuracy: 0.8616\n",
"Epoch 13/15\n",
"469/469 [==============================] - 0s 778us/step - loss: 0.4021 - accuracy: 0.8621\n",
"Epoch 14/15\n",
"469/469 [==============================] - 0s 759us/step - loss: 0.4000 - accuracy: 0.8633\n",
"Epoch 15/15\n",
"469/469 [==============================] - 0s 766us/step - loss: 0.3988 - accuracy: 0.8627\n",
"Test accuracy Linear NN: 0.8432\n",
"Epoch 1/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.5830 - accuracy: 0.7891\n",
"Epoch 2/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.4076 - accuracy: 0.8494\n",
"Epoch 3/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.3657 - accuracy: 0.8644\n",
"Epoch 4/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.3385 - accuracy: 0.8757\n",
"Epoch 5/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.3202 - accuracy: 0.8808\n",
"Epoch 6/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.3069 - accuracy: 0.8859\n",
"Epoch 7/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.2943 - accuracy: 0.8898\n",
"Epoch 8/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.2826 - accuracy: 0.8946\n",
"Epoch 9/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.2743 - accuracy: 0.8973\n",
"Epoch 10/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.2646 - accuracy: 0.8999\n",
"Epoch 11/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.2590 - accuracy: 0.9021\n",
"Epoch 12/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.2510 - accuracy: 0.9057\n",
"Epoch 13/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.2423 - accuracy: 0.9092\n",
"Epoch 14/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.2381 - accuracy: 0.9096\n",
"Epoch 15/15\n",
"469/469 [==============================] - 2s 4ms/step - loss: 0.2322 - accuracy: 0.9119\n",
"Test accuracy FFNN: 0.8934\n",
"Epoch 1/15\n",
"469/469 [==============================] - 34s 73ms/step - loss: 0.5775 - accuracy: 0.7941\n",
"Epoch 2/15\n",
"469/469 [==============================] - 34s 72ms/step - loss: 0.3679 - accuracy: 0.8689\n",
"Epoch 3/15\n",
"469/469 [==============================] - 34s 73ms/step - loss: 0.3161 - accuracy: 0.8875\n",
"Epoch 4/15\n",
"469/469 [==============================] - 35s 74ms/step - loss: 0.2846 - accuracy: 0.8981\n",
"Epoch 5/15\n",
"469/469 [==============================] - 36s 78ms/step - loss: 0.2604 - accuracy: 0.9073\n",
"Epoch 6/15\n",
"469/469 [==============================] - 36s 76ms/step - loss: 0.2418 - accuracy: 0.9125\n",
"Epoch 7/15\n",
"469/469 [==============================] - 36s 76ms/step - loss: 0.2272 - accuracy: 0.9189\n",
"Epoch 8/15\n",
"469/469 [==============================] - 36s 77ms/step - loss: 0.2101 - accuracy: 0.9240\n",
"Epoch 9/15\n",
"469/469 [==============================] - 36s 76ms/step - loss: 0.2001 - accuracy: 0.9274\n",
"Epoch 10/15\n",
"469/469 [==============================] - 37s 79ms/step - loss: 0.1880 - accuracy: 0.9316\n",
"Epoch 11/15\n",
"469/469 [==============================] - 37s 78ms/step - loss: 0.1810 - accuracy: 0.9345\n",
"Epoch 12/15\n",
"469/469 [==============================] - 37s 78ms/step - loss: 0.1714 - accuracy: 0.9383\n",
"Epoch 13/15\n",
"469/469 [==============================] - 36s 77ms/step - loss: 0.1633 - accuracy: 0.9411\n",
"Epoch 14/15\n",
"469/469 [==============================] - 40s 84ms/step - loss: 0.1547 - accuracy: 0.9439\n",
"Epoch 15/15\n",
"469/469 [==============================] - 36s 76ms/step - loss: 0.1484 - accuracy: 0.9453\n",
"Test accuracy CNN: 0.9239\n"
]
}
],
"source": [
"x_train, x_test, y_train, y_test = load_data(True, reshape=True)\n",
"y_train_cat, y_test_cat = convert_cat(y_train, y_test)\n",
"# train simple linear model\n",
"model = train_linnn(x_train, y_train_cat, epochs=15)\n",
"score = model.evaluate(x_test, y_test_cat, verbose=0)\n",
"print('Test accuracy Linear NN:', score[1], \"\\n\")\n",
"# train multi-layer FFNN\n",
"model = train_ffnn(x_train, y_train_cat, epochs=15)\n",
"score = model.evaluate(x_test, y_test_cat, verbose=0)\n",
"print('Test accuracy FFNN:', score[1], \"\\n\")\n",
"# load data again (not reshaped)\n",
"x_train, x_test, y_train, y_test = load_data(True)\n",
"y_train_cat, y_test_cat = convert_cat(y_train, y_test)\n",
"# train cnn\n",
"model = train_cnn(x_train, y_train_cat, epochs=15)\n",
"score = model.evaluate(x_test, y_test_cat, verbose=0)\n",
"print('Test accuracy CNN:', score[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}