mirror of
https://github.com/ArthurDanjou/ArtStudies.git
synced 2026-02-01 06:29:33 +01:00
Refactor code formatting and improve readability in Jupyter notebooks for TP_4 and TP_5
- Adjusted indentation and line breaks for better clarity in function definitions and import statements. - Standardized string quotes for consistency across the codebase. - Enhanced readability of DataFrame creation and manipulation by breaking long lines into multiple lines. - Cleaned up print statements and comments for improved understanding. - Ensured consistent use of whitespace around operators and after commas.
This commit is contained in:
@@ -128,7 +128,7 @@
|
||||
"\n",
|
||||
"# Configuration des graphiques\n",
|
||||
"plt.style.use(\"seaborn-v0_8-darkgrid\")\n",
|
||||
"sns.set_palette(\"husl\")\n"
|
||||
"sns.set_palette(\"husl\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -162,20 +162,38 @@
|
||||
"## Chargement du dataset Adult Income (Census)\n",
|
||||
"\n",
|
||||
"print(\"=== Chargement du dataset Adult Income ===\\n\")\n",
|
||||
"print(\"Dataset classique de Kaggle/UCI qui illustre parfaitement les forces de CatBoost\")\n",
|
||||
"print(\"Objectif : Prédire si le revenu annuel > 50K$ basé sur des caractéristiques socio-démographiques\\n\")\n",
|
||||
"print(\n",
|
||||
" \"Dataset classique de Kaggle/UCI qui illustre parfaitement les forces de CatBoost\"\n",
|
||||
")\n",
|
||||
"print(\n",
|
||||
" \"Objectif : Prédire si le revenu annuel > 50K$ basé sur des caractéristiques socio-démographiques\\n\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Chargement depuis UCI\n",
|
||||
"url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n",
|
||||
"\n",
|
||||
"column_names = [\n",
|
||||
" 'age', 'workclass', 'fnlwgt', 'education', 'education_num',\n",
|
||||
" 'marital_status', 'occupation', 'relationship', 'race', 'sex',\n",
|
||||
" 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country', 'income'\n",
|
||||
" \"age\",\n",
|
||||
" \"workclass\",\n",
|
||||
" \"fnlwgt\",\n",
|
||||
" \"education\",\n",
|
||||
" \"education_num\",\n",
|
||||
" \"marital_status\",\n",
|
||||
" \"occupation\",\n",
|
||||
" \"relationship\",\n",
|
||||
" \"race\",\n",
|
||||
" \"sex\",\n",
|
||||
" \"capital_gain\",\n",
|
||||
" \"capital_loss\",\n",
|
||||
" \"hours_per_week\",\n",
|
||||
" \"native_country\",\n",
|
||||
" \"income\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" df = pd.read_csv(url, names=column_names, sep=r',\\s*', engine='python', na_values='?')\n",
|
||||
" df = pd.read_csv(\n",
|
||||
" url, names=column_names, sep=r\",\\s*\", engine=\"python\", na_values=\"?\"\n",
|
||||
" )\n",
|
||||
" print(\"Dataset chargé depuis UCI repository\")\n",
|
||||
"except: # noqa: E722\n",
|
||||
" print(\"Impossible de charger depuis UCI, création d'un dataset simulé similaire...\")\n",
|
||||
@@ -183,44 +201,121 @@
|
||||
" np.random.seed(42)\n",
|
||||
" n_samples = 32561\n",
|
||||
"\n",
|
||||
" df = pd.DataFrame({\n",
|
||||
" 'age': np.random.randint(17, 90, n_samples),\n",
|
||||
" 'workclass': np.random.choice(['Private', 'Self-emp-not-inc', 'Local-gov', 'State-gov', 'Self-emp-inc',\n",
|
||||
" 'Federal-gov', 'Without-pay'], n_samples, p=[0.73, 0.08, 0.06, 0.04, 0.03, 0.03, 0.03]),\n",
|
||||
" 'fnlwgt': np.random.randint(12285, 1484705, n_samples),\n",
|
||||
" 'education': np.random.choice(\n",
|
||||
" ['HS-grad', 'Some-college', 'Bachelors', 'Masters', 'Assoc-voc',\n",
|
||||
" 'Doctorate', '11th', '9th', '7th-8th'], n_samples,\n",
|
||||
" p=[0.32, 0.22, 0.16, 0.05, 0.04, 0.01, 0.04, 0.03, 0.13]),\n",
|
||||
" 'education_num': np.random.randint(1, 16, n_samples),\n",
|
||||
" 'marital_status': np.random.choice(['Married-civ-spouse', 'Never-married', 'Divorced', 'Separated',\n",
|
||||
" 'Widowed'], n_samples, p=[0.46, 0.33, 0.14, 0.03, 0.04]),\n",
|
||||
" 'occupation': np.random.choice(['Prof-specialty', 'Craft-repair', 'Exec-managerial', 'Adm-clerical',\n",
|
||||
" 'Sales', 'Other-service', 'Machine-op-inspct', 'Tech-support'], n_samples,\n",
|
||||
" p=[0.13, 0.13, 0.13, 0.12, 0.11, 0.10, 0.06, 0.22]),\n",
|
||||
" 'relationship': np.random.choice(['Husband', 'Not-in-family', 'Own-child', 'Unmarried', 'Wife', 'Other-relative'],\n",
|
||||
" n_samples, p=[0.40, 0.26, 0.16, 0.10, 0.05, 0.03]),\n",
|
||||
" 'race': np.random.choice(['White', 'Black', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other'],\n",
|
||||
" n_samples, p=[0.85, 0.10, 0.03, 0.01, 0.01]),\n",
|
||||
" 'sex': np.random.choice(['Male', 'Female'], n_samples, p=[0.67, 0.33]),\n",
|
||||
" 'capital_gain': np.where(np.random.random(n_samples) < 0.92, 0, np.random.randint(1, 99999, n_samples)),\n",
|
||||
" 'capital_loss': np.where(np.random.random(n_samples) < 0.95, 0, np.random.randint(1, 4356, n_samples)),\n",
|
||||
" 'hours_per_week': np.random.randint(1, 99, n_samples),\n",
|
||||
" 'native_country': np.random.choice(['United-States', 'Mexico', 'Philippines', 'Germany', 'Canada',\n",
|
||||
" 'India', 'Other'], n_samples, p=[0.90, 0.02, 0.01, 0.01, 0.01, 0.01, 0.04])\n",
|
||||
" })\n",
|
||||
" df = pd.DataFrame(\n",
|
||||
" {\n",
|
||||
" \"age\": np.random.randint(17, 90, n_samples),\n",
|
||||
" \"workclass\": np.random.choice(\n",
|
||||
" [\n",
|
||||
" \"Private\",\n",
|
||||
" \"Self-emp-not-inc\",\n",
|
||||
" \"Local-gov\",\n",
|
||||
" \"State-gov\",\n",
|
||||
" \"Self-emp-inc\",\n",
|
||||
" \"Federal-gov\",\n",
|
||||
" \"Without-pay\",\n",
|
||||
" ],\n",
|
||||
" n_samples,\n",
|
||||
" p=[0.73, 0.08, 0.06, 0.04, 0.03, 0.03, 0.03],\n",
|
||||
" ),\n",
|
||||
" \"fnlwgt\": np.random.randint(12285, 1484705, n_samples),\n",
|
||||
" \"education\": np.random.choice(\n",
|
||||
" [\n",
|
||||
" \"HS-grad\",\n",
|
||||
" \"Some-college\",\n",
|
||||
" \"Bachelors\",\n",
|
||||
" \"Masters\",\n",
|
||||
" \"Assoc-voc\",\n",
|
||||
" \"Doctorate\",\n",
|
||||
" \"11th\",\n",
|
||||
" \"9th\",\n",
|
||||
" \"7th-8th\",\n",
|
||||
" ],\n",
|
||||
" n_samples,\n",
|
||||
" p=[0.32, 0.22, 0.16, 0.05, 0.04, 0.01, 0.04, 0.03, 0.13],\n",
|
||||
" ),\n",
|
||||
" \"education_num\": np.random.randint(1, 16, n_samples),\n",
|
||||
" \"marital_status\": np.random.choice(\n",
|
||||
" [\n",
|
||||
" \"Married-civ-spouse\",\n",
|
||||
" \"Never-married\",\n",
|
||||
" \"Divorced\",\n",
|
||||
" \"Separated\",\n",
|
||||
" \"Widowed\",\n",
|
||||
" ],\n",
|
||||
" n_samples,\n",
|
||||
" p=[0.46, 0.33, 0.14, 0.03, 0.04],\n",
|
||||
" ),\n",
|
||||
" \"occupation\": np.random.choice(\n",
|
||||
" [\n",
|
||||
" \"Prof-specialty\",\n",
|
||||
" \"Craft-repair\",\n",
|
||||
" \"Exec-managerial\",\n",
|
||||
" \"Adm-clerical\",\n",
|
||||
" \"Sales\",\n",
|
||||
" \"Other-service\",\n",
|
||||
" \"Machine-op-inspct\",\n",
|
||||
" \"Tech-support\",\n",
|
||||
" ],\n",
|
||||
" n_samples,\n",
|
||||
" p=[0.13, 0.13, 0.13, 0.12, 0.11, 0.10, 0.06, 0.22],\n",
|
||||
" ),\n",
|
||||
" \"relationship\": np.random.choice(\n",
|
||||
" [\n",
|
||||
" \"Husband\",\n",
|
||||
" \"Not-in-family\",\n",
|
||||
" \"Own-child\",\n",
|
||||
" \"Unmarried\",\n",
|
||||
" \"Wife\",\n",
|
||||
" \"Other-relative\",\n",
|
||||
" ],\n",
|
||||
" n_samples,\n",
|
||||
" p=[0.40, 0.26, 0.16, 0.10, 0.05, 0.03],\n",
|
||||
" ),\n",
|
||||
" \"race\": np.random.choice(\n",
|
||||
" [\"White\", \"Black\", \"Asian-Pac-Islander\", \"Amer-Indian-Eskimo\", \"Other\"],\n",
|
||||
" n_samples,\n",
|
||||
" p=[0.85, 0.10, 0.03, 0.01, 0.01],\n",
|
||||
" ),\n",
|
||||
" \"sex\": np.random.choice([\"Male\", \"Female\"], n_samples, p=[0.67, 0.33]),\n",
|
||||
" \"capital_gain\": np.where(\n",
|
||||
" np.random.random(n_samples) < 0.92,\n",
|
||||
" 0,\n",
|
||||
" np.random.randint(1, 99999, n_samples),\n",
|
||||
" ),\n",
|
||||
" \"capital_loss\": np.where(\n",
|
||||
" np.random.random(n_samples) < 0.95,\n",
|
||||
" 0,\n",
|
||||
" np.random.randint(1, 4356, n_samples),\n",
|
||||
" ),\n",
|
||||
" \"hours_per_week\": np.random.randint(1, 99, n_samples),\n",
|
||||
" \"native_country\": np.random.choice(\n",
|
||||
" [\n",
|
||||
" \"United-States\",\n",
|
||||
" \"Mexico\",\n",
|
||||
" \"Philippines\",\n",
|
||||
" \"Germany\",\n",
|
||||
" \"Canada\",\n",
|
||||
" \"India\",\n",
|
||||
" \"Other\",\n",
|
||||
" ],\n",
|
||||
" n_samples,\n",
|
||||
" p=[0.90, 0.02, 0.01, 0.01, 0.01, 0.01, 0.04],\n",
|
||||
" ),\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Création de la cible avec logique réaliste\n",
|
||||
" income_score = (\n",
|
||||
" (df['age'] > 35).astype(int) * 20 +\n",
|
||||
" (df['education_num'] > 12).astype(int) * 30 +\n",
|
||||
" (df['hours_per_week'] > 40).astype(int) * 15 +\n",
|
||||
" (df['capital_gain'] > 0).astype(int) * 25 +\n",
|
||||
" (df['marital_status'] == 'Married-civ-spouse').astype(int) * 20 +\n",
|
||||
" (df['occupation'].isin(['Exec-managerial', 'Prof-specialty'])).astype(int) * 15 +\n",
|
||||
" np.random.normal(0, 15, n_samples)\n",
|
||||
" (df[\"age\"] > 35).astype(int) * 20\n",
|
||||
" + (df[\"education_num\"] > 12).astype(int) * 30\n",
|
||||
" + (df[\"hours_per_week\"] > 40).astype(int) * 15\n",
|
||||
" + (df[\"capital_gain\"] > 0).astype(int) * 25\n",
|
||||
" + (df[\"marital_status\"] == \"Married-civ-spouse\").astype(int) * 20\n",
|
||||
" + (df[\"occupation\"].isin([\"Exec-managerial\", \"Prof-specialty\"])).astype(int)\n",
|
||||
" * 15\n",
|
||||
" + np.random.normal(0, 15, n_samples)\n",
|
||||
" )\n",
|
||||
" df['income'] = (income_score > 60).map({True: '>50K', False: '<=50K'})"
|
||||
" df[\"income\"] = (income_score > 60).map({True: \">50K\", False: \"<=50K\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -358,7 +453,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Encodage de la cible en 0/1\n",
|
||||
"df['income'] = (df['income'] == '>50K').astype(int)"
|
||||
"df[\"income\"] = (df[\"income\"] == \">50K\").astype(int)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -468,7 +563,7 @@
|
||||
" \"race\",\n",
|
||||
" \"sex\",\n",
|
||||
" \"native_country\",\n",
|
||||
"]\n"
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -497,7 +592,7 @@
|
||||
"source": [
|
||||
"print(\"\\n Cardinalité des variables catégorielles :\")\n",
|
||||
"for col in cat_features:\n",
|
||||
" print(f\" {col}: {df[col].nunique()} catégories uniques\")\n"
|
||||
" print(f\" {col}: {df[col].nunique()} catégories uniques\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -526,7 +621,7 @@
|
||||
"# Corrélation avec la cible\n",
|
||||
"print(\"\\n Corrélations avec le revenu >50K :\")\n",
|
||||
"correlations = df[numeric_features].corrwith(df[\"income\"]).sort_values(ascending=False)\n",
|
||||
"print(correlations)\n"
|
||||
"print(correlations)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -583,7 +678,7 @@
|
||||
"axes[1, 2].set_xlabel(\"Revenu >50K\")\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()\n"
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -647,7 +742,7 @@
|
||||
"\n",
|
||||
"print(\"\\n=== Préparation pour CatBoost ===\\n\")\n",
|
||||
"print(f\"Variables catégorielles : {cat_features}\")\n",
|
||||
"print(f\"Variables numériques : {numeric_features}\")\n"
|
||||
"print(f\"Variables numériques : {numeric_features}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -861,11 +956,15 @@
|
||||
"# Courbe ROC\n",
|
||||
"fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba_baseline)\n",
|
||||
"plt.figure(figsize=(8, 6))\n",
|
||||
"plt.plot(fpr, tpr, label=f'CatBoost (AUC = {roc_auc_score(y_test, y_pred_proba_baseline):.3f})')\n",
|
||||
"plt.plot([0, 1], [0, 1], 'k--', label='Hasard')\n",
|
||||
"plt.xlabel('Taux de faux positifs')\n",
|
||||
"plt.ylabel('Taux de vrais positifs')\n",
|
||||
"plt.title('Courbe ROC')\n",
|
||||
"plt.plot(\n",
|
||||
" fpr,\n",
|
||||
" tpr,\n",
|
||||
" label=f\"CatBoost (AUC = {roc_auc_score(y_test, y_pred_proba_baseline):.3f})\",\n",
|
||||
")\n",
|
||||
"plt.plot([0, 1], [0, 1], \"k--\", label=\"Hasard\")\n",
|
||||
"plt.xlabel(\"Taux de faux positifs\")\n",
|
||||
"plt.ylabel(\"Taux de vrais positifs\")\n",
|
||||
"plt.title(\"Courbe ROC\")\n",
|
||||
"plt.legend()\n",
|
||||
"plt.grid(True)\n",
|
||||
"plt.show()"
|
||||
@@ -1133,7 +1232,7 @@
|
||||
"df[\"montant_credit\"] = np.random.uniform(3000, 450000)\n",
|
||||
"df[\"defaut\"] = df[\"capital_gain\"] < df[\"capital_loss\"]\n",
|
||||
"\n",
|
||||
"df = df.drop([\"capital_gain\", \"capital_loss\"], axis=1)\n"
|
||||
"df = df.drop([\"capital_gain\", \"capital_loss\"], axis=1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1211,7 +1310,7 @@
|
||||
"perte_totale_predite = y_pred_reg.sum()\n",
|
||||
"print(f\"\\\\nPerte totale réelle : {perte_totale_reelle:,.2f}€\")\n",
|
||||
"print(f\"Perte totale prédite : {perte_totale_predite:,.2f}€\")\n",
|
||||
"print(f\"Erreur : {abs(perte_totale_reelle - perte_totale_predite):,.2f}€\")\n"
|
||||
"print(f\"Erreur : {abs(perte_totale_reelle - perte_totale_predite):,.2f}€\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1276,7 +1375,7 @@
|
||||
"plt.title(\"Top 15 variables les plus importantes\")\n",
|
||||
"plt.gca().invert_yaxis()\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()\n"
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1320,8 +1419,8 @@
|
||||
" depth=6,\n",
|
||||
" random_seed=42,\n",
|
||||
" verbose=0,\n",
|
||||
" auto_class_weights='Balanced',\n",
|
||||
" eval_metric='AUC'\n",
|
||||
" auto_class_weights=\"Balanced\",\n",
|
||||
" eval_metric=\"AUC\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"model_weighted.fit(train_pool, eval_set=test_pool)"
|
||||
@@ -1405,7 +1504,7 @@
|
||||
" random_seed=42,\n",
|
||||
" verbose=0,\n",
|
||||
" scale_pos_weight=scale,\n",
|
||||
" eval_metric='AUC'\n",
|
||||
" eval_metric=\"AUC\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"model_scaled.fit(train_pool, eval_set=test_pool)"
|
||||
@@ -1427,10 +1526,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# CatBoost a un support natif pour SHAP\n",
|
||||
"shap_values = model_baseline.get_feature_importance(\n",
|
||||
" train_pool,\n",
|
||||
" type='ShapValues'\n",
|
||||
")"
|
||||
"shap_values = model_baseline.get_feature_importance(train_pool, type=\"ShapValues\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1483,13 +1579,12 @@
|
||||
"source": [
|
||||
"# Importance SHAP moyenne\n",
|
||||
"shap_importance = np.abs(shap_values[:, :-1]).mean(axis=0)\n",
|
||||
"shap_df = pd.DataFrame({\n",
|
||||
" 'feature': X_train.columns,\n",
|
||||
" 'shap_importance': shap_importance\n",
|
||||
"}).sort_values('shap_importance', ascending=False)\n",
|
||||
"shap_df = pd.DataFrame(\n",
|
||||
" {\"feature\": X_train.columns, \"shap_importance\": shap_importance}\n",
|
||||
").sort_values(\"shap_importance\", ascending=False)\n",
|
||||
"\n",
|
||||
"print(\"\\nImportance SHAP moyenne :\")\n",
|
||||
"print(shap_df.head(10))\n"
|
||||
"print(shap_df.head(10))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1512,9 +1607,9 @@
|
||||
"source": [
|
||||
"# Visualisation\n",
|
||||
"plt.figure(figsize=(10, 8))\n",
|
||||
"plt.barh(shap_df['feature'][:15], shap_df['shap_importance'][:15])\n",
|
||||
"plt.xlabel('|SHAP value| moyen')\n",
|
||||
"plt.title('Importance des features (SHAP)')\n",
|
||||
"plt.barh(shap_df[\"feature\"][:15], shap_df[\"shap_importance\"][:15])\n",
|
||||
"plt.xlabel(\"|SHAP value| moyen\")\n",
|
||||
"plt.title(\"Importance des features (SHAP)\")\n",
|
||||
"plt.gca().invert_yaxis()\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
|
||||
Reference in New Issue
Block a user