fix minor train/test copy-paste error

This commit is contained in:
franzi
2021-11-18 15:23:31 +01:00
parent c175421348
commit 0ca4957c2b

View File

@@ -81,7 +81,7 @@
" data_test = datasets.MNIST(\"../data\", train=False, transform=transforms.ToTensor())\n", " data_test = datasets.MNIST(\"../data\", train=False, transform=transforms.ToTensor())\n",
" # extract (n_samples x n_features) and (n_samples,) X and y numpy arrays from torch dataset\n", " # extract (n_samples x n_features) and (n_samples,) X and y numpy arrays from torch dataset\n",
" X_train, y_train = torch_to_X_y(data_train)\n", " X_train, y_train = torch_to_X_y(data_train)\n",
" X_test, y_test = torch_to_X_y(data_train)\n", " X_test, y_test = torch_to_X_y(data_test)\n",
" return X_train, X_test, y_train, y_test\n", " return X_train, X_test, y_train, y_test\n",
" \n", " \n",
"def plot_images(x):\n", "def plot_images(x):\n",
@@ -283,7 +283,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 13,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-11-22T19:06:31.188855Z", "end_time": "2020-11-22T19:06:31.188855Z",
@@ -385,14 +385,14 @@
"source": [ "source": [
"## Test on MNIST dataset\n", "## Test on MNIST dataset\n",
"\n", "\n",
"As you see below, the simple logistic regression classifier is already very good on this easy task, with a test accuracy of over 93.5%.\n", "As you see below, the simple logistic regression classifier is already very good on this easy task, with a test accuracy of over 92.6%.\n",
"\n", "\n",
"The linear FFNN has almost the same accuracy (90.5%) as the LogReg model (please note: the NNs were only trained for a single epoch!) and the multi-layer FFNN is already better than the LogReg model (96.4%), while the CNN beats them all (98.2%), which is expected since this architecture is designed for the image classification task." "The linear FFNN has almost the same accuracy (91.0%) as the LogReg model (please note: the NNs were only trained for a single epoch!) and the multi-layer FFNN is already better than the LogReg model (96.3%), while the CNN beats them all (98.2%), which is expected since this architecture is designed for the image classification task."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 14,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-11-22T19:07:16.662634Z", "end_time": "2020-11-22T19:07:16.662634Z",
@@ -411,36 +411,54 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"/home/franzi/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "/home/franzi/miniconda3/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:763: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n", "\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n", "Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n" " n_iter_i = _check_optimize_result(\n"
] ]
}, },
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Test accuracy: 0.93535\n", "Test accuracy: 0.9263 \n",
"\n",
"### LinNN\n", "### LinNN\n",
" epoch train_acc train_loss valid_acc valid_loss dur\n", " epoch train_acc train_loss valid_acc valid_loss dur\n",
"------- ----------- ------------ ----------- ------------ ------\n", "------- ----------- ------------ ----------- ------------ ------\n",
" 1 \u001b[36m0.8914\u001b[0m \u001b[32m0.4044\u001b[0m \u001b[35m0.9052\u001b[0m \u001b[31m0.3264\u001b[0m 2.9920\n", " 1 \u001b[36m0.8914\u001b[0m \u001b[32m0.4044\u001b[0m \u001b[35m0.9048\u001b[0m \u001b[31m0.3267\u001b[0m 1.8230\n",
"Test accuracy: 0.9051833333333333\n", "Test accuracy: 0.9097 \n",
"\n",
"### FFNN\n", "### FFNN\n",
" epoch train_acc train_loss valid_acc valid_loss dur\n", " epoch train_acc train_loss valid_acc valid_loss dur\n",
"------- ----------- ------------ ----------- ------------ ------\n", "------- ----------- ------------ ----------- ------------ ------\n",
" 1 \u001b[36m0.9205\u001b[0m \u001b[32m0.2589\u001b[0m \u001b[35m0.9600\u001b[0m \u001b[31m0.1438\u001b[0m 5.8516\n", " 1 \u001b[36m0.9203\u001b[0m \u001b[32m0.2596\u001b[0m \u001b[35m0.9587\u001b[0m \u001b[31m0.1487\u001b[0m 3.6306\n",
"Test accuracy: 0.9642166666666667\n", "Test accuracy: 0.9625 \n",
"### CNN\n", "\n",
"### CNN\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/franzi/miniconda3/lib/python3.9/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n",
" return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" epoch train_acc train_loss valid_acc valid_loss dur\n", " epoch train_acc train_loss valid_acc valid_loss dur\n",
"------- ----------- ------------ ----------- ------------ ------\n", "------- ----------- ------------ ----------- ------------ ------\n",
" 1 \u001b[36m0.9311\u001b[0m \u001b[32m0.2313\u001b[0m \u001b[35m0.9797\u001b[0m \u001b[31m0.0744\u001b[0m 8.4549\n", " 1 \u001b[36m0.9251\u001b[0m \u001b[32m0.2474\u001b[0m \u001b[35m0.9772\u001b[0m \u001b[31m0.0786\u001b[0m 7.1728\n",
"Test accuracy: 0.9821833333333333\n" "Test accuracy: 0.9824 \n",
"\n"
] ]
} }
], ],
@@ -468,12 +486,12 @@
"source": [ "source": [
"## Test on FashionMNIST\n", "## Test on FashionMNIST\n",
"\n", "\n",
"On the more difficult FMNIST task, the LogReg model has a much lower accuracy of 86.6% compared to the 93.5% achieved on the original MNIST dataset. When trained for only a single epoch, both the linear and multi-layer FFNNs have a lower accuracy than the LogReg model (82.7 and 83.7% respectively) and only the CNN does a bit better (88.6%). " "On the more difficult FMNIST task, the LogReg model has a much lower accuracy of 84.4% compared to the 92.6% achieved on the original MNIST dataset. When trained for only a single epoch, both the linear and multi-layer FFNNs have a lower accuracy than the LogReg model (81.2 and 82.4% respectively) and only the CNN does a bit better (87.3%). "
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 15,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-11-22T19:07:58.876986Z", "end_time": "2020-11-22T19:07:58.876986Z",
@@ -492,36 +510,40 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"/home/franzi/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "/home/franzi/miniconda3/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:763: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n", "\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n", "Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n" " n_iter_i = _check_optimize_result(\n"
] ]
}, },
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Test accuracy: 0.8659833333333333\n", "Test accuracy: 0.8438 \n",
"\n",
"### LinNN\n", "### LinNN\n",
" epoch train_acc train_loss valid_acc valid_loss dur\n", " epoch train_acc train_loss valid_acc valid_loss dur\n",
"------- ----------- ------------ ----------- ------------ ------\n", "------- ----------- ------------ ----------- ------------ ------\n",
" 1 \u001b[36m0.8022\u001b[0m \u001b[32m0.5794\u001b[0m \u001b[35m0.8257\u001b[0m \u001b[31m0.5023\u001b[0m 2.7844\n", " 1 \u001b[36m0.8015\u001b[0m \u001b[32m0.5792\u001b[0m \u001b[35m0.8253\u001b[0m \u001b[31m0.5040\u001b[0m 1.9656\n",
"Test accuracy: 0.8270166666666666\n", "Test accuracy: 0.8117 \n",
"\n",
"### FFNN\n", "### FFNN\n",
" epoch train_acc train_loss valid_acc valid_loss dur\n", " epoch train_acc train_loss valid_acc valid_loss dur\n",
"------- ----------- ------------ ----------- ------------ ------\n", "------- ----------- ------------ ----------- ------------ ------\n",
" 1 \u001b[36m0.7816\u001b[0m \u001b[32m0.5942\u001b[0m \u001b[35m0.8366\u001b[0m \u001b[31m0.4465\u001b[0m 5.6541\n", " 1 \u001b[36m0.7827\u001b[0m \u001b[32m0.5929\u001b[0m \u001b[35m0.8400\u001b[0m \u001b[31m0.4453\u001b[0m 3.8960\n",
"Test accuracy: 0.8375666666666667\n", "Test accuracy: 0.8236 \n",
"\n",
"### CNN\n", "### CNN\n",
" epoch train_acc train_loss valid_acc valid_loss dur\n", " epoch train_acc train_loss valid_acc valid_loss dur\n",
"------- ----------- ------------ ----------- ------------ ------\n", "------- ----------- ------------ ----------- ------------ ------\n",
" 1 \u001b[36m0.8063\u001b[0m \u001b[32m0.5415\u001b[0m \u001b[35m0.8842\u001b[0m \u001b[31m0.3228\u001b[0m 8.7526\n", " 1 \u001b[36m0.8000\u001b[0m \u001b[32m0.5616\u001b[0m \u001b[35m0.8818\u001b[0m \u001b[31m0.3272\u001b[0m 7.0208\n",
"Test accuracy: 0.8861166666666667\n" "Test accuracy: 0.8728 \n",
"\n"
] ]
} }
], ],
@@ -544,14 +566,14 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"However, when trained for more epochs, the performance of all models improves, with the accuracy of the linear FFNN now being very close to that of the LogReg model (85.8%), while the multi-layer FFNN is better (89.3%) and the CNN can now solve the task quite well with an accuracy of 94.6%.\n", "However, when trained for more epochs, the performance of all models improves, with the accuracy of the linear FFNN now being very close to that of the LogReg model (84.4%), while the multi-layer FFNN is better (87.1%) and the CNN can now solve the task quite well with an accuracy of 91.7%.\n",
"\n", "\n",
"(See how the training and validation loss decrease over time - observing how these metrics develop can help you judge whether you've set your learning rate correctly and for how many epochs you should train the network.)" "(See how the training and validation loss decrease over time - observing how these metrics develop can help you judge whether you've set your learning rate correctly and for how many epochs you should train the network.)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 16,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2020-11-22T19:12:29.545463Z", "end_time": "2020-11-22T19:12:29.545463Z",
@@ -566,60 +588,63 @@
"### LinNN\n", "### LinNN\n",
" epoch train_acc train_loss valid_acc valid_loss dur\n", " epoch train_acc train_loss valid_acc valid_loss dur\n",
"------- ----------- ------------ ----------- ------------ ------\n", "------- ----------- ------------ ----------- ------------ ------\n",
" 1 \u001b[36m0.8015\u001b[0m \u001b[32m0.5794\u001b[0m \u001b[35m0.8258\u001b[0m \u001b[31m0.5047\u001b[0m 2.9007\n", " 1 \u001b[36m0.8002\u001b[0m \u001b[32m0.5800\u001b[0m \u001b[35m0.8263\u001b[0m \u001b[31m0.5020\u001b[0m 2.0403\n",
" 2 \u001b[36m0.8383\u001b[0m \u001b[32m0.4715\u001b[0m \u001b[35m0.8363\u001b[0m \u001b[31m0.4753\u001b[0m 3.1765\n", " 2 \u001b[36m0.8389\u001b[0m \u001b[32m0.4715\u001b[0m \u001b[35m0.8372\u001b[0m \u001b[31m0.4734\u001b[0m 2.1028\n",
" 3 \u001b[36m0.8461\u001b[0m \u001b[32m0.4520\u001b[0m \u001b[35m0.8401\u001b[0m \u001b[31m0.4635\u001b[0m 3.3053\n", " 3 \u001b[36m0.8458\u001b[0m \u001b[32m0.4520\u001b[0m \u001b[35m0.8409\u001b[0m \u001b[31m0.4620\u001b[0m 2.1199\n",
" 4 \u001b[36m0.8497\u001b[0m \u001b[32m0.4415\u001b[0m \u001b[35m0.8426\u001b[0m \u001b[31m0.4576\u001b[0m 2.8845\n", " 4 \u001b[36m0.8498\u001b[0m \u001b[32m0.4416\u001b[0m \u001b[35m0.8427\u001b[0m \u001b[31m0.4563\u001b[0m 2.2367\n",
" 5 \u001b[36m0.8523\u001b[0m \u001b[32m0.4347\u001b[0m \u001b[35m0.8449\u001b[0m \u001b[31m0.4544\u001b[0m 3.1842\n", " 5 \u001b[36m0.8519\u001b[0m \u001b[32m0.4347\u001b[0m \u001b[35m0.8448\u001b[0m \u001b[31m0.4533\u001b[0m 2.1635\n",
" 6 \u001b[36m0.8542\u001b[0m \u001b[32m0.4297\u001b[0m \u001b[35m0.8460\u001b[0m \u001b[31m0.4527\u001b[0m 2.9838\n", " 6 \u001b[36m0.8540\u001b[0m \u001b[32m0.4297\u001b[0m \u001b[35m0.8461\u001b[0m \u001b[31m0.4516\u001b[0m 2.1888\n",
" 7 \u001b[36m0.8555\u001b[0m \u001b[32m0.4258\u001b[0m 0.8456 \u001b[31m0.4518\u001b[0m 3.0838\n", " 7 \u001b[36m0.8553\u001b[0m \u001b[32m0.4258\u001b[0m \u001b[35m0.8465\u001b[0m \u001b[31m0.4509\u001b[0m 2.1393\n",
" 8 \u001b[36m0.8569\u001b[0m \u001b[32m0.4227\u001b[0m 0.8456 \u001b[31m0.4515\u001b[0m 3.0435\n", " 8 \u001b[36m0.8570\u001b[0m \u001b[32m0.4227\u001b[0m 0.8465 \u001b[31m0.4506\u001b[0m 2.1642\n",
" 9 \u001b[36m0.8580\u001b[0m \u001b[32m0.4201\u001b[0m 0.8459 0.4515 3.2644\n", " 9 \u001b[36m0.8578\u001b[0m \u001b[32m0.4201\u001b[0m \u001b[35m0.8468\u001b[0m 0.4507 2.2154\n",
" 10 \u001b[36m0.8592\u001b[0m \u001b[32m0.4179\u001b[0m \u001b[35m0.8462\u001b[0m 0.4517 3.0675\n", " 10 \u001b[36m0.8590\u001b[0m \u001b[32m0.4179\u001b[0m 0.8466 0.4510 2.3116\n",
" 11 \u001b[36m0.8602\u001b[0m \u001b[32m0.4159\u001b[0m \u001b[35m0.8464\u001b[0m 0.4521 3.1939\n", " 11 \u001b[36m0.8602\u001b[0m \u001b[32m0.4160\u001b[0m \u001b[35m0.8477\u001b[0m 0.4515 2.1901\n",
" 12 \u001b[36m0.8610\u001b[0m \u001b[32m0.4143\u001b[0m \u001b[35m0.8468\u001b[0m 0.4526 3.0209\n", " 12 \u001b[36m0.8608\u001b[0m \u001b[32m0.4143\u001b[0m 0.8474 0.4520 2.1735\n",
" 13 \u001b[36m0.8619\u001b[0m \u001b[32m0.4128\u001b[0m 0.8462 0.4532 3.0769\n", " 13 \u001b[36m0.8618\u001b[0m \u001b[32m0.4128\u001b[0m 0.8473 0.4526 2.2863\n",
" 14 \u001b[36m0.8627\u001b[0m \u001b[32m0.4115\u001b[0m 0.8462 0.4538 3.1085\n", " 14 \u001b[36m0.8626\u001b[0m \u001b[32m0.4115\u001b[0m 0.8472 0.4532 2.2163\n",
" 15 \u001b[36m0.8635\u001b[0m \u001b[32m0.4103\u001b[0m 0.8465 0.4544 2.8896\n", " 15 \u001b[36m0.8631\u001b[0m \u001b[32m0.4103\u001b[0m 0.8471 0.4539 2.0270\n",
"Test accuracy: 0.8585333333333334\n", "Test accuracy: 0.8363 \n",
"\n",
"### FFNN\n", "### FFNN\n",
" epoch train_acc train_loss valid_acc valid_loss dur\n", " epoch train_acc train_loss valid_acc valid_loss dur\n",
"------- ----------- ------------ ----------- ------------ ------\n", "------- ----------- ------------ ----------- ------------ ------\n",
" 1 \u001b[36m0.7825\u001b[0m \u001b[32m0.5919\u001b[0m \u001b[35m0.8488\u001b[0m \u001b[31m0.4411\u001b[0m 5.6527\n", " 1 \u001b[36m0.7826\u001b[0m \u001b[32m0.5931\u001b[0m \u001b[35m0.8327\u001b[0m \u001b[31m0.4488\u001b[0m 3.3172\n",
" 2 \u001b[36m0.8400\u001b[0m \u001b[32m0.4556\u001b[0m \u001b[35m0.8496\u001b[0m \u001b[31m0.4000\u001b[0m 6.2455\n", " 2 \u001b[36m0.8381\u001b[0m \u001b[32m0.4572\u001b[0m \u001b[35m0.8459\u001b[0m \u001b[31m0.4277\u001b[0m 3.8799\n",
" 3 \u001b[36m0.8499\u001b[0m \u001b[32m0.4279\u001b[0m \u001b[35m0.8550\u001b[0m 0.4068 5.3284\n", " 3 \u001b[36m0.8500\u001b[0m \u001b[32m0.4296\u001b[0m \u001b[35m0.8555\u001b[0m \u001b[31m0.4243\u001b[0m 3.7612\n",
" 4 \u001b[36m0.8554\u001b[0m \u001b[32m0.4144\u001b[0m \u001b[35m0.8609\u001b[0m 0.4078 5.9324\n", " 4 \u001b[36m0.8552\u001b[0m \u001b[32m0.4174\u001b[0m \u001b[35m0.8672\u001b[0m \u001b[31m0.3989\u001b[0m 3.5930\n",
" 5 \u001b[36m0.8595\u001b[0m \u001b[32m0.4088\u001b[0m \u001b[35m0.8679\u001b[0m \u001b[31m0.3969\u001b[0m 5.5903\n", " 5 \u001b[36m0.8592\u001b[0m \u001b[32m0.4071\u001b[0m \u001b[35m0.8672\u001b[0m 0.4075 4.0038\n",
" 6 \u001b[36m0.8605\u001b[0m \u001b[32m0.4040\u001b[0m \u001b[35m0.8687\u001b[0m 0.4187 5.7188\n", " 6 \u001b[36m0.8645\u001b[0m \u001b[32m0.3990\u001b[0m \u001b[35m0.8702\u001b[0m \u001b[31m0.3788\u001b[0m 3.7168\n",
" 7 \u001b[36m0.8645\u001b[0m \u001b[32m0.4004\u001b[0m 0.8608 0.4450 5.9384\n", " 7 \u001b[36m0.8661\u001b[0m \u001b[32m0.3925\u001b[0m 0.8685 0.4108 3.8341\n",
" 8 \u001b[36m0.8657\u001b[0m \u001b[32m0.3949\u001b[0m 0.8673 \u001b[31m0.3921\u001b[0m 5.8410\n", " 8 \u001b[36m0.8676\u001b[0m \u001b[32m0.3916\u001b[0m \u001b[35m0.8740\u001b[0m 0.3865 3.7558\n",
" 9 0.8649 \u001b[32m0.3934\u001b[0m \u001b[35m0.8748\u001b[0m 0.3986 6.0358\n", " 9 \u001b[36m0.8683\u001b[0m \u001b[32m0.3881\u001b[0m \u001b[35m0.8752\u001b[0m 0.4069 3.1817\n",
" 10 \u001b[36m0.8702\u001b[0m \u001b[32m0.3902\u001b[0m 0.8698 0.4123 5.6180\n", " 10 \u001b[36m0.8708\u001b[0m \u001b[32m0.3837\u001b[0m 0.8736 0.4105 3.4485\n",
" 11 \u001b[36m0.8709\u001b[0m \u001b[32m0.3887\u001b[0m \u001b[35m0.8762\u001b[0m 0.3928 5.8379\n", " 11 \u001b[36m0.8734\u001b[0m 0.3840 \u001b[35m0.8818\u001b[0m 0.3997 3.6171\n",
" 12 \u001b[36m0.8721\u001b[0m \u001b[32m0.3871\u001b[0m 0.8751 0.3933 5.9377\n", " 12 \u001b[36m0.8739\u001b[0m \u001b[32m0.3811\u001b[0m 0.8782 0.4131 3.9902\n",
" 13 \u001b[36m0.8734\u001b[0m \u001b[32m0.3826\u001b[0m \u001b[35m0.8778\u001b[0m 0.4058 5.4589\n", " 13 \u001b[36m0.8744\u001b[0m \u001b[32m0.3760\u001b[0m 0.8775 0.4049 3.5877\n",
" 14 0.8734 \u001b[32m0.3775\u001b[0m \u001b[35m0.8798\u001b[0m 0.3961 5.5648\n", " 14 \u001b[36m0.8776\u001b[0m \u001b[32m0.3731\u001b[0m 0.8742 0.4266 3.2744\n",
" 15 \u001b[36m0.8745\u001b[0m 0.3825 0.8788 0.3984 5.7210\n", " 15 0.8763 0.3737 0.8784 0.4210 3.9802\n",
"Test accuracy: 0.8931333333333333\n", "Test accuracy: 0.8713 \n",
"\n",
"### CNN\n", "### CNN\n",
" epoch train_acc train_loss valid_acc valid_loss dur\n", " epoch train_acc train_loss valid_acc valid_loss dur\n",
"------- ----------- ------------ ----------- ------------ ------\n", "------- ----------- ------------ ----------- ------------ ------\n",
" 1 \u001b[36m0.8104\u001b[0m \u001b[32m0.5317\u001b[0m \u001b[35m0.8889\u001b[0m \u001b[31m0.3057\u001b[0m 7.9180\n", " 1 \u001b[36m0.8079\u001b[0m \u001b[32m0.5324\u001b[0m \u001b[35m0.8864\u001b[0m \u001b[31m0.3138\u001b[0m 6.8582\n",
" 2 \u001b[36m0.8770\u001b[0m \u001b[32m0.3548\u001b[0m \u001b[35m0.8979\u001b[0m \u001b[31m0.2899\u001b[0m 8.4600\n", " 2 \u001b[36m0.8775\u001b[0m \u001b[32m0.3512\u001b[0m \u001b[35m0.9021\u001b[0m \u001b[31m0.2723\u001b[0m 7.1190\n",
" 3 \u001b[36m0.8930\u001b[0m \u001b[32m0.3103\u001b[0m \u001b[35m0.9120\u001b[0m \u001b[31m0.2476\u001b[0m 8.7328\n", " 3 \u001b[36m0.8934\u001b[0m \u001b[32m0.3089\u001b[0m \u001b[35m0.9119\u001b[0m \u001b[31m0.2423\u001b[0m 6.8647\n",
" 4 \u001b[36m0.9010\u001b[0m \u001b[32m0.2884\u001b[0m \u001b[35m0.9121\u001b[0m 0.2549 8.7084\n", " 4 \u001b[36m0.9007\u001b[0m \u001b[32m0.2865\u001b[0m 0.9047 0.2611 7.2493\n",
" 5 \u001b[36m0.9061\u001b[0m \u001b[32m0.2731\u001b[0m \u001b[35m0.9153\u001b[0m \u001b[31m0.2409\u001b[0m 8.2132\n", " 5 \u001b[36m0.9072\u001b[0m \u001b[32m0.2727\u001b[0m 0.9089 0.2543 7.2193\n",
" 6 \u001b[36m0.9121\u001b[0m \u001b[32m0.2610\u001b[0m 0.9117 0.2664 8.4597\n", " 6 \u001b[36m0.9120\u001b[0m \u001b[32m0.2592\u001b[0m \u001b[35m0.9185\u001b[0m \u001b[31m0.2338\u001b[0m 7.1084\n",
" 7 \u001b[36m0.9154\u001b[0m \u001b[32m0.2504\u001b[0m \u001b[35m0.9179\u001b[0m \u001b[31m0.2391\u001b[0m 8.5581\n", " 7 \u001b[36m0.9145\u001b[0m \u001b[32m0.2519\u001b[0m 0.9068 0.2671 7.1128\n",
" 8 \u001b[36m0.9185\u001b[0m \u001b[32m0.2428\u001b[0m 0.9143 0.2494 8.1528\n", " 8 \u001b[36m0.9181\u001b[0m \u001b[32m0.2409\u001b[0m \u001b[35m0.9202\u001b[0m 0.2636 7.1794\n",
" 9 \u001b[36m0.9208\u001b[0m \u001b[32m0.2348\u001b[0m \u001b[35m0.9184\u001b[0m 0.2415 8.8257\n", " 9 \u001b[36m0.9193\u001b[0m \u001b[32m0.2408\u001b[0m 0.9157 0.2507 7.1312\n",
" 10 \u001b[36m0.9234\u001b[0m \u001b[32m0.2317\u001b[0m 0.9182 0.2498 8.1160\n", " 10 \u001b[36m0.9218\u001b[0m \u001b[32m0.2326\u001b[0m 0.9188 0.2371 7.2552\n",
" 11 \u001b[36m0.9243\u001b[0m \u001b[32m0.2288\u001b[0m 0.9172 0.2472 7.9041\n", " 11 \u001b[36m0.9250\u001b[0m \u001b[32m0.2257\u001b[0m 0.9190 \u001b[31m0.2338\u001b[0m 7.1150\n",
" 12 0.9232 0.2306 0.9183 0.2617 8.7895\n", " 12 \u001b[36m0.9260\u001b[0m 0.2262 \u001b[35m0.9225\u001b[0m 0.2382 7.1878\n",
" 13 \u001b[36m0.9260\u001b[0m \u001b[32m0.2252\u001b[0m 0.9136 0.2523 8.5135\n", " 13 \u001b[36m0.9270\u001b[0m \u001b[32m0.2211\u001b[0m 0.9193 0.2566 7.1713\n",
" 14 \u001b[36m0.9290\u001b[0m \u001b[32m0.2194\u001b[0m 0.9146 0.2508 8.8612\n", " 14 \u001b[36m0.9275\u001b[0m \u001b[32m0.2185\u001b[0m \u001b[35m0.9236\u001b[0m 0.2352 6.8393\n",
" 15 0.9282 \u001b[32m0.2182\u001b[0m 0.9184 0.2503 7.9617\n", " 15 \u001b[36m0.9292\u001b[0m \u001b[32m0.2171\u001b[0m 0.9223 0.2606 6.9869\n",
"Test accuracy: 0.9464833333333333\n" "Test accuracy: 0.9173 \n",
"\n"
] ]
} }
], ],
@@ -642,7 +667,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@@ -656,7 +681,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.5" "version": "3.9.5"
} }
}, },
"nbformat": 4, "nbformat": 4,