mirror of
https://github.com/ArthurDanjou/handson-ml3.git
synced 2026-01-29 03:00:28 +01:00
Large change: replace os.path with pathlib, move to Python 3.7
This commit is contained in:
@@ -4,8 +4,13 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Chapter 3 – Classification**\n",
|
||||
"\n",
|
||||
"**Chapter 3 – Classification**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"_This notebook contains all the sample code and solutions to the exercises in chapter 3._"
|
||||
]
|
||||
},
|
||||
@@ -34,7 +39,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20."
|
||||
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -43,24 +48,17 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Python ≥3.5 is required\n",
|
||||
"# Python ≥3.7 is required\n",
|
||||
"import sys\n",
|
||||
"assert sys.version_info >= (3, 5)\n",
|
||||
"assert sys.version_info >= (3, 7)\n",
|
||||
"\n",
|
||||
"# Is this notebook running on Colab or Kaggle?\n",
|
||||
"IS_COLAB = \"google.colab\" in sys.modules\n",
|
||||
"IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n",
|
||||
"\n",
|
||||
"# Scikit-Learn ≥0.20 is required\n",
|
||||
"# Scikit-Learn ≥1.0 is required\n",
|
||||
"import sklearn\n",
|
||||
"assert sklearn.__version__ >= \"0.20\"\n",
|
||||
"assert sklearn.__version__ >= \"1.0\"\n",
|
||||
"\n",
|
||||
"# Common imports\n",
|
||||
"import numpy as np\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# to make this notebook's output stable across runs\n",
|
||||
"np.random.seed(42)\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"# To plot pretty figures\n",
|
||||
"%matplotlib inline\n",
|
||||
@@ -71,14 +69,11 @@
|
||||
"mpl.rc('ytick', labelsize=12)\n",
|
||||
"\n",
|
||||
"# Where to save the figures\n",
|
||||
"PROJECT_ROOT_DIR = \".\"\n",
|
||||
"CHAPTER_ID = \"classification\"\n",
|
||||
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
|
||||
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
|
||||
"IMAGES_PATH = Path() / \"images\" / \"classification\"\n",
|
||||
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
|
||||
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n",
|
||||
" print(\"Saving figure\", fig_id)\n",
|
||||
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
|
||||
" if tight_layout:\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.savefig(path, format=fig_extension, dpi=resolution)"
|
||||
@@ -1466,49 +1461,36 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 100,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from pathlib import Path\n",
|
||||
"import pandas as pd\n",
|
||||
"import urllib.request\n",
|
||||
"\n",
|
||||
"TITANIC_PATH = os.path.join(\"datasets\", \"titanic\")\n",
|
||||
"DOWNLOAD_URL = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/datasets/titanic/\"\n",
|
||||
"\n",
|
||||
"def fetch_titanic_data(url=DOWNLOAD_URL, path=TITANIC_PATH):\n",
|
||||
" if not os.path.isdir(path):\n",
|
||||
" os.makedirs(path)\n",
|
||||
" for filename in (\"train.csv\", \"test.csv\"):\n",
|
||||
" filepath = os.path.join(path, filename)\n",
|
||||
" if not os.path.isfile(filepath):\n",
|
||||
" print(\"Downloading\", filename)\n",
|
||||
" urllib.request.urlretrieve(url + filename, filepath)\n",
|
||||
"\n",
|
||||
"fetch_titanic_data() "
|
||||
"def load_titanic_data():\n",
|
||||
" titanic_path = Path() / \"datasets\" / \"titanic\"\n",
|
||||
" titanic_path.mkdir(parents=True, exist_ok=True)\n",
|
||||
" filenames = (\"train.csv\", \"test.csv\")\n",
|
||||
" for filename in filenames:\n",
|
||||
" filepath = titanic_path / filename\n",
|
||||
" if filepath.is_file():\n",
|
||||
" continue\n",
|
||||
" root = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
|
||||
" url = root + \"/datasets/titanic/\" + filename\n",
|
||||
" print(\"Downloading\", filename)\n",
|
||||
" urllib.request.urlretrieve(url, filepath)\n",
|
||||
" return [pd.read_csv(titanic_path / filename) for filename in filenames]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 101,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"def load_titanic_data(filename, titanic_path=TITANIC_PATH):\n",
|
||||
" csv_path = os.path.join(titanic_path, filename)\n",
|
||||
" return pd.read_csv(csv_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 102,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data = load_titanic_data(\"train.csv\")\n",
|
||||
"test_data = load_titanic_data(\"test.csv\")"
|
||||
"train_data, test_data = load_titanic_data()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1966,38 +1948,40 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 125,
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from pathlib import Path\n",
|
||||
"import tarfile\n",
|
||||
"import urllib.request\n",
|
||||
"\n",
|
||||
"DOWNLOAD_ROOT = \"http://spamassassin.apache.org/old/publiccorpus/\"\n",
|
||||
"HAM_URL = DOWNLOAD_ROOT + \"20030228_easy_ham.tar.bz2\"\n",
|
||||
"SPAM_URL = DOWNLOAD_ROOT + \"20030228_spam.tar.bz2\"\n",
|
||||
"SPAM_PATH = os.path.join(\"datasets\", \"spam\")\n",
|
||||
"def fetch_spam_data():\n",
|
||||
" root = \"http://spamassassin.apache.org/old/publiccorpus/\"\n",
|
||||
" ham_url = root + \"20030228_easy_ham.tar.bz2\"\n",
|
||||
" spam_url = root + \"20030228_spam.tar.bz2\"\n",
|
||||
"\n",
|
||||
"def fetch_spam_data(ham_url=HAM_URL, spam_url=SPAM_URL, spam_path=SPAM_PATH):\n",
|
||||
" if not os.path.isdir(spam_path):\n",
|
||||
" os.makedirs(spam_path)\n",
|
||||
" for filename, url in ((\"ham.tar.bz2\", ham_url), (\"spam.tar.bz2\", spam_url)):\n",
|
||||
" path = os.path.join(spam_path, filename)\n",
|
||||
" if not os.path.isfile(path):\n",
|
||||
" spam_path = Path() / \"datasets\" / \"spam\"\n",
|
||||
" spam_path.mkdir(parents=True, exist_ok=True)\n",
|
||||
" for dir_name, tar_name, url in ((\"easy_ham\", \"ham\", ham_url),\n",
|
||||
" (\"spam\", \"spam\", spam_url)):\n",
|
||||
" if not (spam_path / dir_name).is_dir():\n",
|
||||
" path = (spam_path / tar_name).with_suffix(\".tar.bz2\")\n",
|
||||
" print(\"Downloading\", path)\n",
|
||||
" urllib.request.urlretrieve(url, path)\n",
|
||||
" tar_bz2_file = tarfile.open(path)\n",
|
||||
" tar_bz2_file.extractall(path=spam_path)\n",
|
||||
" tar_bz2_file.close()"
|
||||
" tar_bz2_file = tarfile.open(path)\n",
|
||||
" tar_bz2_file.extractall(path=spam_path)\n",
|
||||
" tar_bz2_file.close()\n",
|
||||
" return [spam_path / dir_name for dir_name in (\"easy_ham\", \"spam\")]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 126,
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fetch_spam_data()"
|
||||
"ham_dir, spam_dir = fetch_spam_data()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -2009,19 +1993,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 127,
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"HAM_DIR = os.path.join(SPAM_PATH, \"easy_ham\")\n",
|
||||
"SPAM_DIR = os.path.join(SPAM_PATH, \"spam\")\n",
|
||||
"ham_filenames = [name for name in sorted(os.listdir(HAM_DIR)) if len(name) > 20]\n",
|
||||
"spam_filenames = [name for name in sorted(os.listdir(SPAM_DIR)) if len(name) > 20]"
|
||||
"ham_filenames = [f for f in sorted(ham_dir.iterdir()) if len(f.name) > 20]\n",
|
||||
"spam_filenames = [f for f in sorted(spam_dir.iterdir()) if len(f.name) > 20]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 128,
|
||||
"execution_count": 41,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2030,7 +2012,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 129,
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2046,27 +2028,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 130,
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import email\n",
|
||||
"import email.policy\n",
|
||||
"\n",
|
||||
"def load_email(is_spam, filename, spam_path=SPAM_PATH):\n",
|
||||
" directory = \"spam\" if is_spam else \"easy_ham\"\n",
|
||||
" with open(os.path.join(spam_path, directory, filename), \"rb\") as f:\n",
|
||||
"def load_email(filepath):\n",
|
||||
" with open(filepath, \"rb\") as f:\n",
|
||||
" return email.parser.BytesParser(policy=email.policy.default).parse(f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 131,
|
||||
"execution_count": 48,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ham_emails = [load_email(is_spam=False, filename=name) for name in ham_filenames]\n",
|
||||
"spam_emails = [load_email(is_spam=True, filename=name) for name in spam_filenames]"
|
||||
"ham_emails = [load_email(filepath) for filepath in ham_filenames]\n",
|
||||
"spam_emails = [load_email(filepath) for filepath in spam_filenames]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -2078,7 +2059,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 132,
|
||||
"execution_count": 49,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2087,7 +2068,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 133,
|
||||
"execution_count": 50,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2103,7 +2084,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 134,
|
||||
"execution_count": 51,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2122,7 +2103,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 135,
|
||||
"execution_count": 52,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2138,7 +2119,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 136,
|
||||
"execution_count": 53,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2147,7 +2128,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 137,
|
||||
"execution_count": 54,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2170,7 +2151,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 138,
|
||||
"execution_count": 55,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2187,7 +2168,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 139,
|
||||
"execution_count": 56,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2203,7 +2184,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 140,
|
||||
"execution_count": 57,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2225,7 +2206,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 141,
|
||||
"execution_count": 58,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2249,7 +2230,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 142,
|
||||
"execution_count": 59,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2268,7 +2249,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 143,
|
||||
"execution_count": 60,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2284,7 +2265,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 144,
|
||||
"execution_count": 61,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2308,7 +2289,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 145,
|
||||
"execution_count": 62,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2326,7 +2307,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 146,
|
||||
"execution_count": 63,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2352,10 +2333,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 147,
|
||||
"execution_count": 67,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Is this notebook running on Colab or Kaggle?\n",
|
||||
"IS_COLAB = \"google.colab\" in sys.modules\n",
|
||||
"IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n",
|
||||
"\n",
|
||||
"# if running this notebook on Colab or Kaggle, we just pip install urlextract\n",
|
||||
"if IS_COLAB or IS_KAGGLE:\n",
|
||||
" %pip install -q -U urlextract"
|
||||
@@ -2370,7 +2355,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 148,
|
||||
"execution_count": 68,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2393,7 +2378,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 149,
|
||||
"execution_count": 69,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2445,7 +2430,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 150,
|
||||
"execution_count": 70,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2470,7 +2455,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 151,
|
||||
"execution_count": 71,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2501,7 +2486,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 152,
|
||||
"execution_count": 72,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2512,7 +2497,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 153,
|
||||
"execution_count": 73,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2528,7 +2513,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 154,
|
||||
"execution_count": 74,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2544,7 +2529,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 155,
|
||||
"execution_count": 75,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2567,14 +2552,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 156,
|
||||
"execution_count": 76,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.linear_model import LogisticRegression\n",
|
||||
"from sklearn.model_selection import cross_val_score\n",
|
||||
"\n",
|
||||
"log_clf = LogisticRegression(solver=\"lbfgs\", max_iter=1000, random_state=42)\n",
|
||||
"log_clf = LogisticRegression(max_iter=1000, random_state=42)\n",
|
||||
"score = cross_val_score(log_clf, X_train_transformed, y_train, cv=3, verbose=3)\n",
|
||||
"score.mean()"
|
||||
]
|
||||
@@ -2590,7 +2575,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 157,
|
||||
"execution_count": 78,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -2598,7 +2583,7 @@
|
||||
"\n",
|
||||
"X_test_transformed = preprocess_pipeline.transform(X_test)\n",
|
||||
"\n",
|
||||
"log_clf = LogisticRegression(solver=\"lbfgs\", max_iter=1000, random_state=42)\n",
|
||||
"log_clf = LogisticRegression(max_iter=1000, random_state=42)\n",
|
||||
"log_clf.fit(X_train_transformed, y_train)\n",
|
||||
"\n",
|
||||
"y_pred = log_clf.predict(X_test_transformed)\n",
|
||||
@@ -2617,7 +2602,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user