mirror of
https://github.com/ArthurDanjou/ArtStudies.git
synced 2026-01-14 15:54:13 +01:00
- Adjusted indentation and line breaks for better clarity in function definitions and import statements. - Standardized string quotes for consistency across the codebase. - Enhanced readability of DataFrame creation and manipulation by breaking long lines into multiple lines. - Cleaned up print statements and comments for improved understanding. - Ensured consistent use of whitespace around operators and after commas.
2645 lines
446 KiB
Plaintext
2645 lines
446 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"# TP4 Ridge, Lasso, CV\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"### Table of Contents\n",
|
||
"\n",
|
||
"* [0. Data Preparation ](#chapter0)\n",
|
||
"* [1. Ridge and Lasso Regression ](#chapter1)\n",
|
||
"* [2. Cross validation for the hyperparameters $\\alpha$ of Ridge and Lasso](#chapter2)\n",
|
||
"\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 0. Data Preparation <a class=\"anchor\" id=\"chapter0\"></a>\n",
|
||
"\n",
|
||
"We will predict the salary of a baseball player and use the dataset `Hitters`.\n",
|
||
"\n",
|
||
"Reference : book \"James, Gareth, Daniela Witten, Trevor Hastie, and Robert Tibshirani. An introduction to statistical learning. Vol. 112. New York: springer, 2013\"."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:52.177602Z",
|
||
"start_time": "2025-03-26T10:33:52.174201Z"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import warnings\n",
|
||
"\n",
|
||
"warnings.filterwarnings(\"ignore\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:34:01.222808Z",
|
||
"start_time": "2025-03-26T10:34:01.200754Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>AtBat</th>\n",
|
||
" <th>Hits</th>\n",
|
||
" <th>HmRun</th>\n",
|
||
" <th>Runs</th>\n",
|
||
" <th>RBI</th>\n",
|
||
" <th>Walks</th>\n",
|
||
" <th>Years</th>\n",
|
||
" <th>CAtBat</th>\n",
|
||
" <th>CHits</th>\n",
|
||
" <th>CHmRun</th>\n",
|
||
" <th>CRuns</th>\n",
|
||
" <th>CRBI</th>\n",
|
||
" <th>CWalks</th>\n",
|
||
" <th>League</th>\n",
|
||
" <th>Division</th>\n",
|
||
" <th>PutOuts</th>\n",
|
||
" <th>Assists</th>\n",
|
||
" <th>Errors</th>\n",
|
||
" <th>Salary</th>\n",
|
||
" <th>NewLeague</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Name</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>-Andy Allanson</th>\n",
|
||
" <td>293</td>\n",
|
||
" <td>66</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>30</td>\n",
|
||
" <td>29</td>\n",
|
||
" <td>14</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>293</td>\n",
|
||
" <td>66</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>30</td>\n",
|
||
" <td>29</td>\n",
|
||
" <td>14</td>\n",
|
||
" <td>A</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>446</td>\n",
|
||
" <td>33</td>\n",
|
||
" <td>20</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>A</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>-Alan Ashby</th>\n",
|
||
" <td>315</td>\n",
|
||
" <td>81</td>\n",
|
||
" <td>7</td>\n",
|
||
" <td>24</td>\n",
|
||
" <td>38</td>\n",
|
||
" <td>39</td>\n",
|
||
" <td>14</td>\n",
|
||
" <td>3449</td>\n",
|
||
" <td>835</td>\n",
|
||
" <td>69</td>\n",
|
||
" <td>321</td>\n",
|
||
" <td>414</td>\n",
|
||
" <td>375</td>\n",
|
||
" <td>N</td>\n",
|
||
" <td>W</td>\n",
|
||
" <td>632</td>\n",
|
||
" <td>43</td>\n",
|
||
" <td>10</td>\n",
|
||
" <td>475.0</td>\n",
|
||
" <td>N</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>-Alvin Davis</th>\n",
|
||
" <td>479</td>\n",
|
||
" <td>130</td>\n",
|
||
" <td>18</td>\n",
|
||
" <td>66</td>\n",
|
||
" <td>72</td>\n",
|
||
" <td>76</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>1624</td>\n",
|
||
" <td>457</td>\n",
|
||
" <td>63</td>\n",
|
||
" <td>224</td>\n",
|
||
" <td>266</td>\n",
|
||
" <td>263</td>\n",
|
||
" <td>A</td>\n",
|
||
" <td>W</td>\n",
|
||
" <td>880</td>\n",
|
||
" <td>82</td>\n",
|
||
" <td>14</td>\n",
|
||
" <td>480.0</td>\n",
|
||
" <td>A</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>-Andre Dawson</th>\n",
|
||
" <td>496</td>\n",
|
||
" <td>141</td>\n",
|
||
" <td>20</td>\n",
|
||
" <td>65</td>\n",
|
||
" <td>78</td>\n",
|
||
" <td>37</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>5628</td>\n",
|
||
" <td>1575</td>\n",
|
||
" <td>225</td>\n",
|
||
" <td>828</td>\n",
|
||
" <td>838</td>\n",
|
||
" <td>354</td>\n",
|
||
" <td>N</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>200</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>500.0</td>\n",
|
||
" <td>N</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>-Andres Galarraga</th>\n",
|
||
" <td>321</td>\n",
|
||
" <td>87</td>\n",
|
||
" <td>10</td>\n",
|
||
" <td>39</td>\n",
|
||
" <td>42</td>\n",
|
||
" <td>30</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>396</td>\n",
|
||
" <td>101</td>\n",
|
||
" <td>12</td>\n",
|
||
" <td>48</td>\n",
|
||
" <td>46</td>\n",
|
||
" <td>33</td>\n",
|
||
" <td>N</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>805</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>91.5</td>\n",
|
||
" <td>N</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>-Willie McGee</th>\n",
|
||
" <td>497</td>\n",
|
||
" <td>127</td>\n",
|
||
" <td>7</td>\n",
|
||
" <td>65</td>\n",
|
||
" <td>48</td>\n",
|
||
" <td>37</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>2703</td>\n",
|
||
" <td>806</td>\n",
|
||
" <td>32</td>\n",
|
||
" <td>379</td>\n",
|
||
" <td>311</td>\n",
|
||
" <td>138</td>\n",
|
||
" <td>N</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>325</td>\n",
|
||
" <td>9</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>700.0</td>\n",
|
||
" <td>N</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>-Willie Randolph</th>\n",
|
||
" <td>492</td>\n",
|
||
" <td>136</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>76</td>\n",
|
||
" <td>50</td>\n",
|
||
" <td>94</td>\n",
|
||
" <td>12</td>\n",
|
||
" <td>5511</td>\n",
|
||
" <td>1511</td>\n",
|
||
" <td>39</td>\n",
|
||
" <td>897</td>\n",
|
||
" <td>451</td>\n",
|
||
" <td>875</td>\n",
|
||
" <td>A</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>313</td>\n",
|
||
" <td>381</td>\n",
|
||
" <td>20</td>\n",
|
||
" <td>875.0</td>\n",
|
||
" <td>A</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>-Wayne Tolleson</th>\n",
|
||
" <td>475</td>\n",
|
||
" <td>126</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>61</td>\n",
|
||
" <td>43</td>\n",
|
||
" <td>52</td>\n",
|
||
" <td>6</td>\n",
|
||
" <td>1700</td>\n",
|
||
" <td>433</td>\n",
|
||
" <td>7</td>\n",
|
||
" <td>217</td>\n",
|
||
" <td>93</td>\n",
|
||
" <td>146</td>\n",
|
||
" <td>A</td>\n",
|
||
" <td>W</td>\n",
|
||
" <td>37</td>\n",
|
||
" <td>113</td>\n",
|
||
" <td>7</td>\n",
|
||
" <td>385.0</td>\n",
|
||
" <td>A</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>-Willie Upshaw</th>\n",
|
||
" <td>573</td>\n",
|
||
" <td>144</td>\n",
|
||
" <td>9</td>\n",
|
||
" <td>85</td>\n",
|
||
" <td>60</td>\n",
|
||
" <td>78</td>\n",
|
||
" <td>8</td>\n",
|
||
" <td>3198</td>\n",
|
||
" <td>857</td>\n",
|
||
" <td>97</td>\n",
|
||
" <td>470</td>\n",
|
||
" <td>420</td>\n",
|
||
" <td>332</td>\n",
|
||
" <td>A</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>1314</td>\n",
|
||
" <td>131</td>\n",
|
||
" <td>12</td>\n",
|
||
" <td>960.0</td>\n",
|
||
" <td>A</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>-Willie Wilson</th>\n",
|
||
" <td>631</td>\n",
|
||
" <td>170</td>\n",
|
||
" <td>9</td>\n",
|
||
" <td>77</td>\n",
|
||
" <td>44</td>\n",
|
||
" <td>31</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>4908</td>\n",
|
||
" <td>1457</td>\n",
|
||
" <td>30</td>\n",
|
||
" <td>775</td>\n",
|
||
" <td>357</td>\n",
|
||
" <td>249</td>\n",
|
||
" <td>A</td>\n",
|
||
" <td>W</td>\n",
|
||
" <td>408</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>1000.0</td>\n",
|
||
" <td>A</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>322 rows × 20 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits \\\n",
|
||
"Name \n",
|
||
"-Andy Allanson 293 66 1 30 29 14 1 293 66 \n",
|
||
"-Alan Ashby 315 81 7 24 38 39 14 3449 835 \n",
|
||
"-Alvin Davis 479 130 18 66 72 76 3 1624 457 \n",
|
||
"-Andre Dawson 496 141 20 65 78 37 11 5628 1575 \n",
|
||
"-Andres Galarraga 321 87 10 39 42 30 2 396 101 \n",
|
||
"... ... ... ... ... ... ... ... ... ... \n",
|
||
"-Willie McGee 497 127 7 65 48 37 5 2703 806 \n",
|
||
"-Willie Randolph 492 136 5 76 50 94 12 5511 1511 \n",
|
||
"-Wayne Tolleson 475 126 3 61 43 52 6 1700 433 \n",
|
||
"-Willie Upshaw 573 144 9 85 60 78 8 3198 857 \n",
|
||
"-Willie Wilson 631 170 9 77 44 31 11 4908 1457 \n",
|
||
"\n",
|
||
" CHmRun CRuns CRBI CWalks League Division PutOuts \\\n",
|
||
"Name \n",
|
||
"-Andy Allanson 1 30 29 14 A E 446 \n",
|
||
"-Alan Ashby 69 321 414 375 N W 632 \n",
|
||
"-Alvin Davis 63 224 266 263 A W 880 \n",
|
||
"-Andre Dawson 225 828 838 354 N E 200 \n",
|
||
"-Andres Galarraga 12 48 46 33 N E 805 \n",
|
||
"... ... ... ... ... ... ... ... \n",
|
||
"-Willie McGee 32 379 311 138 N E 325 \n",
|
||
"-Willie Randolph 39 897 451 875 A E 313 \n",
|
||
"-Wayne Tolleson 7 217 93 146 A W 37 \n",
|
||
"-Willie Upshaw 97 470 420 332 A E 1314 \n",
|
||
"-Willie Wilson 30 775 357 249 A W 408 \n",
|
||
"\n",
|
||
" Assists Errors Salary NewLeague \n",
|
||
"Name \n",
|
||
"-Andy Allanson 33 20 NaN A \n",
|
||
"-Alan Ashby 43 10 475.0 N \n",
|
||
"-Alvin Davis 82 14 480.0 A \n",
|
||
"-Andre Dawson 11 3 500.0 N \n",
|
||
"-Andres Galarraga 40 4 91.5 N \n",
|
||
"... ... ... ... ... \n",
|
||
"-Willie McGee 9 3 700.0 N \n",
|
||
"-Willie Randolph 381 20 875.0 A \n",
|
||
"-Wayne Tolleson 113 7 385.0 A \n",
|
||
"-Willie Upshaw 131 12 960.0 A \n",
|
||
"-Willie Wilson 4 3 1000.0 A \n",
|
||
"\n",
|
||
"[322 rows x 20 columns]"
|
||
]
|
||
},
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd # dataframes are in pandas\n",
|
||
"\n",
|
||
"hitters = pd.read_csv(\"data/Hitters.csv\", index_col=\"Name\")\n",
|
||
"\n",
|
||
"hitters"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Exercise 1** : \n",
|
||
"\n",
|
||
"In `pd.read_csv(\"Hitters.csv\", index_col = \"Name\") `, what does `index_col = \"Name\"` mean ? Try without `index_col = \"Name\"`.\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:52.585039Z",
|
||
"start_time": "2025-03-26T10:33:52.568658Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Name</th>\n",
|
||
" <th>AtBat</th>\n",
|
||
" <th>Hits</th>\n",
|
||
" <th>HmRun</th>\n",
|
||
" <th>Runs</th>\n",
|
||
" <th>RBI</th>\n",
|
||
" <th>Walks</th>\n",
|
||
" <th>Years</th>\n",
|
||
" <th>CAtBat</th>\n",
|
||
" <th>CHits</th>\n",
|
||
" <th>...</th>\n",
|
||
" <th>CRuns</th>\n",
|
||
" <th>CRBI</th>\n",
|
||
" <th>CWalks</th>\n",
|
||
" <th>League</th>\n",
|
||
" <th>Division</th>\n",
|
||
" <th>PutOuts</th>\n",
|
||
" <th>Assists</th>\n",
|
||
" <th>Errors</th>\n",
|
||
" <th>Salary</th>\n",
|
||
" <th>NewLeague</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>-Andy Allanson</td>\n",
|
||
" <td>293</td>\n",
|
||
" <td>66</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>30</td>\n",
|
||
" <td>29</td>\n",
|
||
" <td>14</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>293</td>\n",
|
||
" <td>66</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>30</td>\n",
|
||
" <td>29</td>\n",
|
||
" <td>14</td>\n",
|
||
" <td>A</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>446</td>\n",
|
||
" <td>33</td>\n",
|
||
" <td>20</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>A</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>-Alan Ashby</td>\n",
|
||
" <td>315</td>\n",
|
||
" <td>81</td>\n",
|
||
" <td>7</td>\n",
|
||
" <td>24</td>\n",
|
||
" <td>38</td>\n",
|
||
" <td>39</td>\n",
|
||
" <td>14</td>\n",
|
||
" <td>3449</td>\n",
|
||
" <td>835</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>321</td>\n",
|
||
" <td>414</td>\n",
|
||
" <td>375</td>\n",
|
||
" <td>N</td>\n",
|
||
" <td>W</td>\n",
|
||
" <td>632</td>\n",
|
||
" <td>43</td>\n",
|
||
" <td>10</td>\n",
|
||
" <td>475.0</td>\n",
|
||
" <td>N</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>-Alvin Davis</td>\n",
|
||
" <td>479</td>\n",
|
||
" <td>130</td>\n",
|
||
" <td>18</td>\n",
|
||
" <td>66</td>\n",
|
||
" <td>72</td>\n",
|
||
" <td>76</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>1624</td>\n",
|
||
" <td>457</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>224</td>\n",
|
||
" <td>266</td>\n",
|
||
" <td>263</td>\n",
|
||
" <td>A</td>\n",
|
||
" <td>W</td>\n",
|
||
" <td>880</td>\n",
|
||
" <td>82</td>\n",
|
||
" <td>14</td>\n",
|
||
" <td>480.0</td>\n",
|
||
" <td>A</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>-Andre Dawson</td>\n",
|
||
" <td>496</td>\n",
|
||
" <td>141</td>\n",
|
||
" <td>20</td>\n",
|
||
" <td>65</td>\n",
|
||
" <td>78</td>\n",
|
||
" <td>37</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>5628</td>\n",
|
||
" <td>1575</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>828</td>\n",
|
||
" <td>838</td>\n",
|
||
" <td>354</td>\n",
|
||
" <td>N</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>200</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>500.0</td>\n",
|
||
" <td>N</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>-Andres Galarraga</td>\n",
|
||
" <td>321</td>\n",
|
||
" <td>87</td>\n",
|
||
" <td>10</td>\n",
|
||
" <td>39</td>\n",
|
||
" <td>42</td>\n",
|
||
" <td>30</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>396</td>\n",
|
||
" <td>101</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>48</td>\n",
|
||
" <td>46</td>\n",
|
||
" <td>33</td>\n",
|
||
" <td>N</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>805</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>91.5</td>\n",
|
||
" <td>N</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>317</th>\n",
|
||
" <td>-Willie McGee</td>\n",
|
||
" <td>497</td>\n",
|
||
" <td>127</td>\n",
|
||
" <td>7</td>\n",
|
||
" <td>65</td>\n",
|
||
" <td>48</td>\n",
|
||
" <td>37</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>2703</td>\n",
|
||
" <td>806</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>379</td>\n",
|
||
" <td>311</td>\n",
|
||
" <td>138</td>\n",
|
||
" <td>N</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>325</td>\n",
|
||
" <td>9</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>700.0</td>\n",
|
||
" <td>N</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>318</th>\n",
|
||
" <td>-Willie Randolph</td>\n",
|
||
" <td>492</td>\n",
|
||
" <td>136</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>76</td>\n",
|
||
" <td>50</td>\n",
|
||
" <td>94</td>\n",
|
||
" <td>12</td>\n",
|
||
" <td>5511</td>\n",
|
||
" <td>1511</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>897</td>\n",
|
||
" <td>451</td>\n",
|
||
" <td>875</td>\n",
|
||
" <td>A</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>313</td>\n",
|
||
" <td>381</td>\n",
|
||
" <td>20</td>\n",
|
||
" <td>875.0</td>\n",
|
||
" <td>A</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>319</th>\n",
|
||
" <td>-Wayne Tolleson</td>\n",
|
||
" <td>475</td>\n",
|
||
" <td>126</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>61</td>\n",
|
||
" <td>43</td>\n",
|
||
" <td>52</td>\n",
|
||
" <td>6</td>\n",
|
||
" <td>1700</td>\n",
|
||
" <td>433</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>217</td>\n",
|
||
" <td>93</td>\n",
|
||
" <td>146</td>\n",
|
||
" <td>A</td>\n",
|
||
" <td>W</td>\n",
|
||
" <td>37</td>\n",
|
||
" <td>113</td>\n",
|
||
" <td>7</td>\n",
|
||
" <td>385.0</td>\n",
|
||
" <td>A</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>320</th>\n",
|
||
" <td>-Willie Upshaw</td>\n",
|
||
" <td>573</td>\n",
|
||
" <td>144</td>\n",
|
||
" <td>9</td>\n",
|
||
" <td>85</td>\n",
|
||
" <td>60</td>\n",
|
||
" <td>78</td>\n",
|
||
" <td>8</td>\n",
|
||
" <td>3198</td>\n",
|
||
" <td>857</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>470</td>\n",
|
||
" <td>420</td>\n",
|
||
" <td>332</td>\n",
|
||
" <td>A</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>1314</td>\n",
|
||
" <td>131</td>\n",
|
||
" <td>12</td>\n",
|
||
" <td>960.0</td>\n",
|
||
" <td>A</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>321</th>\n",
|
||
" <td>-Willie Wilson</td>\n",
|
||
" <td>631</td>\n",
|
||
" <td>170</td>\n",
|
||
" <td>9</td>\n",
|
||
" <td>77</td>\n",
|
||
" <td>44</td>\n",
|
||
" <td>31</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>4908</td>\n",
|
||
" <td>1457</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>775</td>\n",
|
||
" <td>357</td>\n",
|
||
" <td>249</td>\n",
|
||
" <td>A</td>\n",
|
||
" <td>W</td>\n",
|
||
" <td>408</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>1000.0</td>\n",
|
||
" <td>A</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>322 rows × 21 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Name AtBat Hits HmRun Runs RBI Walks Years CAtBat \\\n",
|
||
"0 -Andy Allanson 293 66 1 30 29 14 1 293 \n",
|
||
"1 -Alan Ashby 315 81 7 24 38 39 14 3449 \n",
|
||
"2 -Alvin Davis 479 130 18 66 72 76 3 1624 \n",
|
||
"3 -Andre Dawson 496 141 20 65 78 37 11 5628 \n",
|
||
"4 -Andres Galarraga 321 87 10 39 42 30 2 396 \n",
|
||
".. ... ... ... ... ... ... ... ... ... \n",
|
||
"317 -Willie McGee 497 127 7 65 48 37 5 2703 \n",
|
||
"318 -Willie Randolph 492 136 5 76 50 94 12 5511 \n",
|
||
"319 -Wayne Tolleson 475 126 3 61 43 52 6 1700 \n",
|
||
"320 -Willie Upshaw 573 144 9 85 60 78 8 3198 \n",
|
||
"321 -Willie Wilson 631 170 9 77 44 31 11 4908 \n",
|
||
"\n",
|
||
" CHits ... CRuns CRBI CWalks League Division PutOuts Assists \\\n",
|
||
"0 66 ... 30 29 14 A E 446 33 \n",
|
||
"1 835 ... 321 414 375 N W 632 43 \n",
|
||
"2 457 ... 224 266 263 A W 880 82 \n",
|
||
"3 1575 ... 828 838 354 N E 200 11 \n",
|
||
"4 101 ... 48 46 33 N E 805 40 \n",
|
||
".. ... ... ... ... ... ... ... ... ... \n",
|
||
"317 806 ... 379 311 138 N E 325 9 \n",
|
||
"318 1511 ... 897 451 875 A E 313 381 \n",
|
||
"319 433 ... 217 93 146 A W 37 113 \n",
|
||
"320 857 ... 470 420 332 A E 1314 131 \n",
|
||
"321 1457 ... 775 357 249 A W 408 4 \n",
|
||
"\n",
|
||
" Errors Salary NewLeague \n",
|
||
"0 20 NaN A \n",
|
||
"1 10 475.0 N \n",
|
||
"2 14 480.0 A \n",
|
||
"3 3 500.0 N \n",
|
||
"4 4 91.5 N \n",
|
||
".. ... ... ... \n",
|
||
"317 3 700.0 N \n",
|
||
"318 20 875.0 A \n",
|
||
"319 7 385.0 A \n",
|
||
"320 12 960.0 A \n",
|
||
"321 3 1000.0 A \n",
|
||
"\n",
|
||
"[322 rows x 21 columns]"
|
||
]
|
||
},
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"hitters_bis = pd.read_csv(\"data/Hitters.csv\")\n",
|
||
"\n",
|
||
"hitters_bis"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Answer for ex. 1 : \n",
|
||
"\n",
|
||
"\n",
|
||
"It set the column named `Name` as the main column\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Exercise 2** :\n",
|
||
"\n",
|
||
"(1) Sample size of `Hitters` ? How many features in `Hitters` ?\n",
|
||
"\n",
|
||
"(2) What are the features in `Hitters` ? \n",
|
||
"\n",
|
||
"(3) Are all the features in $\\mathbb{R}$? \n",
|
||
"\n",
|
||
"(4) Are there many missing data ? `print` the number of missing data for each feature. \n",
|
||
"\n",
|
||
"- Hint : \n",
|
||
" - (2) et (3) Use `pandas.DataFrame.dtypes`. https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.dtypes.html\n",
|
||
" - (4) Use `pandas.DataFrame.isnull`.\n",
|
||
" https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.isnull.html\n",
|
||
" See the example below. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:52.804278Z",
|
||
"start_time": "2025-03-26T10:33:52.784889Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"data : \n",
|
||
" nom age exam\n",
|
||
"0 Alice 19.0 15.0\n",
|
||
"1 Nicolas NaN 14.0\n",
|
||
"2 Jean NaN NaN\n",
|
||
"First result : \n",
|
||
" nom age exam\n",
|
||
"0 False False False\n",
|
||
"1 False True False\n",
|
||
"2 False True True\n",
|
||
"Second result : \n",
|
||
" nom 0\n",
|
||
"age 2\n",
|
||
"exam 1\n",
|
||
"dtype: int64\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Hint for Question (4) :\n",
|
||
"ex = pd.DataFrame(\n",
|
||
" {\n",
|
||
" \"nom\": [\"Alice\", \"Nicolas\", \"Jean\"],\n",
|
||
" \"age\": [19, np.NaN, np.NaN],\n",
|
||
" \"exam\": [15, 14, np.NaN],\n",
|
||
" }\n",
|
||
")\n",
|
||
"\n",
|
||
"print(\"data : \\n\", ex)\n",
|
||
"print(\"First result : \\n\", ex.isnull())\n",
|
||
"print(\"Second result : \\n\", ex.isnull().sum())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:53.092776Z",
|
||
"start_time": "2025-03-26T10:33:53.053583Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"(322, 20)\n",
|
||
"Index(['AtBat', 'Hits', 'HmRun', 'Runs', 'RBI', 'Walks', 'Years', 'CAtBat',\n",
|
||
" 'CHits', 'CHmRun', 'CRuns', 'CRBI', 'CWalks', 'League', 'Division',\n",
|
||
" 'PutOuts', 'Assists', 'Errors', 'Salary', 'NewLeague'],\n",
|
||
" dtype='object')\n",
|
||
" AtBat Hits HmRun Runs RBI Walks \\\n",
|
||
"count 322.000000 322.000000 322.000000 322.000000 322.000000 322.000000 \n",
|
||
"mean 380.928571 101.024845 10.770186 50.909938 48.027950 38.742236 \n",
|
||
"std 153.404981 46.454741 8.709037 26.024095 26.166895 21.639327 \n",
|
||
"min 16.000000 1.000000 0.000000 0.000000 0.000000 0.000000 \n",
|
||
"25% 255.250000 64.000000 4.000000 30.250000 28.000000 22.000000 \n",
|
||
"50% 379.500000 96.000000 8.000000 48.000000 44.000000 35.000000 \n",
|
||
"75% 512.000000 137.000000 16.000000 69.000000 64.750000 53.000000 \n",
|
||
"max 687.000000 238.000000 40.000000 130.000000 121.000000 105.000000 \n",
|
||
"\n",
|
||
" Years CAtBat CHits CHmRun CRuns \\\n",
|
||
"count 322.000000 322.00000 322.000000 322.000000 322.000000 \n",
|
||
"mean 7.444099 2648.68323 717.571429 69.490683 358.795031 \n",
|
||
"std 4.926087 2324.20587 654.472627 86.266061 334.105886 \n",
|
||
"min 1.000000 19.00000 4.000000 0.000000 1.000000 \n",
|
||
"25% 4.000000 816.75000 209.000000 14.000000 100.250000 \n",
|
||
"50% 6.000000 1928.00000 508.000000 37.500000 247.000000 \n",
|
||
"75% 11.000000 3924.25000 1059.250000 90.000000 526.250000 \n",
|
||
"max 24.000000 14053.00000 4256.000000 548.000000 2165.000000 \n",
|
||
"\n",
|
||
" CRBI CWalks PutOuts Assists Errors \\\n",
|
||
"count 322.000000 322.000000 322.000000 322.000000 322.000000 \n",
|
||
"mean 330.118012 260.239130 288.937888 106.913043 8.040373 \n",
|
||
"std 333.219617 267.058085 280.704614 136.854876 6.368359 \n",
|
||
"min 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
|
||
"25% 88.750000 67.250000 109.250000 7.000000 3.000000 \n",
|
||
"50% 220.500000 170.500000 212.000000 39.500000 6.000000 \n",
|
||
"75% 426.250000 339.250000 325.000000 166.000000 11.000000 \n",
|
||
"max 1659.000000 1566.000000 1378.000000 492.000000 32.000000 \n",
|
||
"\n",
|
||
" Salary \n",
|
||
"count 263.000000 \n",
|
||
"mean 535.925882 \n",
|
||
"std 451.118681 \n",
|
||
"min 67.500000 \n",
|
||
"25% 190.000000 \n",
|
||
"50% 425.000000 \n",
|
||
"75% 750.000000 \n",
|
||
"max 2460.000000 \n",
|
||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||
"Index: 322 entries, -Andy Allanson to -Willie Wilson\n",
|
||
"Data columns (total 20 columns):\n",
|
||
" # Column Non-Null Count Dtype \n",
|
||
"--- ------ -------------- ----- \n",
|
||
" 0 AtBat 322 non-null int64 \n",
|
||
" 1 Hits 322 non-null int64 \n",
|
||
" 2 HmRun 322 non-null int64 \n",
|
||
" 3 Runs 322 non-null int64 \n",
|
||
" 4 RBI 322 non-null int64 \n",
|
||
" 5 Walks 322 non-null int64 \n",
|
||
" 6 Years 322 non-null int64 \n",
|
||
" 7 CAtBat 322 non-null int64 \n",
|
||
" 8 CHits 322 non-null int64 \n",
|
||
" 9 CHmRun 322 non-null int64 \n",
|
||
" 10 CRuns 322 non-null int64 \n",
|
||
" 11 CRBI 322 non-null int64 \n",
|
||
" 12 CWalks 322 non-null int64 \n",
|
||
" 13 League 322 non-null object \n",
|
||
" 14 Division 322 non-null object \n",
|
||
" 15 PutOuts 322 non-null int64 \n",
|
||
" 16 Assists 322 non-null int64 \n",
|
||
" 17 Errors 322 non-null int64 \n",
|
||
" 18 Salary 263 non-null float64\n",
|
||
" 19 NewLeague 322 non-null object \n",
|
||
"dtypes: float64(1), int64(16), object(3)\n",
|
||
"memory usage: 52.8+ KB\n",
|
||
"None\n",
|
||
"AtBat 0\n",
|
||
"Hits 0\n",
|
||
"HmRun 0\n",
|
||
"Runs 0\n",
|
||
"RBI 0\n",
|
||
"Walks 0\n",
|
||
"Years 0\n",
|
||
"CAtBat 0\n",
|
||
"CHits 0\n",
|
||
"CHmRun 0\n",
|
||
"CRuns 0\n",
|
||
"CRBI 0\n",
|
||
"CWalks 0\n",
|
||
"League 0\n",
|
||
"Division 0\n",
|
||
"PutOuts 0\n",
|
||
"Assists 0\n",
|
||
"Errors 0\n",
|
||
"Salary 59\n",
|
||
"NewLeague 0\n",
|
||
"dtype: int64\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(hitters.shape)\n",
|
||
"print(hitters.columns)\n",
|
||
"print(hitters.describe())\n",
|
||
"print(hitters.info())\n",
|
||
"print(hitters.isnull().sum())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"There are 20 features, all the features are in $\\mathbb{R}$. And there are 59 missing values in the column **Salary**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"As `Salary` is the **target** we want to predict, if it is missing for a player, we will remove this player.\n",
|
||
"\n",
|
||
"To simplify here, **we only take numeric features** and ignore factors (i.e. categorical attributes) like `League`, `Division` and `NewLeague`. \n",
|
||
"\n",
|
||
"**Remark :**\n",
|
||
"To handle the categorical features, one can use one-hot encoding, see https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Exercise 3** :\n",
|
||
"\n",
|
||
"(1) Remove the players for whom `Salary` is missing. \n",
|
||
"\n",
|
||
"- Hint : use `pandas.DataFrame.dropna`. https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.dropna.html\n",
|
||
"\n",
|
||
"\n",
|
||
"(2) `Salary` is the **target**, denoted as `Y` in the next cell. For the features ( `X` in the next cell, we remove `League`, `Division` and `NewLeague`.) \n",
|
||
"\n",
|
||
"- Hint : (1) You can use `dtypes == 'int64'` to select the integer-valued features. Alternative : use `select_dtypes`with `include=number`\n",
|
||
" (2) Use `pandas.DataFrame.loc` to access the dataframe. \n",
|
||
"https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.loc.html"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:53.191743Z",
|
||
"start_time": "2025-03-26T10:33:53.186755Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"0\n",
|
||
"(263, 16)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# We remove the players for whom Salary is missing\n",
|
||
"hitters.dropna(subset=[\"Salary\"], inplace=True)\n",
|
||
"\n",
|
||
"X = hitters.select_dtypes(include=int)\n",
|
||
"Y = hitters[\"Salary\"]\n",
|
||
"\n",
|
||
"# check-point\n",
|
||
"print(Y.isnull().sum()) # should be 0\n",
|
||
"print(X.shape) # should be (322-59, 20-4)=(263,16)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Exercise 4**\n",
|
||
"Split the data into a train set and a test set. Use 30% of the data for the test set and a random state = 42. (In the end you can try other values than 42). "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:53.367971Z",
|
||
"start_time": "2025-03-26T10:33:53.348332Z"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Answer for Exercise 4\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"\n",
|
||
"Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.3, random_state=42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Standardization**\n",
|
||
"We will standardize the data before applying Lasso or Ridge as is usually advised. For this we use the transformer `StandardScaler`. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:53.539084Z",
|
||
"start_time": "2025-03-26T10:33:53.531578Z"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"\n",
|
||
"scaler = StandardScaler()\n",
|
||
"XtrainScaled = scaler.fit_transform(Xtrain)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Let us check that the columns are now standardized :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:53.574140Z",
|
||
"start_time": "2025-03-26T10:33:53.568355Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),\n",
|
||
" array([ 2.41352831e-17, 9.65411326e-18, 1.93082265e-17, -1.20676416e-17,\n",
|
||
" 1.01368189e-16, 1.61706397e-16, 7.72329061e-17, 6.75787928e-17,\n",
|
||
" -1.93082265e-17, -3.86164530e-17, 5.79246795e-17, -3.86164530e-17,\n",
|
||
" 2.89623398e-17, -3.25826322e-17, -6.75787928e-17, 8.68870193e-17]))"
|
||
]
|
||
},
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"XtrainScaled.std(axis=0), XtrainScaled.mean(axis=0)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Note that the initial dataframe is now a numpy array :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:53.625854Z",
|
||
"start_time": "2025-03-26T10:33:53.619268Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"numpy.ndarray"
|
||
]
|
||
},
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"type(XtrainScaled)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Let us turn Ytrain into a numpy array as well :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:53.657967Z",
|
||
"start_time": "2025-03-26T10:33:53.654162Z"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"Ytrain = Ytrain.to_numpy()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 1. Ridge and Lasso <a class=\"anchor\" id=\"chapter1\"></a>"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Ridge**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Exercise 5** : (1) Create a Ridge regression model. Specify `fit_intercept=True`, use the default value `alpha=1` and call it `ridge`. Check what alpha corresponds to. \n",
|
||
"\n",
|
||
"Hint : Use `sklearn.linear_model.Ridge`.\n",
|
||
"\n",
|
||
"https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html\n",
|
||
"\n",
|
||
"(2) Fit `ridge` on the data `(XtrainScaled,Ytrain)`.\n",
|
||
"\n",
|
||
"(3) Display the estimated coefficients (`intercept` included !). "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:53.735321Z",
|
||
"start_time": "2025-03-26T10:33:53.685415Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[-215.12785148 290.20873256 53.45833136 -7.60173315 -38.54941926\n",
|
||
" 73.36945017 32.60415669 -323.39710533 116.22147257 26.20493566\n",
|
||
" 353.14584561 122.1676183 -121.03308298 93.78447846 36.78702805\n",
|
||
" -14.33540126] 535.83825\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.linear_model import Ridge\n",
|
||
"\n",
|
||
"ridge = Ridge(fit_intercept=True, alpha=1)\n",
|
||
"ridge.fit(XtrainScaled, Ytrain)\n",
|
||
"\n",
|
||
"print(ridge.coef_, ridge.intercept_)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Exercise 6** : (1) Create a Lasso regression model with `alpha=1` and `fit_intercept=True`, call it `lasso`. What does `alpha` correspond to ?\n",
|
||
"\n",
|
||
"Hint : use `sklearn.linear_model.Lasso` https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html\n",
|
||
"\n",
|
||
"(2) Fit `lasso` on the data `(XtrainScaled,Ytrain)`.\n",
|
||
"\n",
|
||
"(3) Display the estimated coefficients. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:53.759155Z",
|
||
"start_time": "2025-03-26T10:33:53.751938Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"array([-230.44564808, 301.60792783, 40.17406635, -10.38454116,\n",
|
||
" -22.84228549, 70.6672065 , 40.56803687, -395.18954578,\n",
|
||
" 36.12817932, -0. , 486.21184889, 170.22079734,\n",
|
||
" -133.5656457 , 94.08801971, 35.12092631, -8.15322646])"
|
||
]
|
||
},
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.linear_model import Lasso\n",
|
||
"\n",
|
||
"lasso = Lasso(fit_intercept=True, alpha=1)\n",
|
||
"lasso.fit(XtrainScaled, Ytrain)\n",
|
||
"\n",
|
||
"lasso.coef_"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"-------------------\n",
|
||
"\n",
|
||
"In the next exercise, we will display the variation of the Ridge estimator coefficients as a function of the regularization (`alpha`, *shrinkage parameter* ).\n",
|
||
"Let us generate 50 different values for alpha as follows :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:53.780921Z",
|
||
"start_time": "2025-03-26T10:33:53.777186Z"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"alphas = np.logspace(-4, 6, 50)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"*Info about logspace* (if necessary)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:53.807330Z",
|
||
"start_time": "2025-03-26T10:33:53.802462Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"array([ 1., 10., 100., 1000.])"
|
||
]
|
||
},
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"np.logspace(0, 3, 4)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:53.850865Z",
|
||
"start_time": "2025-03-26T10:33:53.845506Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"array([ 1., 10., 100., 1000.])"
|
||
]
|
||
},
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"10 ** (np.linspace(0, 3, 4))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:53.872536Z",
|
||
"start_time": "2025-03-26T10:33:53.868382Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"array([ True, True, True, True, True])"
|
||
]
|
||
},
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"np.logspace(-2, 2, 5) == 10 ** (np.linspace(-2, 2, 5))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.012679Z",
|
||
"start_time": "2025-03-26T10:33:53.899578Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[1.00000000e-04 1.59985872e-04 2.55954792e-04 4.09491506e-04\n",
|
||
" 6.55128557e-04 1.04811313e-03 1.67683294e-03 2.68269580e-03\n",
|
||
" 4.29193426e-03 6.86648845e-03 1.09854114e-02 1.75751062e-02\n",
|
||
" 2.81176870e-02 4.49843267e-02 7.19685673e-02 1.15139540e-01\n",
|
||
" 1.84206997e-01 2.94705170e-01 4.71486636e-01 7.54312006e-01\n",
|
||
" 1.20679264e+00 1.93069773e+00 3.08884360e+00 4.94171336e+00\n",
|
||
" 7.90604321e+00 1.26485522e+01 2.02358965e+01 3.23745754e+01\n",
|
||
" 5.17947468e+01 8.28642773e+01 1.32571137e+02 2.12095089e+02\n",
|
||
" 3.39322177e+02 5.42867544e+02 8.68511374e+02 1.38949549e+03\n",
|
||
" 2.22299648e+03 3.55648031e+03 5.68986603e+03 9.10298178e+03\n",
|
||
" 1.45634848e+04 2.32995181e+04 3.72759372e+04 5.96362332e+04\n",
|
||
" 9.54095476e+04 1.52641797e+05 2.44205309e+05 3.90693994e+05\n",
|
||
" 6.25055193e+05 1.00000000e+06]\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[<matplotlib.lines.Line2D at 0x30a321970>]"
|
||
]
|
||
},
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 2 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"print(alphas)\n",
|
||
"\n",
|
||
"fig, (ax1, ax2) = plt.subplots(1, 2)\n",
|
||
"ax1.plot(alphas, \".\")\n",
|
||
"ax2.plot(np.log10(alphas), \".\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Exercise 7** : (1) For each `alpha` in `alpha_s`, fit a ridge model. \n",
|
||
"You can use `ridge.set_params`. See : https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html. \n",
|
||
" \n",
|
||
"\n",
|
||
"(2) Plot the coefficients (without the intercept) w.r.t. `log10(alpha)`. What do you observe ? \n",
|
||
"\n",
|
||
"(3) (**Optional** ) Create a linear regression model, compare the coefficients given by the OLS estimator with those you get with a Ridge model with a small `alpha`. \n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.103061Z",
|
||
"start_time": "2025-03-26T10:33:54.029474Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.linear_model import LinearRegression\n",
|
||
"\n",
|
||
"coefs_ridge = []\n",
|
||
"for alpha in alphas:\n",
|
||
" ridge = Ridge(fit_intercept=True, alpha=alpha)\n",
|
||
" ridge.fit(XtrainScaled, Ytrain)\n",
|
||
" coefs_ridge.append(ridge.coef_)\n",
|
||
"\n",
|
||
"plt.plot(np.log10(alphas), coefs_ridge)\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.141323Z",
|
||
"start_time": "2025-03-26T10:33:54.136808Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[-212.17552979 307.84554012 70.19893491 -33.94121908 -61.98261284\n",
|
||
" 80.93907097 87.9354101 -854.87222296 301.36091121 -32.55986525\n",
|
||
" 575.96472518 276.24716079 -147.21083025 92.98658147 49.1730284\n",
|
||
" -8.23077394]\n",
|
||
"[-212.18873763 307.85651187 70.19704929 -33.94181179 -61.97778652\n",
|
||
" 80.94031337 87.91930611 -854.70162148 301.23727914 -32.56030861\n",
|
||
" 575.95219431 276.23085108 -147.21581673 92.98728123 49.16991958\n",
|
||
" -8.231394 ]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"linReg = LinearRegression(fit_intercept=True)\n",
|
||
"linReg.fit(XtrainScaled, Ytrain)\n",
|
||
"print(linReg.coef_)\n",
|
||
"print(coefs_ridge[0])\n",
|
||
"coefs_ridge = np.array(coefs_ridge)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Lasso**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"We will do the same questions as in Exercise 7 but for Lasso. \n",
|
||
"\n",
|
||
"For that, we will use `sklearn.linear_model.lasso_path`. \n",
|
||
"See https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.lasso_path.html \n",
|
||
"\n",
|
||
"Note that the output of `lasso_path` is a `tuple`. \n",
|
||
"You can read this example : https://scikit-learn.org/stable/auto_examples/linear_model/plot_lasso_lasso_lars_elasticnet_path.html#sphx-glr-auto-examples-linear-model-plot-lasso-lasso-lars-elasticnet-path-py\n",
|
||
"\n",
|
||
"(Also : note the shape of the array of coefficients output by `lasso_path`...)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.178048Z",
|
||
"start_time": "2025-03-26T10:33:54.165251Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"((16, 50), (16, 50))"
|
||
]
|
||
},
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.linear_model import lasso_path\n",
|
||
"\n",
|
||
"alphas_lasso, coefs_lasso, _ = lasso_path(XtrainScaled, Ytrain, n_alphas=50)\n",
|
||
"coefs_lasso.shape, coefs_lasso.shape"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"(in coefs_ridge the rows corresponded to the different values of alpha and the columns were the features, it is the opposite for the output of `lasso_path`).\n",
|
||
"\n",
|
||
"We used 50 values for alphas (those values are automatically determined by `lasso_path`). "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.259091Z",
|
||
"start_time": "2025-03-26T10:33:54.196219Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"coefs_lasso = coefs_lasso.T\n",
|
||
"plt.plot(np.log10(alphas_lasso), coefs_lasso)\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.349032Z",
|
||
"start_time": "2025-03-26T10:33:54.282613Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[<matplotlib.lines.Line2D at 0x30a3f6c60>]"
|
||
]
|
||
},
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"# the values of alphas chosen by defaults are also on a logarithmic scale\n",
|
||
"plt.plot(np.log10(alphas_lasso), \".\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.502618Z",
|
||
"start_time": "2025-03-26T10:33:54.416713Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"Text(0, 0.5, 'Lasso coefficients')"
|
||
]
|
||
},
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 800x600 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"fig, ax = plt.subplots(figsize=(8, 6))\n",
|
||
"ax.plot(np.log10(alphas_lasso), coefs_lasso)\n",
|
||
"ax.set_xlabel(\"log10(alpha)\")\n",
|
||
"ax.set_ylabel(\"Lasso coefficients\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"-------------------\n",
|
||
"\n",
|
||
"Now let us show that **Lasso will help us select the features**, i.e. some coefficients will be set to 0 when `alpha` increases. \n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Exercise 8** : Compute the number `nb` of zeros among the **Lasso** coefficients for each `alpha`. (Hint in the next cell)\n",
|
||
"\n",
|
||
"Plot `nb` w.r.t. `log10(alphas)`.\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.530858Z",
|
||
"start_time": "2025-03-26T10:33:54.527857Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"1.\n",
|
||
" [[0 1 1]\n",
|
||
" [0 1 1]\n",
|
||
" [0 1 0]]\n",
|
||
"2.\n",
|
||
" [[ True False False]\n",
|
||
" [ True False False]\n",
|
||
" [ True False True]]\n",
|
||
"3. Le nombre de 0 de chaque colonne est :\n",
|
||
" [3 0 1]\n",
|
||
"4. Le nombre de 0 de chaque ligne est : \n",
|
||
" [1 1 2]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"ind = np.array([[0, 1, 1], [0, 1, 1], [0, 1, 0]])\n",
|
||
"\n",
|
||
"print(\"1.\\n\", ind)\n",
|
||
"print(\"2.\\n\", ind == 0)\n",
|
||
"print(\"3. Le nombre de 0 de chaque colonne est :\\n \", (ind == 0).sum(axis=0))\n",
|
||
"print(\"4. Le nombre de 0 de chaque ligne est : \\n\", (ind == 0).sum(axis=1))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.614964Z",
|
||
"start_time": "2025-03-26T10:33:54.561469Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAe50lEQVR4nO3dDZCV1X0/8N8VdhfWgQ1gfaEuK+n4EoNBK76WGEgoSpVoHFqt1lLTdrTxJbgZI1S0apNs7bSWTqgyZjqaNjVxGoWYmop0IqADpu4q6UsaCQkBJ0ppHMrKSxdc7n+ep//durILu8u9nN17P5+ZM7v3uefec/Zwdp8v53me+xSKxWIxAAASOSZVwwAAGWEEAEhKGAEAkhJGAICkhBEAIClhBABIShgBAJISRgCApEbGEHPgwIF48803Y8yYMVEoFFJ3BwDoh+wzVN95552YOHFiHHPMMcM7jGRBpLGxMXU3AIBBeOONN+Lkk08e3mEkWxHp+mHGjh2bujsAQD+0t7fniwld+/FhHUa6Ds1kQUQYAYDhZTCnWDiBFQBIShgBAJISRgCApIQRACApYQQASEoYAQCSEkYAgKSEEQAgKWEEABheYWTt2rUxd+7c/EY42aesrVix4qA6//Ef/xGf/OQno6GhIf9Y2AsvvDC2bt1aqj4DANUcRnbv3h1Tp06NpUuX9vr8j3/845g+fXqcccYZsXr16vj+978f99xzT4waNaoU/QUAKkyhmN3zd7AvLhRi+fLlcdVVV3Vvu/baa6Ompib+9m//dtA32slWVHbu3OneNABQQtkuf+/+zvz70TUjBnUfmXLsv0t6zsiBAwfi2WefjdNOOy0uvfTSOP744+OCCy7o9VBOl46OjvwHeG8BAEovCyJn3rsyL12hZCgoaRjZvn177Nq1K/7kT/4kLrvssnj++efjU5/6VFx99dWxZs2aXl/T0tKSJ6mukt1+GACoHiVfGclceeWVcccdd8TZZ58dCxcujCuuuCKWLVvW62sWLVqUL+l0lTfeeKOUXQIAhriRpXyz4447LkaOHBlnnnlmj+0f+tCH4qWXXur1NXV1dXkBAKpTSVdGamtr47zzzovXX3+9x/aNGzdGU1NTKZsCAKp1ZSQ7J2TTpk3djzdv3hwbNmyI8ePHx6RJk+LOO++Ma665Ji655JKYOXNmPPfcc/Htb387v8wXAOCIw0hra2seMro0NzfnX+fPnx+PP/54fsJqdn5IdmLq7bffHqeffno89dRT+WePAAAccRiZMWNGfp3yoXz605/OCwDA4bg3DQCQlDACACQljAAASQkjAEBSwggAkJQwAgAkJYwAAEkJIwBAUsIIAJCUMAIAJCWMAABJCSMAQFLCCACQlDACACQljAAASQkjAEBSwggAkJQwAgAkJYwAAEkJIwBAUsIIAJCUMAIAJCWMAABJCSMAQFLCCACQlDACACQljAAASQkjAEBSwggAkJQwAgAkJYwAAMMrjKxduzbmzp0bEydOjEKhECtWrOiz7k033ZTXWbJkyZH2EwCoUAMOI7t3746pU6fG0qVLD1kvCynf+9738tACANCXkTFAc+bMycuh/OxnP4tbb701Vq5cGZdffvlAmwAABqFYLMbe/Z19Pr9nX9/PDaswcjgHDhyIG264Ie6888748Ic/fNj6HR0deenS3t5e6i4BQFUEkXnL1kfblh0R1X4C64MPPhgjR46M22+/vV/1W1paoqGhobs0NjaWuksAUPH27u/sdxCZ1jQuRteMiIpcGWlra4u//Mu/jFdffTU/cbU/Fi1aFM3NzT1WRgQSABi81sWzor6277CRBZH+7qeHXRh58cUXY/v27TFp0qTubZ2dnfG5z30uv6Lmpz/96UGvqaurywsAUBpZEKmvLfmZGGVT0p5m54rMmjWrx7ZLL700337jjTeWsikAoEIMOIzs2rUrNm3a1P148+bNsWHDhhg/fny+IjJhwoQe9WtqauLEE0+M008/vTQ9BgCqO4y0trbGzJkzux93ne8xf/78ePzxx0vbOwCg4g04jMyYMSO/fKi/ejtPBACgi3vTAABJCSMAQFLCCACQlDACACQljAAASQkjAEBSwggAkJQwAgAkJYwAAEkJIwBAUsIIAJCUMAIADK8b5QEA/ZPdWHbv/s6j0taefUennXIQRgCgTEFk3rL10bZlR+quDHkO0wBAGWQrIimCyLSmcTG6ZkQMJ1ZGAKDMWhfPivraoxMQRteMiEKhEMOJMAIAZZYFkfpau9y+OEwDACQljAAASQkjAEBSwggAkJQwAgAkJYwAAEkJIwBAUsIIAJCUMAIAJCWMAABJCSMAQFLCCACQlDACACQljAAAwyuMrF27NubOnRsTJ06MQqEQK1as6H5u//79cdddd8VZZ50Vxx57bF7nt3/7t+PNN98sdb8BgGoNI7t3746pU6fG0qVLD3puz5498eqrr8Y999yTf3366adj48aN8clPfrJU/QUAKszIgb5gzpw5eelNQ0NDrFq1qse2L3/5y3H++efH1q1bY9KkSYPvKQCUQLFYjL37O8vezp595W+jasPIQO3cuTM/nPOBD3yg1+c7Ojry0qW9vb3cXQKgioPIvGXro23LjtRd4WidwPo///M/sXDhwrjuuuti7NixvdZpaWnJV1S6SmNjYzm7BEAVy1ZEjnYQmdY0LkbXjDiqbQ43ZVsZyU5mvfbaa+PAgQPx8MMP91lv0aJF0dzc3GNlRCABoNxaF8+K+tryh4QsiGRHCDjKYSQLIr/xG78Rmzdvju9+97t9ropk6urq8gIAR1MWROpry362Av0wslxB5Ec/+lG88MILMWHChFI3AQBUcxjZtWtXbNq0qftxtvqxYcOGGD9+fP65IvPmzcsv6/2Hf/iH6OzsjG3btuX1sudra2tL23sAoPrCSGtra8ycObP7cdf5HvPnz4/77rsvnnnmmfzx2Wef3eN12SrJjBkzjrzHAEB1h5EsUGSXRvXlUM8BALyfe9MAAEkJIwBAUsIIAJCUMAIAJCWMAABJCSMAQFLCCACQlDACACQljAAASQkjAEBSwggAkJQwAgAkJYwAAEkJIwBAUsIIAJCUMAIAJCWMAABJCSMAQFLCCACQlDACACQljAAASQkjAEBSwggAkJQwAgAkJYwAAEkJIwBAUsIIAJCUMAIAJCWMAABJCSMAQFLCCAAwvMLI2rVrY+7cuTFx4sQoFAqxYsWKHs8Xi8W477778udHjx4dM2bMiH//938vZZ8BgGoOI7t3746pU6fG0qVLe33+T//0T+Ohhx7Kn3/llVfixBNPjF/91V+Nd955pxT9BQAqzMiBvmDOnDl56U22KrJkyZK4++674+qrr863ffWrX40TTjghnnjiibjpppuOvMcADBvZfmHv/s4YKvbsGzp94QjCyKFs3rw5tm3bFrNnz+7eVldXFx/72Mdi3bp1vYaRjo6OvHRpb28vZZcASBhE5i1bH21bdqTuCtV0AmsWRDLZSsh7ZY+7nnu/lpaWaGho6C6NjY2l7BIAiWQrIkM1iExrGheja0ak7gblWBnpkp3Y+v50/P5tXRYtWhTNzc09VkYEEoDK0rp4VtTXDp2dfxZE+tovMczDSHayaiZbBTnppJO6t2/fvv2g1ZL3HsbJCgCVKwsi9bVl+f8vFaCkh2kmT56cB5JVq1Z1b9u3b1+sWbMmLr744lI2BQBUiAHH1F27dsWmTZt6nLS6YcOGGD9+fEyaNCkWLFgQX/rSl+LUU0/NS/Z9fX19XHfddaXuOwBQjWGktbU1Zs6c2f2463yP+fPnx+OPPx6f//znY+/evfGZz3wmduzYERdccEE8//zzMWbMmNL2HACoCIVidnbpEJKdwJpdVbNz584YO3Zs6u4AMEh79r0bZ967Mv/+Bw9c6pyRCtd+BPtv96YBAJISRgCApIQRACApYQQASEoYAQCSEkYAgKSEEQAgKWEEAEhKGAEAkhJGAICkhBEAIClhBABIyl2LAMokuw/p3v2dUa327Kven52BEUYAyhRE5i1bH21bdqTuCgx5DtMAlEG2IiKI/K9pTeNidM2I1N1gCLMyAlBmrYtnRX1t9e6MsyBSKBRSd4MhTBgBKLMsiNTX+nMLfXGYBgBIShgBAJISRgCApIQRACApYQQASEoYAQCSEkYAgKSEEQAgKWEEAEhKGAEAkhJGAICkhBEAIClhBABIShgBACorjLz77ruxePHimDx5cowePTo++MEPxgMPPBAHDhwodVMAQAUYWeo3fPDBB2PZsmXx1a9+NT784Q9Ha2tr3HjjjdHQ0BCf/exnS90cADDMlTyMrF+/Pq688sq4/PLL88ennHJKfP3rX89DCcD7FYvF2Lu/MyrNnn2V9zPBsAkj06dPz1dGNm7cGKeddlp8//vfj5deeimWLFnSa/2Ojo68dGlvby91l4AhHETmLVsfbVt2pO4KUElh5K677oqdO3fGGWecESNGjIjOzs744he/GL/5m7/Za/2Wlpa4//77S90NYBjIVkQqPYhMaxoXo2tGpO4GVFcYefLJJ+NrX/taPPHEE/k5Ixs2bIgFCxbExIkTY/78+QfVX7RoUTQ3N/dYGWlsbCx1t4AhrnXxrKivrbyddhZECoVC6m5AdYWRO++8MxYuXBjXXntt/viss86KLVu25CsgvYWRurq6vADVLQsi9bUl/5MEVOOlvXv27Iljjun5ttnhGpf2AgC9Kfl/Q+bOnZufIzJp0qT8MM1rr70WDz30UHz6058udVMAQAUoeRj58pe/HPfcc0985jOfie3bt+fnitx0001x7733lropAKAClDyMjBkzJr+Mt69LeQEA3su9aQCApIQRACApYQQASEoYAQCSEkYAgKSEEQAgKWEEAEhKGAEAkhJGAICkhBEAIClhBABIShgBACrrRnnA0FUsFmPv/s4YKvbsGzp9AdIRRqCKgsi8ZeujbcuO1F0B6MFhGqgS2YrIUA0i05rGxeiaEam7ASRiZQSqUOviWVFfO3R2/lkQKRQKqbsBJCKMQBXKgkh9rV9/YGhwmAYASEoYAQCSEkYAgKSEEQAgKWEEAEhKGAEAkhJGAICkhBEAIClhBABIShgBAJISRgCApIQRACApYQQASEoYAQAqL4z87Gc/i9/6rd+KCRMmRH19fZx99tnR1tZWjqYAgGFuZKnfcMeOHfErv/IrMXPmzPjHf/zHOP744+PHP/5xfOADHyh1UwBABSh5GHnwwQejsbExHnvsse5tp5xySqmbAQAqRMkP0zzzzDMxbdq0+PVf//V8VeScc86Jr3zlK33W7+joiPb29h4FAKgeJQ8jP/nJT+KRRx6JU089NVauXBk333xz3H777fE3f/M3vdZvaWmJhoaG7pKtqgAA1aNQLBaLpXzD2trafGVk3bp13duyMPLKK6/E+vXre10ZyUqXbGUkCyQ7d+6MsWPHlrJrUNX27Hs3zrx3Zf79Dx64NOprS36UFqhi7e3t+aLCYPbfJV8ZOemkk+LMM8/sse1DH/pQbN26tdf6dXV1eaffWwCA6lHyMJJdSfP666/32LZx48ZoamoqdVMAQAUoeRi544474uWXX44vfelLsWnTpnjiiSfi0UcfjVtuuaXUTQEAFaDkYeS8886L5cuXx9e//vWYMmVK/PEf/3EsWbIkrr/++lI3BQBUgLKcwXbFFVfkBQDgcNybBgBIShgBAJISRgCApIQRACApYQQASEoYAQCSEkYAgKSEEQAgKWEEAEhKGAEAkhJGAICkhBEAoPJulAf0X7FYjL37O8vezp595W8DYDCEEUgcROYtWx9tW3ak7gpAMg7TQELZisjRDiLTmsbF6JoRR7VNgEOxMgJDROviWVFfW/6QkAWRQqFQ9nYA+ksYgSEiCyL1tX4lgerjMA0AkJQwAgAkJYwAAEkJIwBAUsIIAJCUMAIAJCWMAABJCSMAQFLCCACQlDACACQljAAASQkjAEBSwggAkJQwAgBUdhhpaWmJQqEQCxYsKHdTAMAwVNYw8sorr8Sjjz4aH/nIR8rZDAAwjI0s1xvv2rUrrr/++vjKV74SX/jCF8rVDCRRLBZj7/7OI36fPfuO/D0AhruyhZFbbrklLr/88pg1a9Yhw0hHR0deurS3t5erS1CyIDJv2fpo27IjdVcAKkJZwsg3vvGNePXVV/PDNP05p+T+++8vRzegLLIVkVIHkWlN42J0zYiSvidA1YaRN954Iz772c/G888/H6NGjTps/UWLFkVzc3OPlZHGxsZSdwvKonXxrKivPfIQkQWR7ERvgGpU8jDS1tYW27dvj3PPPbd7W2dnZ6xduzaWLl2aH5IZMeL//njX1dXlBYajLIjU15btaCdAVSj5X9FPfOIT8a//+q89tt14441xxhlnxF133dUjiAAAlDyMjBkzJqZMmdJj27HHHhsTJkw4aDsAgE9gBQCSOioHu1evXn00mgEAhiErIwBAUsIIAJCUMAIAJCWMAABJCSMAQFLCCACQlDACACQljAAASQkjAEBSwggAkJQwAgAkJYwAAJV/ozw4EsViMfbu7yzZ+42uGRGFQmHQ7e3ZV7q+ACCMMMRlwWDesvXRtmVHyd5zWtO4+PubL+o1kJSjPQAOzWEahrRshaLUwaB1y44+Vz4G0l4WarJVFgCOjJURho3WxbOivnbwO//s8Mq0L/xTydo73OEeAPpHGGHYyIJBfe3Iim0PoFo5TAMAJCWMAABJCSMAQFLCCACQlDACACQljAAASQkjAEBSwggAkJQwAgAkJYwAAEkJIwBAUsIIAJCUMAIAJCWMAACVFUZaWlrivPPOizFjxsTxxx8fV111Vbz++uulbgYAqBAlDyNr1qyJW265JV5++eVYtWpVvPvuuzF79uzYvXt3qZsCACrAyFK/4XPPPdfj8WOPPZavkLS1tcUll1xS6uYY5orFYuzd39nn83v29f3ckejrfcvVHgBHMYy8386dO/Ov48eP7/X5jo6OvHRpb28vd5cYQkFk3rL10bZlx1Fve9oX/umotwlAghNYs51Nc3NzTJ8+PaZMmdLnOSYNDQ3dpbGxsZxdYgjJVkT6G0SmNY2L0TUjjqi97PXZ+xyt9gDon0IxSwxlkp078uyzz8ZLL70UJ598cr9XRrJAkq2ojB07tlxdYwjYs+/dOPPelfn3rYtnRX1t3zv/LBgUCoWyHxYqdXsA1aK9vT1fVBjM/rtsh2luu+22eOaZZ2Lt2rV9BpFMXV1dXqhuWRCpry37UcM8YByNdgDov5L/Vc7+55kFkeXLl8fq1atj8uTJpW4CAKggI8txaOaJJ56Ib33rW/lnjWzbti3fni3djB49utTNAQDDXMlPYH3kkUfy40UzZsyIk046qbs8+eSTpW4KAKgAZTlMAwDQX+5NAwAkJYwAAEkJIwBAUsIIAJCUMAIAJCWMAABJCSMAQFLCCACQlDACACQljAAASQkjAEBSwggAUFk3yqPnTQP37u9M3Y0ha88+YwOAMFLWIDJv2fpo27IjdVcAYEhzmKZMshURQaR/pjWNi9E1I1J3A4BErIwcBa2LZ0V9rZ1tX7IgUigUUncDgESEkaMgCyL1tYYaAHrjMA0AkJQwAgAkJYwAAEkJIwBAUsIIAJCUMAIAJCWMAABJCSMAQFLCCACQlDACACQljAAASQkjAEBSwggAkJQwAgBUZhh5+OGHY/LkyTFq1Kg499xz48UXXyxXUwDAMFaWMPLkk0/GggUL4u67747XXnstPvrRj8acOXNi69at5WgOABjGRpbjTR966KH43d/93fi93/u9/PGSJUti5cqV8cgjj0RLS0ukUCwWY+/+zqPW3p59R68tABjOSh5G9u3bF21tbbFw4cIe22fPnh3r1q07qH5HR0deurS3t0c5ZEHkzHtXluW9AYAhdJjm5z//eXR2dsYJJ5zQY3v2eNu2bQfVz1ZKGhoauktjY2NUkmlN42J0zYjU3QCA6jpMkykUCgcdJnn/tsyiRYuiubm5x8pIOQJJFgh+8MClJX/f/rTb288NAJQpjBx33HExYsSIg1ZBtm/fftBqSaauri4v5ZYFgvrasmUvAGCoHKapra3NL+VdtWpVj+3Z44svvrjUzQEAw1xZlgqywy433HBDTJs2LS666KJ49NFH88t6b7755nI0BwAMY2UJI9dcc028/fbb8cADD8Rbb70VU6ZMie985zvR1NRUjuYAgGGsUMzOLB1CshNYs6tqdu7cGWPHjk3dHQCgzPtv96YBAJISRgCApIQRACApYQQASEoYAQCSEkYAgKSEEQAgKWEEAEhKGAEAkhpyt7Ht+kDY7JPcAIDhoWu/PZgPdh9yYeSdd97JvzY2NqbuCgAwiP149rHww/reNAcOHIg333wzxowZE4VCIYZa6stC0htvvOG+OYdhrPrPWPWfsRoY49V/xurIxyqLE1kQmThxYhxzzDHDe2Uk+wFOPvnkGMqywTdZ+8dY9Z+x6j9jNTDGq/+M1ZGN1UBXRLo4gRUASEoYAQCSEkYGoK6uLv7oj/4o/8qhGav+M1b9Z6wGxnj1n7FKO1ZD7gRWAKC6WBkBAJISRgCApIQRACApYQQASEoYOYwdO3bEDTfckH+QS1ay7//7v//7kK/5nd/5nfzTY99bLrzwwqg0Dz/8cEyePDlGjRoV5557brz44ouHrL9mzZq8Xlb/gx/8YCxbtiyqxUDGavXq1QfNn6z88Ic/jEq3du3amDt3bv4JjtnPvGLFisO+plrn1UDHqprnVUtLS5x33nn5J3sff/zxcdVVV8Xrr79+2NdV49xqGcRYlWJuCSOHcd1118WGDRviueeey0v2fRZIDueyyy6Lt956q7t85zvfiUry5JNPxoIFC+Luu++O1157LT760Y/GnDlzYuvWrb3W37x5c/zar/1aXi+r/4d/+Idx++23x1NPPRWVbqBj1SX7A/DeOXTqqadGpdu9e3dMnTo1li5d2q/61TyvBjpW1TyvslBxyy23xMsvvxyrVq2Kd999N2bPnp2PYV+qdW6tGcRYlWRuZZf20rsf/OAH2WXPxZdffrl72/r16/NtP/zhD/t83fz584tXXnllsZKdf/75xZtvvrnHtjPOOKO4cOHCXut//vOfz59/r5tuuql44YUXFivdQMfqhRdeyOfYjh07itUsG4Ply5cfsk41z6uBjpV59X+2b9+ej8WaNWv6rGNu9X+sSjG3rIwcwvr16/NDMxdccEH3tuxwS7Zt3bp1h122ypa4TjvttPj93//92L59e1SKffv2RVtbW56W3yt73Ne4ZGP5/vqXXnpptLa2xv79+6NSDWasupxzzjlx0kknxSc+8Yl44YUXytzT4ala59WRMK8idu7cmX8dP358n3XMrf6PVSnmljByCNu2bcsDxftl27Ln+pItwf/d3/1dfPe7340///M/j1deeSU+/vGPR0dHR1SCn//859HZ2RknnHBCj+3Z477GJdveW/1sCTB7v0o1mLHKfpkfffTRfDn46aefjtNPPz3/5c7OEaCnap1Xg2Fe/a9sIam5uTmmT58eU6ZM6bOeuRX9HqtSzK0hd9feo+G+++6L+++//5B1sgCRyU7C6e0fqLftXa655pru77N/wGnTpkVTU1M8++yzcfXVV0eleP8YHG5ceqvf2/ZKNJCxyn6Rs9Lloosuym/V/Wd/9mdxySWXlL2vw001z6uBMK/+16233hr/8i//Ei+99NJh61b73Lq1n2NVirlVlWEkG+Brr732kHVOOeWU/B/hP//zPw967r/+678OSsyHkqXGLIz86Ec/ikpw3HHHxYgRIw76n312KKqvcTnxxBN7rT9y5MiYMGFCVKrBjFVvssODX/va18rQw+GtWudVqVTbvLrtttvimWeeyf/HfvLJJx+ybrXPrdsGMFalmFtVGUayHURWDidLd9nxsn/+53+O888/P9/2ve99L9928cUX97u9t99+O0+JWSipBLW1tfnlbtmZ1p/61Ke6t2ePr7zyyj7H8tvf/naPbc8//3y+alRTUxOVajBj1ZvsbP5KmT+lVK3zqlSqZV5lKxrZznX58uX5+XzZZfaHU61zqziIsSrJ3Br0qa9V4rLLLit+5CMfya+iycpZZ51VvOKKK3rUOf3004tPP/10/v0777xT/NznPldct25dcfPmzflZxhdddFHxF3/xF4vt7e3FSvGNb3yjWFNTU/zrv/7r/KqjBQsWFI899tjiT3/60/z57EqRG264obv+T37yk2J9fX3xjjvuyOtnr8te/81vfrNY6QY6Vn/xF3+RXxmxcePG4r/927/lz2e/qk899VSx0mW/P6+99lpesp/5oYceyr/fsmVL/rx5NfixquZ59Qd/8AfFhoaG4urVq4tvvfVWd9mzZ093HXNr8GNVirkljBzG22+/Xbz++uuLY8aMyUv2/fsvX8oG/bHHHsu/z/7BZs+eXfyFX/iFfOJOmjQpv9R369atxUrzV3/1V8WmpqZibW1t8Zd/+Zd7XPqV/cwf+9jHetTPJvc555yT1z/llFOKjzzySLFaDGSsHnzwweIv/dIvFUeNGlUcN25ccfr06cVnn322WA26LhF8f8nGKGNeDX6sqnle9TZO7/27nTG3Bj9WpZhbhf/fOABAEi7tBQCSEkYAgKSEEQAgKWEEAEhKGAEAkhJGAICkhBEAIClhBABIShgBAJISRgCApIQRACApYQQAiJT+H3FIfHQBLrjJAAAAAElFTkSuQmCC",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"nb_lasso = np.array(coefs_lasso == 0).sum(axis=1)\n",
|
||
"plt.step(np.log10(alphas_lasso), nb_lasso)\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Exercise 9** : Compute the number `nb` of 0 among the **Ridge** coefficients for each `alpha`. \n",
|
||
"\n",
|
||
"Plot `nb` w.r.t. `log10(alphas)`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.687251Z",
|
||
"start_time": "2025-03-26T10:33:54.639809Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"nb_ridge = np.array(coefs_ridge == 0).sum(axis=1)\n",
|
||
"plt.plot(np.log10(alphas), nb_ridge)\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"----------------\n",
|
||
"\n",
|
||
"\n",
|
||
"## 2. Cross Validation for the Lasso and Ridge hyperparameter $\\alpha$ <a class=\"anchor\" id=\"chapter2\"></a>"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"In the next exercises, we will select an optimal `alpha` among the values in `alpha_s` by cross-validation.\n",
|
||
"\n",
|
||
"We will use `sklearn.linear_model.RidgeCV` and `sklearn.linear_model.LassoCV`.\n",
|
||
"\n",
|
||
"Reference :\n",
|
||
"\n",
|
||
"1. `sklearn.linear_model.RidgeCV` : https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.RidgeCV.html\n",
|
||
"\n",
|
||
"2. `sklearn.linear_model.LassoCV` : https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LassoCV.html\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.710558Z",
|
||
"start_time": "2025-03-26T10:33:54.707998Z"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"alphas = np.logspace(-4, 6, 50)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Exercise 10** : (1) Create a model for `RidgeCV`, call it `ridgeCV`, with parameters `alphas=alpha_s` and `store_cv_values=True`. Fit `ridgeCV` on `(XtrainScaled,Ytrain)`. \n",
|
||
"\n",
|
||
"Remark : to be completely rigorous, we should have created a pipeline containing the scaler to avoid what we call \"data leakage\" in the CV calculation. We will ignore this detail here to simplify. \n",
|
||
"\n",
|
||
"(2) We have 50 different values for `alpha`. For each `alpha`, how many `score`s do we get for the cross-validation ? Read the info concerning the parameter `cv_values_` of `RidgeCV` in https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.RidgeCV.html \n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 29,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.738765Z",
|
||
"start_time": "2025-03-26T10:33:54.732482Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[1.00000000e-04 1.59985872e-04 2.55954792e-04 4.09491506e-04\n",
|
||
" 6.55128557e-04 1.04811313e-03 1.67683294e-03 2.68269580e-03\n",
|
||
" 4.29193426e-03 6.86648845e-03 1.09854114e-02 1.75751062e-02\n",
|
||
" 2.81176870e-02 4.49843267e-02 7.19685673e-02 1.15139540e-01\n",
|
||
" 1.84206997e-01 2.94705170e-01 4.71486636e-01 7.54312006e-01\n",
|
||
" 1.20679264e+00 1.93069773e+00 3.08884360e+00 4.94171336e+00\n",
|
||
" 7.90604321e+00 1.26485522e+01 2.02358965e+01 3.23745754e+01\n",
|
||
" 5.17947468e+01 8.28642773e+01 1.32571137e+02 2.12095089e+02\n",
|
||
" 3.39322177e+02 5.42867544e+02 8.68511374e+02 1.38949549e+03\n",
|
||
" 2.22299648e+03 3.55648031e+03 5.68986603e+03 9.10298178e+03\n",
|
||
" 1.45634848e+04 2.32995181e+04 3.72759372e+04 5.96362332e+04\n",
|
||
" 9.54095476e+04 1.52641797e+05 2.44205309e+05 3.90693994e+05\n",
|
||
" 6.25055193e+05 1.00000000e+06]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.linear_model import RidgeCV\n",
|
||
"\n",
|
||
"ridgeCV = RidgeCV(alphas=alphas, store_cv_values=True)\n",
|
||
"ridgeCV.fit(XtrainScaled, Ytrain)\n",
|
||
"\n",
|
||
"print(ridgeCV.alphas)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Exercise 11** : For each alpha, give the mean of the scores, and call it `alpha_score`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.765343Z",
|
||
"start_time": "2025-03-26T10:33:54.762563Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"1\n",
|
||
" (2, 3)\n",
|
||
"2\n",
|
||
" [2.5 3.5 4.5]\n",
|
||
"3\n",
|
||
" [2. 5.]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"ind2 = np.array([[1, 2, 3], [4, 5, 6]])\n",
|
||
"\n",
|
||
"print(\"1\\n\", ind2.shape)\n",
|
||
"print(\"2\\n\", ind2.mean(axis=0))\n",
|
||
"print(\"3\\n\", ind2.mean(axis=1))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.804392Z",
|
||
"start_time": "2025-03-26T10:33:54.799816Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"(50,)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"alpha_score = np.array(ridgeCV.cv_values_).mean(axis=0)\n",
|
||
"print(alpha_score.shape)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Exercise 12** : Plot `alpha_score` w.r.t. `log10(alpha_s)`. Which `alpha` are we going to choose ? Display the coefficients for the chosen `alpha`.\n",
|
||
"\n",
|
||
"Hint : read the attributes `alpha_` and `coef_` of ridgecv."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 32,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.888467Z",
|
||
"start_time": "2025-03-26T10:33:54.839691Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"plt.plot(np.log10(alphas), alpha_score)\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 33,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:54.917915Z",
|
||
"start_time": "2025-03-26T10:33:54.915253Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"best alpha for ridge : 132.57113655901108\n",
|
||
"ridge coef for the best alpha : [18.60099315 44.83916896 19.65511565 31.2895848 19.53291699 23.82494567\n",
|
||
" 17.58025351 24.95507189 34.79628953 25.54796552 36.50612186 32.17517034\n",
|
||
" 17.53213942 56.00516056 0.58648412 -8.49246715]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(\"best alpha for ridge : \", ridgeCV.alpha_)\n",
|
||
"print(\"ridge coef for the best alpha : \", ridgeCV.coef_)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"We will do the same for **Lasso** using `LassoCV`. \n",
|
||
"\n",
|
||
"**Remark** : By default, `LassoCV` uses 5-fold cross-validation, which is different from *Leave-One-Out Cross-Validation* used in `RidgeCV`. \n",
|
||
"\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:55.011780Z",
|
||
"start_time": "2025-03-26T10:33:54.966551Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(50, 5)"
|
||
]
|
||
},
|
||
"execution_count": 34,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.linear_model import LassoCV\n",
|
||
"\n",
|
||
"lassoCV = LassoCV(n_alphas=50)\n",
|
||
"lassoCV.fit(XtrainScaled, Ytrain)\n",
|
||
"lassoCV.mse_path_.shape"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Il y a **50** alphas et **5** K"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"\"5\" corresponds to the 5-folds and \"50\" corresponds to the 50 values of alpha. Here we let LassoCV choose those 50 values. \n",
|
||
"\n",
|
||
"Pay attention to the fact that, again, the rows and the columns are reversed compared to RidgeCV output : that is why we use the parameter \"axis=1\" in the next cell. \n",
|
||
"\n",
|
||
"NB : the default value for cv is 5 here (using leave-one-out for Lasso would be too costly (no fast formula, contrary to Ridge)). "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 35,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:55.023571Z",
|
||
"start_time": "2025-03-26T10:33:55.020025Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(50,)"
|
||
]
|
||
},
|
||
"execution_count": 35,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"alpha_score = lassoCV.mse_path_.mean(axis=1)\n",
|
||
"alpha_score.shape"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 36,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:55.115836Z",
|
||
"start_time": "2025-03-26T10:33:55.049527Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[<matplotlib.lines.Line2D at 0x30a361f70>]"
|
||
]
|
||
},
|
||
"execution_count": 36,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"alphas = lassoCV.alphas_\n",
|
||
"plt.plot(np.log10(alphas), alpha_score)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 37,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:55.145416Z",
|
||
"start_time": "2025-03-26T10:33:55.142140Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"best alpha for lasso : 4.783179587705159\n",
|
||
"lasso coef for the best alpha : [-99.92030224 202.5757594 25.01239288 0. -0.\n",
|
||
" 32.47945631 0. -0. 0. 27.91148338\n",
|
||
" 150.0851833 27.07622842 -0. 86.82862045 0.\n",
|
||
" -8.03825832]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(\"best alpha for lasso : \", lassoCV.alpha_)\n",
|
||
"print(\"lasso coef for the best alpha : \", lassoCV.coef_)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## OPTIONAL : Comparing estimators"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Remark about the standardization : the standardization needs to be done in three steps : \n",
|
||
"- first we use `fit` on the train set, which means the standardizer computes the mean and the std for each feature in your train set.\n",
|
||
"- Then you transform the data : it means, for each column, you remove the mean computed before (i.e. the means computed on the **train** set) and divide by the std computed before (i.e. the std computed on the **train** set).\n",
|
||
"- But pay attention to the fact that you must transform the train set **and** the test set. But the fit has to be done on the train set only !\n",
|
||
"\n",
|
||
"Previously, we used the shorter syntax `fit_transform` that allows you to do the `transform` calculation and then to `fit`. We could do that because we only used the train set. But for the test set, we must only use `transform`. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 38,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:33:55.172132Z",
|
||
"start_time": "2025-03-26T10:33:55.168656Z"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"XtestScaled = scaler.transform(Xtest)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Prediction of Salary on the test set using the OLS estimator**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 39,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:39:51.250274Z",
|
||
"start_time": "2025-03-26T10:39:51.153010Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(-47.125, 2529.625, -84.28768799283341, 2492.4623120071665)"
|
||
]
|
||
},
|
||
"execution_count": 39,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.linear_model import LinearRegression\n",
|
||
"\n",
|
||
"linReg = LinearRegression()\n",
|
||
"linReg.fit(\n",
|
||
" Xtrain, Ytrain\n",
|
||
") # no need to scale for OLS if you just want to predict (unless the solver works best with scaled data)\n",
|
||
"# the predictions should not be different with or without standardization (could differ only owing to numerical problems)\n",
|
||
"hatY_LinReg = linReg.predict(Xtest)\n",
|
||
"\n",
|
||
"fig, ax = plt.subplots()\n",
|
||
"ax.scatter(Ytest, hatY_LinReg, s=5)\n",
|
||
"ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls=\"--\", c=\"gray\")\n",
|
||
"ax.set_xlabel(\"Ytest\")\n",
|
||
"ax.set_ylabel(\"hatY\")\n",
|
||
"ax.set_title(\"Predicted vs true salaries for OLS estimator\")\n",
|
||
"ax.axis(\"square\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Optional Exercise 1**\n",
|
||
"Do the same with the Ridge estimator (with alpha chosen by CV). "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 40,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:42:58.124288Z",
|
||
"start_time": "2025-03-26T10:42:58.032948Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(-47.125, 2529.625, -96.55458820082124, 2480.1954117991786)"
|
||
]
|
||
},
|
||
"execution_count": 40,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"best_ridge = Ridge(fit_intercept=True, alpha=132.57)\n",
|
||
"best_ridge.fit(XtrainScaled, Ytrain)\n",
|
||
"\n",
|
||
"hatY_ridge = best_ridge.predict(XtestScaled)\n",
|
||
"\n",
|
||
"fig, ax = plt.subplots()\n",
|
||
"ax.scatter(Ytest, hatY_ridge, s=5)\n",
|
||
"ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls=\"--\", c=\"gray\")\n",
|
||
"ax.set_xlabel(\"Ytest\")\n",
|
||
"ax.set_ylabel(\"hatY\")\n",
|
||
"ax.set_title(\"Predicted vs true salaries for Ridge estimator\")\n",
|
||
"ax.axis(\"square\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Optional Exercise 2** Do the same with Lasso with alpha chosen by CV. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 41,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:43:01.103239Z",
|
||
"start_time": "2025-03-26T10:43:01.012916Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(-47.125, 2529.625, -130.82214503665753, 2445.9278549633423)"
|
||
]
|
||
},
|
||
"execution_count": 41,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"best_lasso = Lasso(fit_intercept=True, alpha=4.78)\n",
|
||
"best_lasso.fit(XtrainScaled, Ytrain)\n",
|
||
"\n",
|
||
"hatY_lasso = best_lasso.predict(XtestScaled)\n",
|
||
"\n",
|
||
"fig, ax = plt.subplots()\n",
|
||
"ax.scatter(Ytest, hatY_lasso, s=5)\n",
|
||
"ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls=\"--\", c=\"gray\")\n",
|
||
"ax.set_xlabel(\"Ytest\")\n",
|
||
"ax.set_ylabel(\"hatY\")\n",
|
||
"ax.set_title(\"Predicted vs true salaries for Ridge estimator\")\n",
|
||
"ax.axis(\"square\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Calculation of the best alpha with a \"BIC criterion\".**\n",
|
||
"Here is an alternative method to choose the \"best\" alpha : "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 42,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:43:03.416471Z",
|
||
"start_time": "2025-03-26T10:43:03.407079Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"best alpha chosen by BIC criterion : 17.81143051916605\n",
|
||
"best alpha chosen by CV : 4.783179587705159\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.linear_model import LassoLarsIC\n",
|
||
"\n",
|
||
"lassoBIC = LassoLarsIC(criterion=\"bic\")\n",
|
||
"lassoBIC.fit(XtrainScaled, Ytrain)\n",
|
||
"print(\"best alpha chosen by BIC criterion :\", lassoBIC.alpha_)\n",
|
||
"print(\"best alpha chosen by CV :\", lassoCV.alpha_)\n",
|
||
"\n",
|
||
"alpha_BIC = lassoBIC.alpha_"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Comparison between predicted salary and true salary for \"LassoBIC\" estimator :** "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 43,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:43:14.238270Z",
|
||
"start_time": "2025-03-26T10:43:14.158168Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(-47.125, 2529.625, -111.67387145852295, 2465.076128541477)"
|
||
]
|
||
},
|
||
"execution_count": 43,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"lasso.set_params(alpha=alpha_BIC)\n",
|
||
"lasso.fit(XtrainScaled, Ytrain)\n",
|
||
"\n",
|
||
"hatY_BIC = lasso.predict(XtestScaled)\n",
|
||
"\n",
|
||
"fig, ax = plt.subplots()\n",
|
||
"ax.scatter(Ytest, hatY_BIC, s=5)\n",
|
||
"ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls=\"--\", c=\"gray\")\n",
|
||
"ax.set_xlabel(\"Ytest\")\n",
|
||
"ax.set_ylabel(\"hatY\")\n",
|
||
"ax.set_title(\"Predicted vs true salaries for LassoBIC estimator\")\n",
|
||
"ax.axis(\"square\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Optional Exercise 3** Compute the MSE for these four different estimators (LassoCV, LassoBIC, OLS, RidgeCV)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 44,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:43:16.115247Z",
|
||
"start_time": "2025-03-26T10:43:16.108424Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"MSE for LassoCV : 8747438999.110947\n",
|
||
"MSE for LassoBIC : 13652129197.77105\n",
|
||
"MSE for RidgeCV : 34959235357.83314\n",
|
||
"MSE for OLS : 150074.5486842358\n",
|
||
"150074.5486842358\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.metrics import mean_squared_error\n",
|
||
"\n",
|
||
"MSEs = []\n",
|
||
"for name, estimator in zip(\n",
|
||
" [\"LassoCV\", \"LassoBIC\", \"RidgeCV\", \"OLS\"],\n",
|
||
" [lassoCV, lassoBIC, ridgeCV, linReg],\n",
|
||
" strict=False,\n",
|
||
"):\n",
|
||
" y_pred = estimator.predict(Xtest)\n",
|
||
" MSE = mean_squared_error(Ytest, y_pred)\n",
|
||
" print(f\"MSE for {name} : {MSE}\")\n",
|
||
" MSEs.append(MSE)\n",
|
||
"\n",
|
||
"print(np.min(MSEs))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Optional Exercise 4** Display the boxplot of the absolute errors for each estimator."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 45,
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-26T10:43:18.421455Z",
|
||
"start_time": "2025-03-26T10:43:18.334991Z"
|
||
}
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1000x600 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Boxplot the absolute error for LassoCV, LassoBIC, RidgeCV, OLS\n",
|
||
"ridge_cv_errors = np.abs(Ytest - ridgeCV.predict(Xtest))\n",
|
||
"lasso_cv_errors = np.abs(Ytest - lassoCV.predict(Xtest))\n",
|
||
"lasso_bic_errors = np.abs(Ytest - lassoBIC.predict(Xtest))\n",
|
||
"ols_errors = np.abs(Ytest - linReg.predict(Xtest))\n",
|
||
"\n",
|
||
"fig, ax = plt.subplots(figsize=(10, 6))\n",
|
||
"ax.boxplot(\n",
|
||
" [ridge_cv_errors, lasso_cv_errors, lasso_bic_errors, ols_errors],\n",
|
||
" labels=[\"RidgeCV\", \"LassoCV\", \"LassoBIC\", \"OLS\"],\n",
|
||
")\n",
|
||
"ax.set_title(\"Boxplot of Absolute Errors\")\n",
|
||
"ax.set_ylabel(\"Absolute Error\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Optional Exercise 5**\n",
|
||
"Based on the above information, which estimator would you recommend?"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 46,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# answer fot Optional Exercise 5"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "base",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.2"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|