{
"cells": [
{
"cell_type": "markdown",
"id": "78b68162",
"metadata": {},
"source": [
"# TP3 - Fine tuning\n",
"\n",
"L'objectif de ce TP est de présenter comment exploiter les modèles disponibles sur HuggingFace, plus particulièrement pour des tâches de classification. Nous mettrons également en pratique l'approche [*Two-stage fine-tuning*](https://arxiv.org/pdf/2207.10858).\n",
"\n",
"## Contexte\n",
"\n",
"Les flux RSS, de l'anglais *Rich Site Summary*, est un flux d'information pour partager simplement des informations. Largement utilisé dans le domaine de la communication par les médias, nous pouvons nous tenir au courant des informations que les journaux partagent par exemple.\n",
"Cependant, le volume d'informations peut être rapidement très important pour pouvoir être lu. On se propose de construire un filtre pour nous notifier uniquement quand une information sera jugée *d'intérêt*.\n",
"\n",
"Pour cela, nous avons sauvegardé puis labellisé les flux RSS de plusieurs journaux mondiaux en anglais, pendant plusieurs semaines. Importons et visualisons la bases de données à disposition."
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "7e9e1d29",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7619 news with 10.59% of positive class\n"
]
},
{
"data": {
"application/vnd.microsoft.datawrangler.viewer.v0+json": {
"columns": [
{
"name": "index",
"rawType": "int64",
"type": "integer"
},
{
"name": "provider",
"rawType": "object",
"type": "string"
},
{
"name": "category",
"rawType": "object",
"type": "string"
},
{
"name": "link",
"rawType": "object",
"type": "string"
},
{
"name": "published",
"rawType": "datetime64[ns, UTC]",
"type": "unknown"
},
{
"name": "title",
"rawType": "object",
"type": "string"
},
{
"name": "summary",
"rawType": "object",
"type": "string"
},
{
"name": "target",
"rawType": "int64",
"type": "integer"
}
],
"ref": "46cf2b21-2a0e-4b5c-a13a-fad0480005c4",
"rows": [
[
"0",
"El Pais",
"World",
"https://english.elpais.com/climate/2025-12-27/a-floating-school-teaches-children-how-to-save-lake-atitlan.html",
"2025-12-27 04:00:00+00:00",
"A floating school teaches children how to save Lake Atitlán",
"This is the latest attempt by Guatemalan scientists to make kids aware that the once-pristine lake is in danger of ‘dying’ from the constant discharge of waste water, garbage and erosion",
"0"
],
[
"1",
"El Pais",
"World",
"https://english.elpais.com/climate/2025-12-31/david-king-chemist-there-are-scientists-studying-how-to-cool-the-planet-nobody-should-stop-these-experiments-from-happening.html",
"2025-12-31 13:53:41+00:00",
"David King, chemist: ‘There are scientists studying how to cool the planet; nobody should stop these experiments from happening’",
"The British researcher is calling for global regulations to test extreme measures against climate change, such as marine cloud brightening over the Arctic",
"0"
],
[
"2",
"El Pais",
"World",
"https://english.elpais.com/climate/2026-01-01/cartagena-de-indias-is-sinking-what-can-the-city-do-to-mitigate-it.html",
"2026-01-01 12:24:15+00:00",
"Cartagena de Indias is sinking: What can the city do to mitigate it?",
"Sea levels in the bay of the Colombian resort have risen seven millimeters per year for the past two decades, the second-highest rate in the Caribbean after Haiti",
"0"
],
[
"3",
"El Pais",
"World",
"https://english.elpais.com/climate/2026-01-03/the-secrets-of-the-washingtonia-palm-named-after-the-first-us-president.html",
"2026-01-03 04:45:00+00:00",
"The secrets of the Washingtonia palm, named after the first US president",
"They are among the most widely cultivated plants in the world, and can thrive in complete neglect, able to withstand both harsh summers and low temperatures",
"0"
],
[
"4",
"El Pais",
"World",
"https://english.elpais.com/climate/2026-01-05/a-recipe-for-resistance-indigenous-peoples-politicize-their-struggles-from-the-kitchen.html",
"2026-01-05 17:02:27+00:00",
"A recipe for resistance: Indigenous peoples politicize their struggles from the kitchen",
"A book that compiles the recovery of ancestral recipes from the Americas, from Canada to the Amazon, maps strategies in the face of climate change and agribusiness",
"0"
]
],
"shape": {
"columns": 7,
"rows": 5
}
},
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
provider
\n",
"
category
\n",
"
link
\n",
"
published
\n",
"
title
\n",
"
summary
\n",
"
target
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
El Pais
\n",
"
World
\n",
"
https://english.elpais.com/climate/2025-12-27/...
\n",
"
2025-12-27 04:00:00+00:00
\n",
"
A floating school teaches children how to save...
\n",
"
This is the latest attempt by Guatemalan scien...
\n",
"
0
\n",
"
\n",
"
\n",
"
1
\n",
"
El Pais
\n",
"
World
\n",
"
https://english.elpais.com/climate/2025-12-31/...
\n",
"
2025-12-31 13:53:41+00:00
\n",
"
David King, chemist: ‘There are scientists stu...
\n",
"
The British researcher is calling for global r...
\n",
"
0
\n",
"
\n",
"
\n",
"
2
\n",
"
El Pais
\n",
"
World
\n",
"
https://english.elpais.com/climate/2026-01-01/...
\n",
"
2026-01-01 12:24:15+00:00
\n",
"
Cartagena de Indias is sinking: What can the c...
\n",
"
Sea levels in the bay of the Colombian resort ...
\n",
"
0
\n",
"
\n",
"
\n",
"
3
\n",
"
El Pais
\n",
"
World
\n",
"
https://english.elpais.com/climate/2026-01-03/...
\n",
"
2026-01-03 04:45:00+00:00
\n",
"
The secrets of the Washingtonia palm, named af...
\n",
"
They are among the most widely cultivated plan...
\n",
"
0
\n",
"
\n",
"
\n",
"
4
\n",
"
El Pais
\n",
"
World
\n",
"
https://english.elpais.com/climate/2026-01-05/...
\n",
"
2026-01-05 17:02:27+00:00
\n",
"
A recipe for resistance: Indigenous peoples po...
\n",
"
A book that compiles the recovery of ancestral...
\n",
"
0
\n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" provider category link \\\n",
"0 El Pais World https://english.elpais.com/climate/2025-12-27/... \n",
"1 El Pais World https://english.elpais.com/climate/2025-12-31/... \n",
"2 El Pais World https://english.elpais.com/climate/2026-01-01/... \n",
"3 El Pais World https://english.elpais.com/climate/2026-01-03/... \n",
"4 El Pais World https://english.elpais.com/climate/2026-01-05/... \n",
"\n",
" published \\\n",
"0 2025-12-27 04:00:00+00:00 \n",
"1 2025-12-31 13:53:41+00:00 \n",
"2 2026-01-01 12:24:15+00:00 \n",
"3 2026-01-03 04:45:00+00:00 \n",
"4 2026-01-05 17:02:27+00:00 \n",
"\n",
" title \\\n",
"0 A floating school teaches children how to save... \n",
"1 David King, chemist: ‘There are scientists stu... \n",
"2 Cartagena de Indias is sinking: What can the c... \n",
"3 The secrets of the Washingtonia palm, named af... \n",
"4 A recipe for resistance: Indigenous peoples po... \n",
"\n",
" summary target \n",
"0 This is the latest attempt by Guatemalan scien... 0 \n",
"1 The British researcher is calling for global r... 0 \n",
"2 Sea levels in the bay of the Colombian resort ... 0 \n",
"3 They are among the most widely cultivated plan... 0 \n",
"4 A book that compiles the recovery of ancestral... 0 "
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"df = pd.read_csv(\"data/Complete.csv\")\n",
"\n",
"df[\"published\"] = pd.to_datetime(df[\"published\"])\n",
"df = df.loc[~df[\"target\"].isna()].copy()\n",
"df[\"target\"] = df[\"target\"].astype(int)\n",
"\n",
"\n",
"imbalanced_rate = df[\"target\"].mean()\n",
"print(f\"{df.shape[0]} news with {100 * imbalanced_rate:.2f}% of positive class\")\n",
"\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"id": "19931804",
"metadata": {},
"source": [
"## Préparation\n",
"\n",
"Nous n'allons exploiter que le titre et le contenu du flux RSS, et avant de commencer nous allons nettoyer un peu le texte."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "45710a37",
"metadata": {},
"outputs": [],
"source": [
"def clean_text(title: str, summary: str) -> str:\n",
" \"\"\"Clean the title and summary strings.\"\"\"\n",
"\n",
" def clean_string(string: str) -> str:\n",
" \"\"\"Clean a single string.\"\"\"\n",
" string = string.lower()\n",
" string = string.replace(\"(Source: Bloomberg)\", \"\")\n",
" string = string.replace(\"“\", \"'\")\n",
" string = string.replace(\"”\", \"'\")\n",
" string = string.replace(\"\\n\", \"\")\n",
" string = string.replace(\" \", \" \")\n",
" return string.strip()\n",
"\n",
" title = clean_string(title)\n",
" summary = clean_string(summary)\n",
"\n",
" return f\"Title: {title}\\nSummary: {summary}\"\n",
"\n",
"\n",
"df[\"text\"] = df.apply(lambda row: clean_text(row[\"title\"], row[\"summary\"]), axis=1)"
]
},
{
"cell_type": "markdown",
"id": "467c2960",
"metadata": {},
"source": [
"Puisque le dataset est dépendant du temps, nous allons le découper en jeu d'entraînement et de validation en suivant cela."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "024281d1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training size: 6092, imbalanced rate: 10.69%\n",
"Validation size: 1523, imbalanced rate: 10.18%\n"
]
}
],
"source": [
"train_ratio = 0.8\n",
"\n",
"df = df.drop_duplicates(subset=[\"text\"])\n",
"df = df.sort_values(\"published\")\n",
"split_date = df[\"published\"].quantile(train_ratio)\n",
"\n",
"df_train = df[df[\"published\"] <= split_date][[\"text\", \"target\"]]\n",
"df_val = df[df[\"published\"] > split_date][[\"text\", \"target\"]]\n",
"\n",
"df_train = df_train.reset_index(drop=True)\n",
"df_val = df_val.reset_index(drop=True)\n",
"\n",
"print(\n",
" f\"Training size: {len(df_train)}, imbalanced rate: {100 * df_train.target.mean():.2f}%\",\n",
")\n",
"print(\n",
" f\"Validation size: {len(df_val)}, imbalanced rate: {100 * df_val.target.mean():.2f}%\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "2d9d353f",
"metadata": {},
"source": [
"Puisque nous avons un déséquilibre, nous allons le reporter dans la fonction de perte. Pour préparer cette étape, nous allons utiliser la fonction [`compute_class_weight`](https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_class_weight.html) et stocker son résultat dans un [`Tensor`](https://docs.pytorch.org/docs/stable/tensors.html) de PyTorch pour l'exploiter dans la fonction de perte."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "71f1c0e3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([0.5598, 4.6790])\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"import torch\n",
"from sklearn.utils.class_weight import compute_class_weight\n",
"\n",
"labels = df_train[\"target\"].to_numpy()\n",
"class_weights = compute_class_weight(\n",
" class_weight=\"balanced\",\n",
" classes=np.array([0, 1]),\n",
" y=labels,\n",
")\n",
"\n",
"class_weights = torch.tensor(class_weights, dtype=torch.float)\n",
"print(class_weights)"
]
},
{
"cell_type": "markdown",
"id": "f093593f",
"metadata": {},
"source": [
"## Premier entraînement\n",
"\n",
"Il est maintenant temps de basculer du monde classique numpy, pandas, scikit-learn à celui de la librairies transformers d'HuggingFace.\n",
"\n",
"### Dataset d'entraînement\n",
"\n",
"On ne travaille plus avec des DataFrames pandas mais des [Datasets](https://huggingface.co/docs/datasets/use_with_pandas). "
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "2b32e4cc",
"metadata": {},
"outputs": [],
"source": [
"from datasets import Dataset\n",
"\n",
"ds_train = Dataset.from_pandas(df_train)\n",
"ds_val = Dataset.from_pandas(df_val)\n",
"\n",
"ds_train = ds_train.rename_column(\"target\", \"labels\")\n",
"ds_val = ds_val.rename_column(\"target\", \"labels\")"
]
},
{
"cell_type": "markdown",
"id": "824250f7",
"metadata": {},
"source": [
"Le texte contenu dans les datasets n'est pas *compréhensible* en l'état par un modèle: nous devons **tokeniser** le texte. Nous devons le faire en utilisant la même méthode qui a servit à entraîner le modèle que nous utiliserons par la suite : nous choisissons [DistilRoBERTa base](https://huggingface.co/distilbert/distilroberta-base). Le choix s'est fait sur les capacités du modèle et sa taille réduite pour permettre des itérations *rapide*."
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "55224f64",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map: 100%|██████████| 6092/6092 [00:00<00:00, 14604.07 examples/s]\n",
"Map: 100%|██████████| 1523/1523 [00:00<00:00, 16717.37 examples/s]\n"
]
}
],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"MODEL_NAME = \"distilroberta-base\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
"\n",
"\n",
"def tokenize(batch: dict) -> dict:\n",
" \"\"\"Tokenize a batch of texts.\"\"\"\n",
" return tokenizer(\n",
" batch[\"text\"],\n",
" truncation=True,\n",
" padding=\"max_length\",\n",
" max_length=256,\n",
" )\n",
"\n",
"\n",
"ds_train = ds_train.map(tokenize, batched=True)\n",
"ds_val = ds_val.map(tokenize, batched=True)"
]
},
{
"cell_type": "markdown",
"id": "301f82fc",
"metadata": {},
"source": [
"Dans cette cellule nous avons réutiliser le tokenizer du modèle d'intérêt puis l'avons appliquer à la colonne contenant le texte. A présent, le dataset n'est plus composé de deux colonnes (labels et target) mais :"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "4bc70018",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['text', 'labels', 'input_ids', 'attention_mask'],\n",
" num_rows: 6092\n",
"})"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_train"
]
},
{
"cell_type": "markdown",
"id": "d2a4ad98",
"metadata": {},
"source": [
"Avant de continuer, nous changeons le format des données en tenseur torch."
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "056a0baf",
"metadata": {},
"outputs": [],
"source": [
"ds_train.set_format(\n",
" type=\"torch\",\n",
" columns=[\"input_ids\", \"attention_mask\", \"labels\"],\n",
")\n",
"ds_val.set_format(\n",
" type=\"torch\",\n",
" columns=[\"input_ids\", \"attention_mask\", \"labels\"],\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c290b844",
"metadata": {},
"source": [
"La librairie transformers n'a pas de méthode `fit` comme dans scikit-learn ou Keras, mais une classe [`Trainer`](https://huggingface.co/docs/transformers/main_classes/trainer).\n",
"Nativement, cette classe ne prend pas en compte le déséquilibre : nous allons la modifier pour le faire. On se propose d'hériter de la classe `Trainer` et de modifier uniquement la méthode `compute_loss` pour qu'elle utilise la [`CrossEntropyLoss`](https://docs.pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) de PyTorch."
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "02a89166",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import Trainer\n",
"\n",
"\n",
"class WeightedTrainer(Trainer):\n",
" \"\"\"Custom Trainer with weighted loss to handle class imbalance.\"\"\"\n",
"\n",
" def compute_loss(\n",
" self,\n",
" model,\n",
" inputs,\n",
" return_outputs = False,\n",
" num_items_in_batch = None,\n",
" ):\n",
" \"\"\"Compute the weighted cross-entropy loss for class imbalance.\"\"\"\n",
" labels = inputs.pop(\"labels\")\n",
" outputs = model(**inputs)\n",
" logits = outputs.get(\"logits\")\n",
"\n",
" # Configuration de la perte avec les poids de classe\n",
" loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights.to(logits.device))\n",
"\n",
" # Calcul de la perte\n",
" loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))\n",
"\n",
" return (loss, outputs) if return_outputs else loss\n"
]
},
{
"cell_type": "markdown",
"id": "69e3c4c0",
"metadata": {},
"source": [
"Tout au long de l'entraînement du modèle, nous voulons être capable de suivre ses performances, en plus de la valeur de la fonction de perte. Nous définissons une fonction pour répondre à cet objectif :"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "48d9e14c",
"metadata": {},
"outputs": [],
"source": [
"import scipy.special as sp\n",
"\n",
"from sklearn.metrics import (\n",
" accuracy_score,\n",
" average_precision_score,\n",
" precision_recall_fscore_support,\n",
")\n",
"\n",
"\n",
"def compute_metrics(eval_pred: tuple) -> dict:\n",
" \"\"\"Compute evaluation metrics.\"\"\"\n",
" logits, y_true = eval_pred\n",
" probas = sp.softmax(logits, axis=1)[:, 1]\n",
"\n",
" threshold = 0.5\n",
" y_pred = (probas >= threshold).astype(int)\n",
"\n",
" precision, recall, f1, _ = precision_recall_fscore_support(\n",
" y_true,\n",
" y_pred,\n",
" average=\"binary\",\n",
" zero_division=0,\n",
" )\n",
"\n",
" auprc = average_precision_score(y_true, probas, average=\"weighted\")\n",
"\n",
" return {\n",
" \"precision\": precision,\n",
" \"recall\": recall,\n",
" \"f1\": f1,\n",
" \"accuracy\": accuracy_score(y_true, y_pred),\n",
" \"auprc\": auprc,\n",
" }\n"
]
},
{
"cell_type": "markdown",
"id": "0b96a2e4",
"metadata": {},
"source": [
"Pour paramétrer la classe `Trainer`, nous allons utiliser la classe [`TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments)."
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "69d6ac13",
"metadata": {},
"outputs": [],
"source": [
"from transformers import TrainingArguments\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir=\"./results\",\n",
" eval_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" per_device_train_batch_size=32,\n",
" per_device_eval_batch_size=32,\n",
" num_train_epochs=3,\n",
" learning_rate=1e-4,\n",
" weight_decay=0.0,\n",
" warmup_ratio=0.1,\n",
" lr_scheduler_type=\"linear\",\n",
" load_best_model_at_end=True,\n",
" metric_for_best_model=\"f1\",\n",
" greater_is_better=True,\n",
" logging_strategy=\"epoch\",\n",
" report_to=[],\n",
" save_total_limit=1,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "edf6abad",
"metadata": {},
"source": [
"Il est enfin temps de lancer l'entraînement !\n",
"\n",
"**Consigne** :\n",
"1. Initialiser le modèle à l'aide de la classe [`AutoModelForSequenceClassification`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForSequenceClassification), on valorisera le paramètre `num_labels` à 2 puisque nous réalisons une classification binaire.\n",
"2. Initialiser le `WeightedTrainer` avec :\n",
" * Le modèle définit à l'étape précédente\n",
" * Comme arguments (args) le paramètrage d'entraînement défini dans la cellule précédente\n",
" * Les datasets d'entraînement et de test\n",
" * La fonction de mesure de performance\n",
"3. Lancer l'entraînement avec la méthode `train` du `Trainer`"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "b730681f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"/Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py:692: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, device pinned memory won't be used.\n",
" warnings.warn(warn_msg)\n"
]
},
{
"data": {
"text/html": [
"\n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py:692: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, device pinned memory won't be used.\n",
" warnings.warn(warn_msg)\n",
"/Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py:692: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, device pinned memory won't be used.\n",
" warnings.warn(warn_msg)\n"
]
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=573, training_loss=0.40664318677226496, metrics={'train_runtime': 1411.5979, 'train_samples_per_second': 12.947, 'train_steps_per_second': 0.406, 'total_flos': 1210487088918528.0, 'train_loss': 0.40664318677226496, 'epoch': 3.0})"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from transformers import AutoModelForSequenceClassification\n",
"\n",
"classifier = AutoModelForSequenceClassification.from_pretrained(\n",
" MODEL_NAME,\n",
" num_labels=2,\n",
")\n",
"\n",
"trainer = WeightedTrainer(\n",
" model=classifier,\n",
" args=training_args,\n",
" train_dataset=ds_train,\n",
" eval_dataset=ds_val,\n",
" compute_metrics=compute_metrics,\n",
")\n",
"\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "9db1793d",
"metadata": {},
"source": [
"On aimerait visualiser la performance et la distribution des réponses du modèles. Pour cela, on produit deux graphes :\n",
"1. Graphique de distribution des probabilités du modèle sur le dataset de validation\n",
"2. Graphique de valeurs des principales métriques (précision, recall et f1-score) en fonction du seuil"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "fc202c3b",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"sns.set_style(style=\"whitegrid\")\n",
"\n",
"\n",
"def metrics_vs_threshold(\n",
" y_true: np.ndarray,\n",
" y_proba: np.ndarray,\n",
" thresholds: np.ndarray,\n",
") -> tuple[np.ndarray, np.ndarray, np.ndarray]:\n",
" \"\"\"Compute precision, recall, and F1-score for various thresholds.\n",
"\n",
" Args:\n",
" y_true (np.ndarray): True binary labels.\n",
" y_proba (np.ndarray): Predicted probabilities for the positive class.\n",
" thresholds (np.ndarray): Array of thresholds to evaluate.\n",
"\n",
" Returns:\n",
" tuple[np.ndarray, np.ndarray, np.ndarray]: Precision, recall, and F1-score arrays.\n",
"\n",
" \"\"\"\n",
" precision, recall, f1 = [], [], []\n",
"\n",
" for t in thresholds:\n",
" y_pred = (y_proba >= t).astype(int)\n",
"\n",
" p, r, f, _ = precision_recall_fscore_support(\n",
" y_true,\n",
" y_pred,\n",
" average=\"binary\",\n",
" zero_division=0,\n",
" )\n",
"\n",
" precision.append(p)\n",
" recall.append(r)\n",
" f1.append(f)\n",
"\n",
" return np.array(precision), np.array(recall), np.array(f1)\n",
"\n",
"\n",
"def make_plot(y_true: np.ndarray, y_proba: np.ndarray) -> None:\n",
" \"\"\"Plot probability distribution and metrics vs. threshold.\n",
"\n",
" Args:\n",
" y_true (np.ndarray): True binary labels.\n",
" y_proba (np.ndarray): Predicted probabilities for the positive class.\n",
"\n",
" \"\"\"\n",
" thresholds = np.linspace(0.0, 1.0, 100)\n",
" dataframe = pd.DataFrame({\"target\": y_true, \"proba\": y_proba})\n",
"\n",
" precision, recall, f1 = metrics_vs_threshold(\n",
" y_true=y_true,\n",
" y_proba=y_proba,\n",
" thresholds=thresholds,\n",
" )\n",
"\n",
" precisions, recalls, f1_scores = [], [], []\n",
"\n",
" for threshold in thresholds:\n",
" y_pred = (y_proba >= threshold).astype(int)\n",
"\n",
" precision, recall, f1_score, _ = precision_recall_fscore_support(\n",
" y_true,\n",
" y_pred,\n",
" average=\"binary\",\n",
" zero_division=0,\n",
" )\n",
"\n",
" precisions.append(precision)\n",
" recalls.append(recall)\n",
" f1_scores.append(f1_score)\n",
"\n",
" precision, recall, f1 = np.array(precisions), np.array(recalls), np.array(f1_scores)\n",
"\n",
" plt.figure(figsize=(12, 5))\n",
"\n",
" plt.subplot(1, 2, 1)\n",
" sns.histplot(\n",
" data=dataframe,\n",
" x=\"proba\",\n",
" hue=\"target\",\n",
" common_norm=False,\n",
" stat=\"probability\",\n",
" )\n",
" plt.xlim(0, 1)\n",
" plt.title(\"Probability distribution\")\n",
"\n",
" plt.subplot(1, 2, 2)\n",
" plt.plot(thresholds, precision, label=\"Precision\")\n",
" plt.plot(thresholds, recall, label=\"Recall\")\n",
" plt.plot(thresholds, f1, label=\"F1-score\")\n",
" plt.xlabel(\"Decision Threshold\")\n",
" plt.ylabel(\"Score\")\n",
" plt.title(\"Threshold Optimization\")\n",
" plt.legend()\n",
"\n",
" plt.suptitle(\"Performance on validation dataset\")\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "b35f11f8",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py:692: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, device pinned memory won't be used.\n",
" warnings.warn(warn_msg)\n"
]
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"y_pred = trainer.predict(ds_val)\n",
"probas = sp.softmax(y_pred.predictions, axis=1)\n",
"y_proba = probas[:, 1]\n",
"\n",
"y_true = df_val[\"target\"].to_numpy()\n",
"make_plot(y_true, y_proba)"
]
},
{
"cell_type": "markdown",
"id": "4b1d7883",
"metadata": {},
"source": [
"Il y a deux principaux problèmes :\n",
"1. **Overfitting** : en regardant les performances au cours de l'entraînement, on voit que la loss pour le train baisse mais pas celle de la validation\n",
"2. **Sur-confiance** : les masses sont aux extrêmes des valeurs de probabilités, rendant difficile le choix du seuil\n",
"\n",
"Cela s'explique parce que nous avons ré-entraîné l'ensemble du modèle (les millions de paramètres) avec un learning rate trop aggressif. \n",
"\n",
"## Amélioration de la procédure\n",
"\n",
"Pour contrer cela, nous allons implémenter plusieurs changement:\n",
"* Découper l'entraînement en deux parties :\n",
" 1. Entraînement en gelant tous les poids, sauf ceux de la couche de classification sur quelques époques\n",
" 2. Entraînement avec l'ensemble des poids sur une à deux époques, avec un learning rate faible\n",
"* Ajouter à la fonction de perte du *label smoothing*\n",
"* Adapter les scheduler de learning rate, quand c'est nécessaire\n",
"* Ajouter du *weight decay*, quand c'est nécessaire\n",
"\n",
"Commençons !\n",
"\n",
"\n",
"\n",
"### Intégration du *label smoothing*\n",
"\n",
"Le *label smoothing* est une technique de régularisation introduite dans un [article](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Szegedy_Rethinking_the_Inception_CVPR_2016_paper.pdf) améliorant l'architecture Inception. L'idée est d'adresser à la fois le surapprentissage et la sur confiance dans la prédiction.\n",
"\n",
"Concrétement on remplace le vecteur cible $y$ de $K$ classes par :\n",
"\n",
"$$\\tilde{y} = (1 - \\alpha) y + \\frac{\\alpha}{K}$$\n",
"\n",
"Avec $\\alpha \\in [0,1]$ issue de la loi uniforme. De cette manière on évite au logits de devenir trop grand.\n",
"\n",
"**Consigne** : Reprendre la classe `WeightedTrainer` pour ajouter dans la couche de [`CrossEntropyLoss`](https://docs.pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) le label smoothin valorisé à 0.05"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "fdfcf227",
"metadata": {},
"outputs": [],
"source": [
"class WeightedTrainer(Trainer):\n",
" \"\"\"Custom Trainer with weighted loss and label smoothing to handle class imbalance.\"\"\"\n",
"\n",
" def compute_loss(\n",
" self,\n",
" model,\n",
" inputs,\n",
" return_outputs=False,\n",
" num_items_in_batch=None,\n",
" ):\n",
" \"\"\"Compute the weighted cross-entropy loss with label smoothing for class imbalance.\"\"\"\n",
" labels = inputs.pop(\"labels\")\n",
" outputs = model(**inputs)\n",
" logits = outputs.get(\"logits\")\n",
"\n",
" loss_fct = torch.nn.CrossEntropyLoss(\n",
" weight=class_weights.to(logits.device),\n",
" label_smoothing=0.05,\n",
" )\n",
" loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))\n",
"\n",
" return (loss, outputs) if return_outputs else loss\n"
]
},
{
"cell_type": "markdown",
"id": "92f61fa3",
"metadata": {},
"source": [
"### Premier entraînement\n",
"\n",
"Nous avons besoin de réinitialiser le modèle et geler l'ensemble des couches, sauf celle de classification."
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "83a5631e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"592130 trainable parameters, 0.72% of total model parameters\n"
]
}
],
"source": [
"model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)\n",
"\n",
"for name, param in model.named_parameters():\n",
" if not name.startswith(\"classifier\"):\n",
" param.requires_grad = False\n",
"\n",
"n_total_parameters = sum(parameters.numel() for parameters in model.parameters())\n",
"n_trainable_parameters = sum(\n",
" parameters.numel() for parameters in model.parameters() if parameters.requires_grad\n",
")\n",
"rate = n_trainable_parameters / n_total_parameters\n",
"\n",
"print(\n",
" f\"{n_trainable_parameters} trainable parameters, {100 * rate:.2f}% of total model parameters\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "9f95be92",
"metadata": {},
"source": [
"**Consigne** : Reprendre le paramètrage de `TrainingArguments` et lancer l'entraînement. Stocker les paramétrages, modèles et entraînement dans des variables différentes."
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "f682e048",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/arthurdanjou/Workspace/studies/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py:692: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, device pinned memory won't be used.\n",
" warnings.warn(warn_msg)\n"
]
},
{
"data": {
"text/html": [
"\n",
"