Files
ArtStudies/M1/Statistical Learning/TP5_Naive_Bayes.ipynb
Arthur DANJOU f94ff07cab Refactor code for improved readability and consistency across notebooks
- Standardized spacing around operators and function arguments in TP7_Kmeans.ipynb and neural_network.ipynb.
- Enhanced the formatting of model building and training code in neural_network.ipynb for better clarity.
- Updated the pyproject.toml to remove a specific TensorFlow version and added linting configuration for Ruff.
- Improved comments and organization in the code to facilitate easier understanding and maintenance.
2025-07-01 20:46:08 +02:00

2213 lines
84 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TP5 Naive Bayes : Spam or not-spam\n",
"\n",
"\n",
"### Table of Contents\n",
"\n",
"* [0. Data preparation](#chapter0)\n",
"* [1. Feature engineering : Text --> Vector](#chapter1)\n",
"* [2. Naive Bayes classifier](#chapter2)\n",
"* [3. Naive Bayes on MNIST and CIFAR10](#chapter3)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 0. Data preparation <a class=\"anchor\" id=\"chapter0\"></a>\n",
"\n",
"We want to be able to predict if an e-mail is a \"spam\" or not. We will use the dataset `spam`.\n",
"\n",
"Reference : the dataset `spam` is taken from https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First step : import the dataset "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.microsoft.datawrangler.viewer.v0+json": {
"columns": [
{
"name": "index",
"rawType": "int64",
"type": "integer"
},
{
"name": "v1",
"rawType": "object",
"type": "string"
},
{
"name": "v2",
"rawType": "object",
"type": "string"
},
{
"name": "Unnamed: 2",
"rawType": "object",
"type": "unknown"
},
{
"name": "Unnamed: 3",
"rawType": "object",
"type": "unknown"
},
{
"name": "Unnamed: 4",
"rawType": "object",
"type": "unknown"
}
],
"conversionMethod": "pd.DataFrame",
"ref": "37d5b76b-9fed-490f-9dd5-20a98409d9ca",
"rows": [
[
"0",
"ham",
"Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...",
null,
null,
null
],
[
"1",
"ham",
"Ok lar... Joking wif u oni...",
null,
null,
null
],
[
"2",
"spam",
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's",
null,
null,
null
],
[
"3",
"ham",
"U dun say so early hor... U c already then say...",
null,
null,
null
],
[
"4",
"ham",
"Nah I don't think he goes to usf, he lives around here though",
null,
null,
null
]
],
"shape": {
"columns": 5,
"rows": 5
}
},
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>v1</th>\n",
" <th>v2</th>\n",
" <th>Unnamed: 2</th>\n",
" <th>Unnamed: 3</th>\n",
" <th>Unnamed: 4</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ham</td>\n",
" <td>Go until jurong point, crazy.. Available only ...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ham</td>\n",
" <td>Ok lar... Joking wif u oni...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>spam</td>\n",
" <td>Free entry in 2 a wkly comp to win FA Cup fina...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ham</td>\n",
" <td>U dun say so early hor... U c already then say...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ham</td>\n",
" <td>Nah I don't think he goes to usf, he lives aro...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" v1 v2 Unnamed: 2 \\\n",
"0 ham Go until jurong point, crazy.. Available only ... NaN \n",
"1 ham Ok lar... Joking wif u oni... NaN \n",
"2 spam Free entry in 2 a wkly comp to win FA Cup fina... NaN \n",
"3 ham U dun say so early hor... U c already then say... NaN \n",
"4 ham Nah I don't think he goes to usf, he lives aro... NaN \n",
"\n",
" Unnamed: 3 Unnamed: 4 \n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sms = pd.read_csv(\"data/spam.csv\", encoding=\"latin\")\n",
"\n",
"sms.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In column `v1`, `ham`= \"non-spam\". First we will rename the columns `v1` and `v2` : `v1`$\\rightarrow$ ` Label ` and `v2`$\\rightarrow$ `Texte`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"sms.rename(columns={\"v1\": \"Label\", \"v2\": \"Text\"}, inplace=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us check : "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.microsoft.datawrangler.viewer.v0+json": {
"columns": [
{
"name": "index",
"rawType": "int64",
"type": "integer"
},
{
"name": "Label",
"rawType": "object",
"type": "string"
},
{
"name": "Text",
"rawType": "object",
"type": "string"
},
{
"name": "Unnamed: 2",
"rawType": "object",
"type": "unknown"
},
{
"name": "Unnamed: 3",
"rawType": "object",
"type": "unknown"
},
{
"name": "Unnamed: 4",
"rawType": "object",
"type": "unknown"
}
],
"conversionMethod": "pd.DataFrame",
"ref": "7dc949eb-d6ef-4d5f-969f-c852d588c859",
"rows": [
[
"0",
"ham",
"Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...",
null,
null,
null
],
[
"1",
"ham",
"Ok lar... Joking wif u oni...",
null,
null,
null
],
[
"2",
"spam",
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's",
null,
null,
null
],
[
"3",
"ham",
"U dun say so early hor... U c already then say...",
null,
null,
null
],
[
"4",
"ham",
"Nah I don't think he goes to usf, he lives around here though",
null,
null,
null
]
],
"shape": {
"columns": 5,
"rows": 5
}
},
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Label</th>\n",
" <th>Text</th>\n",
" <th>Unnamed: 2</th>\n",
" <th>Unnamed: 3</th>\n",
" <th>Unnamed: 4</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ham</td>\n",
" <td>Go until jurong point, crazy.. Available only ...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ham</td>\n",
" <td>Ok lar... Joking wif u oni...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>spam</td>\n",
" <td>Free entry in 2 a wkly comp to win FA Cup fina...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ham</td>\n",
" <td>U dun say so early hor... U c already then say...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ham</td>\n",
" <td>Nah I don't think he goes to usf, he lives aro...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Label Text Unnamed: 2 \\\n",
"0 ham Go until jurong point, crazy.. Available only ... NaN \n",
"1 ham Ok lar... Joking wif u oni... NaN \n",
"2 spam Free entry in 2 a wkly comp to win FA Cup fina... NaN \n",
"3 ham U dun say so early hor... U c already then say... NaN \n",
"4 ham Nah I don't think he goes to usf, he lives aro... NaN \n",
"\n",
" Unnamed: 3 Unnamed: 4 \n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sms.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we create a new column, named `Labelnum` that contains zeros and ones : `ham`$\\rightarrow$ ` 0 ` and `spam`$\\rightarrow$ ` 1 `."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.microsoft.datawrangler.viewer.v0+json": {
"columns": [
{
"name": "index",
"rawType": "int64",
"type": "integer"
},
{
"name": "Label",
"rawType": "object",
"type": "string"
},
{
"name": "Text",
"rawType": "object",
"type": "string"
},
{
"name": "Unnamed: 2",
"rawType": "object",
"type": "unknown"
},
{
"name": "Unnamed: 3",
"rawType": "object",
"type": "unknown"
},
{
"name": "Unnamed: 4",
"rawType": "object",
"type": "unknown"
},
{
"name": "Labelnum",
"rawType": "int64",
"type": "integer"
}
],
"conversionMethod": "pd.DataFrame",
"ref": "6f54a3b7-bd30-4a68-b39b-4220633acbfc",
"rows": [
[
"0",
"ham",
"Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...",
null,
null,
null,
"0"
],
[
"1",
"ham",
"Ok lar... Joking wif u oni...",
null,
null,
null,
"0"
],
[
"2",
"spam",
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's",
null,
null,
null,
"1"
],
[
"3",
"ham",
"U dun say so early hor... U c already then say...",
null,
null,
null,
"0"
],
[
"4",
"ham",
"Nah I don't think he goes to usf, he lives around here though",
null,
null,
null,
"0"
]
],
"shape": {
"columns": 6,
"rows": 5
}
},
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Label</th>\n",
" <th>Text</th>\n",
" <th>Unnamed: 2</th>\n",
" <th>Unnamed: 3</th>\n",
" <th>Unnamed: 4</th>\n",
" <th>Labelnum</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ham</td>\n",
" <td>Go until jurong point, crazy.. Available only ...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ham</td>\n",
" <td>Ok lar... Joking wif u oni...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>spam</td>\n",
" <td>Free entry in 2 a wkly comp to win FA Cup fina...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ham</td>\n",
" <td>U dun say so early hor... U c already then say...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ham</td>\n",
" <td>Nah I don't think he goes to usf, he lives aro...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Label Text Unnamed: 2 \\\n",
"0 ham Go until jurong point, crazy.. Available only ... NaN \n",
"1 ham Ok lar... Joking wif u oni... NaN \n",
"2 spam Free entry in 2 a wkly comp to win FA Cup fina... NaN \n",
"3 ham U dun say so early hor... U c already then say... NaN \n",
"4 ham Nah I don't think he goes to usf, he lives aro... NaN \n",
"\n",
" Unnamed: 3 Unnamed: 4 Labelnum \n",
"0 NaN NaN 0 \n",
"1 NaN NaN 0 \n",
"2 NaN NaN 1 \n",
"3 NaN NaN 0 \n",
"4 NaN NaN 0 "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sms[\"Labelnum\"] = sms[\"Label\"].map({\"ham\": 0, \"spam\": 1})\n",
"\n",
"sms.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercice 1** : Display the sample size, the number of hams and the number of spams. You can use the three next cells that contain hints (or not...). "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5\n",
"[0 0]\n",
"2\n",
"[1 1 1]\n",
"3\n"
]
}
],
"source": [
"# Hint 1 for Exercise 1\n",
"a = np.array([0, 1, 1, 1, 0])\n",
"print(len(a))\n",
"print(a[a == 0])\n",
"print(len(a[a == 0]))\n",
"print(a[a == 1])\n",
"print(len(a[a == 1]))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.microsoft.datawrangler.viewer.v0+json": {
"columns": [
{
"name": "index",
"rawType": "int64",
"type": "integer"
},
{
"name": "Label",
"rawType": "object",
"type": "string"
},
{
"name": "Text",
"rawType": "object",
"type": "string"
},
{
"name": "Unnamed: 2",
"rawType": "object",
"type": "unknown"
},
{
"name": "Unnamed: 3",
"rawType": "object",
"type": "unknown"
},
{
"name": "Unnamed: 4",
"rawType": "object",
"type": "unknown"
},
{
"name": "Labelnum",
"rawType": "int64",
"type": "integer"
}
],
"conversionMethod": "pd.DataFrame",
"ref": "dc527433-ad4a-4862-bddf-ed1bfa01897a",
"rows": [
[
"0",
"ham",
"Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...",
null,
null,
null,
"0"
],
[
"1",
"ham",
"Ok lar... Joking wif u oni...",
null,
null,
null,
"0"
],
[
"3",
"ham",
"U dun say so early hor... U c already then say...",
null,
null,
null,
"0"
],
[
"4",
"ham",
"Nah I don't think he goes to usf, he lives around here though",
null,
null,
null,
"0"
],
[
"6",
"ham",
"Even my brother is not like to speak with me. They treat me like aids patent.",
null,
null,
null,
"0"
]
],
"shape": {
"columns": 6,
"rows": 5
}
},
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Label</th>\n",
" <th>Text</th>\n",
" <th>Unnamed: 2</th>\n",
" <th>Unnamed: 3</th>\n",
" <th>Unnamed: 4</th>\n",
" <th>Labelnum</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ham</td>\n",
" <td>Go until jurong point, crazy.. Available only ...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ham</td>\n",
" <td>Ok lar... Joking wif u oni...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ham</td>\n",
" <td>U dun say so early hor... U c already then say...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ham</td>\n",
" <td>Nah I don't think he goes to usf, he lives aro...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>ham</td>\n",
" <td>Even my brother is not like to speak with me. ...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Label Text Unnamed: 2 \\\n",
"0 ham Go until jurong point, crazy.. Available only ... NaN \n",
"1 ham Ok lar... Joking wif u oni... NaN \n",
"3 ham U dun say so early hor... U c already then say... NaN \n",
"4 ham Nah I don't think he goes to usf, he lives aro... NaN \n",
"6 ham Even my brother is not like to speak with me. ... NaN \n",
"\n",
" Unnamed: 3 Unnamed: 4 Labelnum \n",
"0 NaN NaN 0 \n",
"1 NaN NaN 0 \n",
"3 NaN NaN 0 \n",
"4 NaN NaN 0 \n",
"6 NaN NaN 0 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Hint 2 for Exercise 1\n",
"sms[sms.Labelnum == 0].head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.microsoft.datawrangler.viewer.v0+json": {
"columns": [
{
"name": "index",
"rawType": "int64",
"type": "integer"
},
{
"name": "Label",
"rawType": "object",
"type": "string"
},
{
"name": "Text",
"rawType": "object",
"type": "string"
},
{
"name": "Unnamed: 2",
"rawType": "object",
"type": "unknown"
},
{
"name": "Unnamed: 3",
"rawType": "object",
"type": "unknown"
},
{
"name": "Unnamed: 4",
"rawType": "object",
"type": "unknown"
},
{
"name": "Labelnum",
"rawType": "int64",
"type": "integer"
}
],
"conversionMethod": "pd.DataFrame",
"ref": "01a11b3a-e783-4b35-8867-c7a57df85078",
"rows": [
[
"2",
"spam",
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's",
null,
null,
null,
"1"
],
[
"5",
"spam",
"FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, å£1.50 to rcv",
null,
null,
null,
"1"
],
[
"8",
"spam",
"WINNER!! As a valued network customer you have been selected to receivea å£900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.",
null,
null,
null,
"1"
],
[
"9",
"spam",
"Had your mobile 11 months or more? U R entitled to Update to the latest colour mobiles with camera for Free! Call The Mobile Update Co FREE on 08002986030",
null,
null,
null,
"1"
],
[
"11",
"spam",
"SIX chances to win CASH! From 100 to 20,000 pounds txt> CSH11 and send to 87575. Cost 150p/day, 6days, 16+ TsandCs apply Reply HL 4 info",
null,
null,
null,
"1"
]
],
"shape": {
"columns": 6,
"rows": 5
}
},
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Label</th>\n",
" <th>Text</th>\n",
" <th>Unnamed: 2</th>\n",
" <th>Unnamed: 3</th>\n",
" <th>Unnamed: 4</th>\n",
" <th>Labelnum</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>spam</td>\n",
" <td>Free entry in 2 a wkly comp to win FA Cup fina...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>spam</td>\n",
" <td>FreeMsg Hey there darling it's been 3 week's n...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>spam</td>\n",
" <td>WINNER!! As a valued network customer you have...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>spam</td>\n",
" <td>Had your mobile 11 months or more? U R entitle...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>spam</td>\n",
" <td>SIX chances to win CASH! From 100 to 20,000 po...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Label Text Unnamed: 2 \\\n",
"2 spam Free entry in 2 a wkly comp to win FA Cup fina... NaN \n",
"5 spam FreeMsg Hey there darling it's been 3 week's n... NaN \n",
"8 spam WINNER!! As a valued network customer you have... NaN \n",
"9 spam Had your mobile 11 months or more? U R entitle... NaN \n",
"11 spam SIX chances to win CASH! From 100 to 20,000 po... NaN \n",
"\n",
" Unnamed: 3 Unnamed: 4 Labelnum \n",
"2 NaN NaN 1 \n",
"5 NaN NaN 1 \n",
"8 NaN NaN 1 \n",
"9 NaN NaN 1 \n",
"11 NaN NaN 1 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Hint 3 for Exercise 1\n",
"sms[sms.Labelnum == 1].head()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5572\n",
"(4825, 6)\n",
"(747, 6)\n"
]
}
],
"source": [
"print(len(sms))\n",
"print(sms[sms.Label == \"ham\"].shape)\n",
"print(sms[sms.Label == \"spam\"].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" **Exercise 2** : (**Optional**) To get an overall view of the data, we can plot a histogram of the length of each SMS.\n",
"\n",
"Hint 1 : How to access an SMS in the data? We can use\n",
"`pandas.DataFrame.loc`. Voir https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.loc.html\n",
"\n",
"Hint 2: How to create a histogram ? https://matplotlib.org/3.5.0/api/_as_gen/matplotlib.pyplot.hist.html"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...\n",
"--> The length of the first sms is 111\n"
]
}
],
"source": [
"# Hint 1 for Exercise 2\n",
"print(sms.loc[0, \"Text\"])\n",
"print(\"--> The length of the first sms is\", len(sms.loc[0, \"Text\"]))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10, 6))\n",
"plt.hist(\n",
" sms.loc[:, \"Text\"].apply(len),\n",
" bins=\"stone\",\n",
")\n",
"plt.title(\"Histogram of SMS Lengths\")\n",
"plt.xlabel(\"Length\")\n",
"plt.ylabel(\"Frequency\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"-------------------------------------\n",
"\n",
"## 1. Feature engineering : Text --> Vector <a class=\"anchor\" id=\"chapter1\"></a>\n",
"\n",
"\n",
"In this part, we will transform the text contained in an sms into a numerical vector in $\\mathbb{R}^{p}$. For this we will use `CountVectorizer`.\n",
"\n",
"Reference : https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.feature_extraction.text import CountVectorizer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---------------------------------\n",
"\n",
"Let us first give an example of use of `CountVectorizer`."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1. The vocabulary of Example is {'iphone': 3, 'gratuit': 2, 'mille': 4, 'vert': 5, 'euro': 1, 'argent': 0}\n",
"The vocabulary arranged in alphabetical order : ['argent', 'euro', 'gratuit', 'iphone', 'mille', 'vert']\n",
"2. The vectors corresponding to the sms are : \n",
" [[0 0 2 2 0 0]\n",
" [0 0 1 0 1 1]\n",
" [0 1 0 1 1 0]\n",
" [1 1 2 0 0 0]]\n",
"3. The numerical vector corresponding to (x_0=iphone gratuit) is \n",
" [[0 0 1 1 0 0]]\n"
]
}
],
"source": [
"Example = pd.DataFrame(\n",
" [\n",
" [\"iphone gratuit iphone gratuit\", 1],\n",
" [\"mille vert gratuit\", 0],\n",
" [\"iphone mille euro\", 0],\n",
" [\"argent gratuit euro gratuit\", 1],\n",
" ],\n",
" columns=[\"sms\", \"label\"],\n",
")\n",
"vec = CountVectorizer()\n",
"X = vec.fit_transform(Example.sms)\n",
"\n",
"# 1. Displaying the vocabulary\n",
"\n",
"print(\"1. The vocabulary of Example is \", vec.vocabulary_)\n",
"\n",
"# 1 bis :\n",
"\n",
"print(\n",
" \"The vocabulary arranged in alphabetical order : \",\n",
" sorted(list(vec.vocabulary_.keys())),\n",
")\n",
"\n",
"# 2. Displaying the vectors :\n",
"\n",
"print(\n",
" \"2. The vectors corresponding to the sms are : \\n\", X.toarray()\n",
") # X.toarray because\n",
"# X is a \"sparse\" matrix.\n",
"\n",
"# 3. For a new data x_0=\"iphone gratuit\",\n",
"# you must also transform x_0 into a numerical vector before predicting.\n",
"\n",
"vec_x_0 = vec.transform([\"iphone gratuit\"]).toarray() #\n",
"print(\"3. The numerical vector corresponding to (x_0=iphone gratuit) is \\n\", vec_x_0)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Compressed Sparse Row sparse matrix of dtype 'int64'\n",
"\twith 2 stored elements and shape (1, 6)>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#'sparse' version (without \"to_array\")\n",
"v = vec.transform([\"iphone iphone gratuit\"])\n",
"v"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0, 0, 1, 2, 0, 0]])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.toarray()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<Compressed Sparse Row sparse matrix of dtype 'int64'\n",
"\twith 2 stored elements and shape (1, 6)>\n",
" Coords\tValues\n",
" (0, 2)\t1\n",
" (0, 3)\t2\n"
]
}
],
"source": [
"# \"(0,2) 1\" means : the element in row 0 and column 2 is equal to 1.\n",
"# \"(0,3) 2\" means : the element in row 0 and column 3 is equal to 2.\n",
"print(v)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise 3** :\n",
"\n",
"1. Transform $x_1=$ \"iphone vert gratuit\" into a numerical vector adapted to the vocabulary created with `Example`. \n",
"\n",
"2. Do the same with $x_2=$ \"iphone rouge gratuit\". What do you observe ? "
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0 0 1 1 0 1]]\n",
"[[0 0 1 1 0 0]]\n"
]
}
],
"source": [
"vec_x_1 = vec.transform([\"iphone vert gratuit\"]).toarray()\n",
"vec_x_2 = vec.transform([\"iphone rouge gratuit\"]).toarray()\n",
"print(vec_x_1)\n",
"print(vec_x_2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"-------------------------------------\n",
"\n",
"Let us now go back to our original dataset `sms`. Maintenant on va changer les données `sms.Texte` en vecteur et les attribuer à `X`. De plus, on va attribuer `sms.Labelnum` au `Y`.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise 4** : Create a text-to-vector transformation model, named vectorizer. Train vectorizer on the sms.Text data.\n",
"\n",
"Note: We have already imported the CountVectorizer package."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"vectorizer = CountVectorizer()\n",
"X = vectorizer.fit_transform(sms[\"Text\"])\n",
"y = sms[\"Labelnum\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we split the sample into a training set and a test set. "
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"size of the training set: 3900\n",
"size of the test set : 1672\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y, test_size=0.30, random_state=50\n",
")\n",
"\n",
"print(\"size of the training set: \", X_train.shape[0])\n",
"print(\"size of the test set :\", X_test.shape[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"------------------------------\n",
"\n",
"## 2. Naive Bayes classification <a class=\"anchor\" id=\"chapter2\"></a>\n",
"\n",
"Now we will train a Naive Bayes classification model. The class we will use is MultinomialNB from sklearn.naive_bayes.\n",
"\n",
"Reference : https://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.MultinomialNB.html"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.naive_bayes import MultinomialNB"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise 6** : \n",
"\n",
"1. Create a Naive Bayes classification model with the smoothing parameter $\\alpha$=1.0 (alpha=1.0), named `sms_bayes`.\n",
"\n",
"What is the role of the smoothing parameter $\\alpha$ ? Refer to the course or this page: https://scikit-learn.org/stable/modules/naive_bayes.html#multinomial-naive-bayes.\n",
" \n",
" \n",
"\n",
"2. Fit `sms_bayes` on (`X_train, y_train`)."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: #000;\n",
" --sklearn-color-text-muted: #666;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-1 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-1 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-1 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: flex;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
" align-items: start;\n",
" justify-content: space-between;\n",
" gap: 0.5em;\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label .caption {\n",
" font-size: 0.6rem;\n",
" font-weight: lighter;\n",
" color: var(--sklearn-color-text-muted);\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-1 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-1 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-1 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-1 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 0.5em;\n",
" text-align: center;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-1 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>MultinomialNB()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>MultinomialNB</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.naive_bayes.MultinomialNB.html\">?<span>Documentation for MultinomialNB</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>MultinomialNB()</pre></div> </div></div></div></div>"
],
"text/plain": [
"MultinomialNB()"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sms_bayes = MultinomialNB(alpha=1.0)\n",
"sms_bayes.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us look at the performance of `sms_bayes` on the test set :"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The accuracy score on the test set is 0.9754784688995215\n"
]
}
],
"source": [
"from sklearn.metrics import accuracy_score\n",
"\n",
"y_pred = sms_bayes.predict(X_test)\n",
"print(\"The accuracy score on the test set is \", accuracy_score(y_test, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---------------------------\n",
"\n",
"\n",
"In **Exercise 1**, It was observed that in this SMS dataset, there are significantly more non-spam messages (4825) than spam messages (747). In this case, it is better to also check the confusion matrix.\n",
"\n",
"Reference: Confusion Matrix\n",
" https://fr.wikipedia.org/wiki/Matrice_de_confusion"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1410, 16],\n",
" [ 25, 221]])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"\n",
"confusion_matrix(y_test, y_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**(Optional)** \n",
"\n",
"Test whether your SMS will be classified as spam or not. Replace `something new` in the next cell with your SMS.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1 1 0 0]\n"
]
}
],
"source": [
"my_sms = vectorizer.transform(\n",
" [\n",
" \"free trial!\",\n",
" \"Iphone 15 is now free\",\n",
" \"I want coffee\",\n",
" \"I want to buy a new iphone\",\n",
" ]\n",
")\n",
"\n",
"pred_my_sms = sms_bayes.predict(my_sms)\n",
"print(pred_my_sms)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Naive Bayes on MNIST and cifar-10 <a class=\"anchor\" id=\"chapter3\"></a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will implement Exercise 6.2 from the course handout. We will implement a Bernoulli Naive Bayes model instead of Multinomial Naive Bayes model. For that, we will use `BernoulliNB` in `sklearn.naive_bayes`. First let us deal with MNIST. "
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_openml\n",
"from sklearn.naive_bayes import BernoulliNB\n",
"\n",
"# Load the MNIST dataset\n",
"mnist = fetch_openml(\"mnist_784\", version=1, parser=\"auto\")\n",
"X, y = mnist.data, mnist.target"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise 7**\n",
"1. Convert pixel values to $\\{0, 1\\}$. The pixel values range from 0 to 256. All values above 127 are converted to 1, the others to 0. \n",
"\n",
"2. Split into training and test sets (25% for the test set and `random_state=42`). \n",
"\n",
"3. Initialize and train a Bernoulli Naive Bayes classifier. \n",
"\n",
"4. Make the predictions on the test set and compute the accuracy score. "
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8375428571428571"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_copy = (X.copy() >= 127).astype(int)\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" X_copy, y, test_size=0.25, random_state=42\n",
")\n",
"\n",
"ber_bayes = BernoulliNB()\n",
"ber_bayes.fit(X_train, y_train)\n",
"\n",
"y_pred = ber_bayes.predict(X_test)\n",
"accuracy_score(y_test, y_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For **cifar10**, we will do the same. First let us import the dataset. "
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"from keras.datasets import cifar10\n",
"\n",
"(x_train, y_train), (x_test, y_test) = cifar10.load_data()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(50000, 32, 32, 3)\n",
"(50000, 1)\n"
]
}
],
"source": [
"# reminder : the output is an RGB image 32 x 32\n",
"print(x_train.shape)\n",
"print(y_train.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" We need to convert the RGB values into grayscale. For that we use `cvtColor` in the `cv2` package. You may need to install the library `opencv`. \n",
"\n",
"Remark : `cvtColor` takes an image and the output of cifar10.load_data is an image as well. "
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"# Convert images to grayscale\n",
"import cv2\n",
"\n",
"x_train_gray = np.array([cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) for img in x_train])\n",
"x_test_gray = np.array([cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) for img in x_test])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise 8** Implement a Bernoulli Naive Bayes classifier on the grayscale images as we did for MNIST and compute the accuracy score on the test images. "
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Naive Bayes Classification Accuracy: 0.2359\n"
]
}
],
"source": [
"x_train_binarized = (x_train_gray > 127).astype(int)\n",
"x_test_binarized = (x_test_gray > 127).astype(int)\n",
"\n",
"# Flatten the images into vectors of size 1024 (32x32)\n",
"x_train_flattened = x_train_binarized.reshape(x_train_binarized.shape[0], -1)\n",
"x_test_flattened = x_test_binarized.reshape(x_test_binarized.shape[0], -1)\n",
"\n",
"# Initialize and train a Naive Bayes classifier\n",
"nb_classifier = BernoulliNB()\n",
"nb_classifier.fit(x_train_flattened, y_train.ravel())\n",
"\n",
"# Make predictions on the test set\n",
"y_pred = nb_classifier.predict(x_test_flattened)\n",
"\n",
"# Compute accuracy\n",
"accuracy = accuracy_score(y_test.ravel(), y_pred)\n",
"print(f\"Naive Bayes Classification Accuracy: {accuracy:.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(**Optional**)If you want to know more about the conversion RGB to gray scale, the formula is as follows : \n",
"\n",
"Y = 0.299R + 0.587G + 0.114B\n",
"\n",
"Where:\n",
"\n",
" Y is the resulting grayscale value\n",
"\n",
" R, G, and B are the red, green, and blue color channel values respectively\n",
"\n",
"This formula, also known as the weighted method, takes into account the human eye's different sensitivities to red, green, and blue light. Green is given the highest weight (0.587) because the human eye is most sensitive to green light, followed by red (0.299), and then blue (0.114)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}