mirror of
https://github.com/ArthurDanjou/ArtStudies.git
synced 2026-01-31 04:29:31 +01:00
1992 lines
136 KiB
Plaintext
1992 lines
136 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "15419443",
|
||
"metadata": {},
|
||
"source": [
|
||
"# TP4 - Agents\n",
|
||
"\n",
|
||
"Dans ce notebook on s'intéresse à la méthodologie agent en utilisant le framework technique LangChain. Notre application sera l'api de la World Bank Data qui est une source de données à l'échelle mondiale sur l'état, d'un point de vue statistique, d'un pays.\n",
|
||
"La difficulté de cette API est qu'elle est très riche, avec des indicateurs ayant des codes complexes. De plus, les codes des pays sont spécifiques, la manière de requêter est précise.\n",
|
||
"\n",
|
||
"Pour exploiter du mieux possible les données disponibles, on se propose de définir un agent dont le rôle est de traiter une questions posées en langage naturel et requêter puis analyser les résultats.\n",
|
||
"\n",
|
||
"## Trouver le bon indicateur\n",
|
||
"\n",
|
||
"On commence par le premier enjeu : traduire une question en une liste d'indicateurs potentiel permettant d'aider à répondre. Pour ce faire, nous avons récupérer l'ensemble des indicateurs disponible et la descriptions associées."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 54,
|
||
"id": "2eaa0a80",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"application/vnd.microsoft.datawrangler.viewer.v0+json": {
|
||
"columns": [
|
||
{
|
||
"name": "index",
|
||
"rawType": "int64",
|
||
"type": "integer"
|
||
},
|
||
{
|
||
"name": "indicator",
|
||
"rawType": "object",
|
||
"type": "string"
|
||
},
|
||
{
|
||
"name": "description",
|
||
"rawType": "object",
|
||
"type": "string"
|
||
},
|
||
{
|
||
"name": "source",
|
||
"rawType": "object",
|
||
"type": "string"
|
||
},
|
||
{
|
||
"name": "source_name",
|
||
"rawType": "object",
|
||
"type": "string"
|
||
},
|
||
{
|
||
"name": "index",
|
||
"rawType": "int64",
|
||
"type": "integer"
|
||
},
|
||
{
|
||
"name": "embedding_text",
|
||
"rawType": "object",
|
||
"type": "string"
|
||
}
|
||
],
|
||
"ref": "845e8479-ec9a-438a-96e2-582d02120544",
|
||
"rows": [
|
||
[
|
||
"0",
|
||
"1.0.HCount.1.90usd",
|
||
"Poverty Headcount ($1.90 a day)",
|
||
"topics",
|
||
"Poverty",
|
||
"11",
|
||
"Indicator code: 1.0.HCount.1.90usd\nIndicator description: Poverty Headcount ($1.90 a day)\nIndicator source name: Poverty"
|
||
],
|
||
[
|
||
"1",
|
||
"1.0.HCount.2.5usd",
|
||
"Poverty Headcount ($2.50 a day)",
|
||
"topics",
|
||
"Poverty",
|
||
"11",
|
||
"Indicator code: 1.0.HCount.2.5usd\nIndicator description: Poverty Headcount ($2.50 a day)\nIndicator source name: Poverty"
|
||
],
|
||
[
|
||
"2",
|
||
"1.0.HCount.Mid10to50",
|
||
"Middle Class ($10-50 a day) Headcount",
|
||
"topics",
|
||
"Poverty",
|
||
"11",
|
||
"Indicator code: 1.0.HCount.Mid10to50\nIndicator description: Middle Class ($10-50 a day) Headcount\nIndicator source name: Poverty"
|
||
],
|
||
[
|
||
"3",
|
||
"1.0.HCount.Ofcl",
|
||
"Official Moderate Poverty Rate-National",
|
||
"topics",
|
||
"Poverty",
|
||
"11",
|
||
"Indicator code: 1.0.HCount.Ofcl\nIndicator description: Official Moderate Poverty Rate-National\nIndicator source name: Poverty"
|
||
],
|
||
[
|
||
"4",
|
||
"1.0.HCount.Poor4uds",
|
||
"Poverty Headcount ($4 a day)",
|
||
"topics",
|
||
"Poverty",
|
||
"11",
|
||
"Indicator code: 1.0.HCount.Poor4uds\nIndicator description: Poverty Headcount ($4 a day)\nIndicator source name: Poverty"
|
||
]
|
||
],
|
||
"shape": {
|
||
"columns": 6,
|
||
"rows": 5
|
||
}
|
||
},
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>indicator</th>\n",
|
||
" <th>description</th>\n",
|
||
" <th>source</th>\n",
|
||
" <th>source_name</th>\n",
|
||
" <th>index</th>\n",
|
||
" <th>embedding_text</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>1.0.HCount.1.90usd</td>\n",
|
||
" <td>Poverty Headcount ($1.90 a day)</td>\n",
|
||
" <td>topics</td>\n",
|
||
" <td>Poverty</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>Indicator code: 1.0.HCount.1.90usd\\nIndicator ...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>1.0.HCount.2.5usd</td>\n",
|
||
" <td>Poverty Headcount ($2.50 a day)</td>\n",
|
||
" <td>topics</td>\n",
|
||
" <td>Poverty</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>Indicator code: 1.0.HCount.2.5usd\\nIndicator d...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>1.0.HCount.Mid10to50</td>\n",
|
||
" <td>Middle Class ($10-50 a day) Headcount</td>\n",
|
||
" <td>topics</td>\n",
|
||
" <td>Poverty</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>Indicator code: 1.0.HCount.Mid10to50\\nIndicato...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>1.0.HCount.Ofcl</td>\n",
|
||
" <td>Official Moderate Poverty Rate-National</td>\n",
|
||
" <td>topics</td>\n",
|
||
" <td>Poverty</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>Indicator code: 1.0.HCount.Ofcl\\nIndicator des...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>1.0.HCount.Poor4uds</td>\n",
|
||
" <td>Poverty Headcount ($4 a day)</td>\n",
|
||
" <td>topics</td>\n",
|
||
" <td>Poverty</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>Indicator code: 1.0.HCount.Poor4uds\\nIndicator...</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" indicator description source \\\n",
|
||
"0 1.0.HCount.1.90usd Poverty Headcount ($1.90 a day) topics \n",
|
||
"1 1.0.HCount.2.5usd Poverty Headcount ($2.50 a day) topics \n",
|
||
"2 1.0.HCount.Mid10to50 Middle Class ($10-50 a day) Headcount topics \n",
|
||
"3 1.0.HCount.Ofcl Official Moderate Poverty Rate-National topics \n",
|
||
"4 1.0.HCount.Poor4uds Poverty Headcount ($4 a day) topics \n",
|
||
"\n",
|
||
" source_name index embedding_text \n",
|
||
"0 Poverty 11 Indicator code: 1.0.HCount.1.90usd\\nIndicator ... \n",
|
||
"1 Poverty 11 Indicator code: 1.0.HCount.2.5usd\\nIndicator d... \n",
|
||
"2 Poverty 11 Indicator code: 1.0.HCount.Mid10to50\\nIndicato... \n",
|
||
"3 Poverty 11 Indicator code: 1.0.HCount.Ofcl\\nIndicator des... \n",
|
||
"4 Poverty 11 Indicator code: 1.0.HCount.Poor4uds\\nIndicator... "
|
||
]
|
||
},
|
||
"execution_count": 54,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"\n",
|
||
"def build_embedding_text(row: pd.Series) -> str:\n",
|
||
" \"\"\"Build the text to be embedded for a given row in the indicators dataframe.\"\"\"\n",
|
||
" return f\"\"\"Indicator code: {row[\"indicator\"]}\\nIndicator description: {row[\"description\"]}\\nIndicator source name: {row[\"source_name\"]}\"\"\"\n",
|
||
"\n",
|
||
"\n",
|
||
"df_indicators = pd.read_csv(\"./data/WBData_indicators.csv\")\n",
|
||
"\n",
|
||
"df_indicators[\"embedding_text\"] = df_indicators.apply(build_embedding_text, axis=1)\n",
|
||
"df_indicators.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "bbabba49",
|
||
"metadata": {},
|
||
"source": [
|
||
"Nous allons construire un embedding de la colonne *embedding_text* pour pouvoir dans un second temps la requêter avec la méthoode FAISS.\n",
|
||
"\n",
|
||
"**Consigne** : En exploitant les fonctions `build_faiss_index` et `retrieve_index` dans le module `rag_utils`, constuire l'embedding et le tester sur question."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 55,
|
||
"id": "9aff2854",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import faiss\n",
|
||
"from sentence_transformers import SentenceTransformer\n",
|
||
"\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"MODEL_NAME = \"all-MiniLM-L6-v2\"\n",
|
||
"model_embedding = SentenceTransformer(MODEL_NAME)\n",
|
||
"\n",
|
||
"\n",
|
||
"def build_faiss_index(\n",
|
||
" texts: [str],\n",
|
||
" show: bool = True, # noqa: FBT001, FBT002\n",
|
||
" batch_size: int = 64,\n",
|
||
") -> faiss.IndexFlatIP:\n",
|
||
" \"\"\"Build a FAISS index from a list of texts.\"\"\"\n",
|
||
" embeddings = model_embedding.encode(\n",
|
||
" texts,\n",
|
||
" batch_size=batch_size,\n",
|
||
" show_progress_bar=show,\n",
|
||
" normalize_embeddings=True,\n",
|
||
" )\n",
|
||
" dimension = embeddings.shape[1]\n",
|
||
" index = faiss.IndexFlatIP(dimension)\n",
|
||
" index.add(embeddings)\n",
|
||
" return index\n",
|
||
"\n",
|
||
"\n",
|
||
"def retrieve_index(\n",
|
||
" dataframe: pd.DataFrame,\n",
|
||
" query: str,\n",
|
||
" index: faiss.IndexFlatIP,\n",
|
||
" k: int = 10,\n",
|
||
") -> pd.DataFrame:\n",
|
||
" \"\"\"Retrieve the top k most similar entries from the dataframe for the given query.\"\"\"\n",
|
||
" query_embbeding = model_embedding.encode([query], normalize_embeddings=True).astype(\n",
|
||
" \"float32\",\n",
|
||
" )\n",
|
||
" scores, indices = index.search(query_embbeding, k)\n",
|
||
"\n",
|
||
" results = dataframe.iloc[indices[0]].copy()\n",
|
||
" results[\"score\"] = scores[0]\n",
|
||
"\n",
|
||
" return results.sort_values(\"score\", ascending=False)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 56,
|
||
"id": "8716ab99",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Batches: 100%|██████████| 409/409 [00:45<00:00, 9.08it/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.microsoft.datawrangler.viewer.v0+json": {
|
||
"columns": [
|
||
{
|
||
"name": "index",
|
||
"rawType": "int64",
|
||
"type": "integer"
|
||
},
|
||
{
|
||
"name": "indicator",
|
||
"rawType": "object",
|
||
"type": "string"
|
||
},
|
||
{
|
||
"name": "description",
|
||
"rawType": "object",
|
||
"type": "string"
|
||
},
|
||
{
|
||
"name": "score",
|
||
"rawType": "float32",
|
||
"type": "float"
|
||
}
|
||
],
|
||
"ref": "6e8d5b84-f1f7-4d37-9a5f-a4696cab9084",
|
||
"rows": [
|
||
[
|
||
"6689",
|
||
"HD.HCI.LAYS",
|
||
"Learning-Adjusted Years of School",
|
||
"0.6565653"
|
||
],
|
||
[
|
||
"14239",
|
||
"SE.PRM.INFR",
|
||
"Basic Infrastructure",
|
||
"0.65512097"
|
||
],
|
||
[
|
||
"17416",
|
||
"UIS.ESG.LOWERSEC.NCOG.ENJO.M",
|
||
"Percentage of students in lower secondary education showing proficiency in knowledge of environmental science and geoscience, Non-cognitive dimension, Enjoyment, male (%)",
|
||
"0.645964"
|
||
],
|
||
[
|
||
"14421",
|
||
"SE.PRM.PEDG.4.M",
|
||
"(De Facto) Percent of teachers with good practices on socioemotional skills (3 or above on Teach Socioemotional Skills score) - Male",
|
||
"0.6426065"
|
||
],
|
||
[
|
||
"14403",
|
||
"SE.PRM.PEDG",
|
||
"Teacher Pedagogical Skills",
|
||
"0.6393548"
|
||
]
|
||
],
|
||
"shape": {
|
||
"columns": 3,
|
||
"rows": 5
|
||
}
|
||
},
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>indicator</th>\n",
|
||
" <th>description</th>\n",
|
||
" <th>score</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>6689</th>\n",
|
||
" <td>HD.HCI.LAYS</td>\n",
|
||
" <td>Learning-Adjusted Years of School</td>\n",
|
||
" <td>0.656565</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>14239</th>\n",
|
||
" <td>SE.PRM.INFR</td>\n",
|
||
" <td>Basic Infrastructure</td>\n",
|
||
" <td>0.655121</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>17416</th>\n",
|
||
" <td>UIS.ESG.LOWERSEC.NCOG.ENJO.M</td>\n",
|
||
" <td>Percentage of students in lower secondary educ...</td>\n",
|
||
" <td>0.645964</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>14421</th>\n",
|
||
" <td>SE.PRM.PEDG.4.M</td>\n",
|
||
" <td>(De Facto) Percent of teachers with good pract...</td>\n",
|
||
" <td>0.642606</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>14403</th>\n",
|
||
" <td>SE.PRM.PEDG</td>\n",
|
||
" <td>Teacher Pedagogical Skills</td>\n",
|
||
" <td>0.639355</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" indicator \\\n",
|
||
"6689 HD.HCI.LAYS \n",
|
||
"14239 SE.PRM.INFR \n",
|
||
"17416 UIS.ESG.LOWERSEC.NCOG.ENJO.M \n",
|
||
"14421 SE.PRM.PEDG.4.M \n",
|
||
"14403 SE.PRM.PEDG \n",
|
||
"\n",
|
||
" description score \n",
|
||
"6689 Learning-Adjusted Years of School 0.656565 \n",
|
||
"14239 Basic Infrastructure 0.655121 \n",
|
||
"17416 Percentage of students in lower secondary educ... 0.645964 \n",
|
||
"14421 (De Facto) Percent of teachers with good pract... 0.642606 \n",
|
||
"14403 Teacher Pedagogical Skills 0.639355 "
|
||
]
|
||
},
|
||
"execution_count": 56,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"index = build_faiss_index(df_indicators[\"embedding_text\"].tolist())\n",
|
||
"query = \"What are the indicators related to education?\"\n",
|
||
"results = retrieve_index(df_indicators, query, index, k=5)\n",
|
||
"results[[\"indicator\", \"description\", \"score\"]]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "206c6280",
|
||
"metadata": {},
|
||
"source": [
|
||
"Il faut maintenant constuire une fonction pour que ça soit un outil de notre agent.\n",
|
||
"\n",
|
||
"**Consigne** : Construire une fonction `retrieve_indicators` qui à partir d'une question et d'un nombre d'indicateurs à renvoyer, renvoie les indicateurs les plus pertinents (selon l'embedding) pour la question.\n",
|
||
"Dans un soucis de simplicité, on ne renverras que trois colonnes : l'indicateur, sa description et le score associé."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 57,
|
||
"id": "583967a2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def retrieve_indicators(query: str, k: int = 10) -> pd.DataFrame:\n",
|
||
" \"\"\"Retrieve the top k most similar indicators for the given query.\"\"\"\n",
|
||
" index = build_faiss_index(df_indicators[\"embedding_text\"].tolist())\n",
|
||
" results = retrieve_index(df_indicators, query, index, k=k)\n",
|
||
" return results[[\"indicator\", \"description\", \"score\"]]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c7bcbb11",
|
||
"metadata": {},
|
||
"source": [
|
||
"Nous sommes maintenant capables d'identifier des indicateurs pertinent pour une question donnée. Il faut maintenant obtenir les codes des pays qui nous intéresse.\n",
|
||
"\n",
|
||
"## Identifier les codes de pays\n",
|
||
"\n",
|
||
"De la même manière que précédemment, nous avons l'ensemble des valeurs disponibles."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 58,
|
||
"id": "190b9427",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"application/vnd.microsoft.datawrangler.viewer.v0+json": {
|
||
"columns": [
|
||
{
|
||
"name": "index",
|
||
"rawType": "int64",
|
||
"type": "integer"
|
||
},
|
||
{
|
||
"name": "Code",
|
||
"rawType": "object",
|
||
"type": "string"
|
||
},
|
||
{
|
||
"name": "Description",
|
||
"rawType": "object",
|
||
"type": "string"
|
||
},
|
||
{
|
||
"name": "embedding_text",
|
||
"rawType": "object",
|
||
"type": "string"
|
||
}
|
||
],
|
||
"ref": "a36c39e7-624f-4363-8af1-fbf22836367b",
|
||
"rows": [
|
||
[
|
||
"0",
|
||
"ABW",
|
||
"Aruba",
|
||
"Code: ABW\nDescription: Aruba"
|
||
],
|
||
[
|
||
"1",
|
||
"AFE",
|
||
"Africa Eastern and Southern",
|
||
"Code: AFE\nDescription: Africa Eastern and Southern"
|
||
],
|
||
[
|
||
"2",
|
||
"AFG",
|
||
"Afghanistan",
|
||
"Code: AFG\nDescription: Afghanistan"
|
||
],
|
||
[
|
||
"3",
|
||
"AFR",
|
||
"Africa",
|
||
"Code: AFR\nDescription: Africa"
|
||
],
|
||
[
|
||
"4",
|
||
"AFW",
|
||
"Africa Western and Central",
|
||
"Code: AFW\nDescription: Africa Western and Central"
|
||
]
|
||
],
|
||
"shape": {
|
||
"columns": 3,
|
||
"rows": 5
|
||
}
|
||
},
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Code</th>\n",
|
||
" <th>Description</th>\n",
|
||
" <th>embedding_text</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>ABW</td>\n",
|
||
" <td>Aruba</td>\n",
|
||
" <td>Code: ABW\\nDescription: Aruba</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>AFE</td>\n",
|
||
" <td>Africa Eastern and Southern</td>\n",
|
||
" <td>Code: AFE\\nDescription: Africa Eastern and Sou...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>AFG</td>\n",
|
||
" <td>Afghanistan</td>\n",
|
||
" <td>Code: AFG\\nDescription: Afghanistan</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>AFR</td>\n",
|
||
" <td>Africa</td>\n",
|
||
" <td>Code: AFR\\nDescription: Africa</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>AFW</td>\n",
|
||
" <td>Africa Western and Central</td>\n",
|
||
" <td>Code: AFW\\nDescription: Africa Western and Cen...</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Code Description \\\n",
|
||
"0 ABW Aruba \n",
|
||
"1 AFE Africa Eastern and Southern \n",
|
||
"2 AFG Afghanistan \n",
|
||
"3 AFR Africa \n",
|
||
"4 AFW Africa Western and Central \n",
|
||
"\n",
|
||
" embedding_text \n",
|
||
"0 Code: ABW\\nDescription: Aruba \n",
|
||
"1 Code: AFE\\nDescription: Africa Eastern and Sou... \n",
|
||
"2 Code: AFG\\nDescription: Afghanistan \n",
|
||
"3 Code: AFR\\nDescription: Africa \n",
|
||
"4 Code: AFW\\nDescription: Africa Western and Cen... "
|
||
]
|
||
},
|
||
"execution_count": 58,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"countries = pd.read_csv(\"./data/WBData_countries.csv\")\n",
|
||
"countries[\"embedding_text\"] = countries.apply(\n",
|
||
" lambda row: f\"\"\"Code: {row[\"Code\"]}\\nDescription: {row[\"Description\"]}\"\"\",\n",
|
||
" axis=1,\n",
|
||
")\n",
|
||
"countries.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "db235f9f",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Consigne** : reproduire *mutatis mutandis* le travail précédent dans le cadre des codes de pays."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 59,
|
||
"id": "09c64a5d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def retrieve_countries(query: str, k: int = 10) -> pd.DataFrame:\n",
|
||
" \"\"\"Retrieve the top k most similar countries for the given query.\"\"\"\n",
|
||
" index = build_faiss_index(countries[\"embedding_text\"].tolist())\n",
|
||
" results = retrieve_index(countries, query, index, k=k)\n",
|
||
" return results[[\"Code\", \"Description\", \"score\"]]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "cd7f259e",
|
||
"metadata": {},
|
||
"source": [
|
||
"A présent nous avons deux fonctions : une pour indentifier les indicateurs pertinent, une pour identifier les codes pays d'intérêts.\n",
|
||
"\n",
|
||
"## Appeler correctement l'API\n",
|
||
"\n",
|
||
"L'objectif est à présent de requêter l'API correctement. Par exemple pour l'inflation annuelle en pourcentage, vue par le consommateur, en France entre 2015 et 2020, la requête est :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 60,
|
||
"id": "9abc67e3",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"date\n",
|
||
"2020 0.476499\n",
|
||
"2019 1.108255\n",
|
||
"2018 1.850815\n",
|
||
"2017 1.032283\n",
|
||
"2016 0.183335\n",
|
||
"2015 0.037514\n",
|
||
"Name: value, dtype: float64"
|
||
]
|
||
},
|
||
"execution_count": 60,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import wbdata\n",
|
||
"\n",
|
||
"code = \"FP.CPI.TOTL.ZG\"\n",
|
||
"country = \"FRA\"\n",
|
||
"series = wbdata.get_series(code, date=(\"2015\", \"2020\"), country=[country])\n",
|
||
"series"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a9ff0f02",
|
||
"metadata": {},
|
||
"source": [
|
||
"Si on souhaite l'obtenir pour la France et l'Allemagne :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 61,
|
||
"id": "99ff0a12",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"country date\n",
|
||
"Germany 2020 0.144878\n",
|
||
" 2019 1.445660\n",
|
||
" 2018 1.732169\n",
|
||
" 2017 1.509495\n",
|
||
" 2016 0.491747\n",
|
||
" 2015 0.514426\n",
|
||
"France 2020 0.476499\n",
|
||
" 2019 1.108255\n",
|
||
" 2018 1.850815\n",
|
||
" 2017 1.032283\n",
|
||
" 2016 0.183335\n",
|
||
" 2015 0.037514\n",
|
||
"Name: value, dtype: float64"
|
||
]
|
||
},
|
||
"execution_count": 61,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"code = \"FP.CPI.TOTL.ZG\"\n",
|
||
"country = [\"FRA\", \"DEU\"]\n",
|
||
"series = wbdata.get_series(code, date=(\"2015\", \"2020\"), country=country)\n",
|
||
"series"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "e427b5a0",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Consigne** : Compléter la fonction `query_api` dont l'objectif est de requêter l'API World Bank Data. A partir d'un code d'indicateur, d'une période en années et d'un code de pays, retourner un DataFrame simple de lecture."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 62,
|
||
"id": "b6a089ce",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def query_api(\n",
|
||
" country_codes: [str],\n",
|
||
" indicator_codes: [str],\n",
|
||
" date: tuple[str, str],\n",
|
||
") -> pd.DataFrame:\n",
|
||
" \"\"\"Query the World Bank API for the given countries and indicators.\"\"\"\n",
|
||
" data = wbdata.get_data(\n",
|
||
" indicator_codes,\n",
|
||
" country=country_codes,\n",
|
||
" data_date=date,\n",
|
||
" )\n",
|
||
" records = []\n",
|
||
" for country_data in data:\n",
|
||
" country_code = country_data[\"country\"][\"id\"]\n",
|
||
" for indicator_code, entries in country_data[\"indicators\"].items():\n",
|
||
" for entry in entries:\n",
|
||
" record = {\n",
|
||
" \"country_code\": country_code,\n",
|
||
" \"indicator_code\": indicator_code,\n",
|
||
" \"date\": entry[\"date\"],\n",
|
||
" \"value\": entry[\"value\"],\n",
|
||
" }\n",
|
||
" records.append(record)\n",
|
||
" return pd.DataFrame(records)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "5a5aaeea",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Agent ! \n",
|
||
"\n",
|
||
"Passons maintenant à l'étape de conception de l'agent, nous utiliserons LangGraph pour le faire.\n",
|
||
"\n",
|
||
"Pour commencer, faisons le récapitulatif des fonctions que nous avons construites :\n",
|
||
"1. `retrieve_indicators` : récupérer les codes pertinents pour une question donnée\n",
|
||
"2. `retrieve_countries` : récupérer les codes pays pertinents pour une question donnée\n",
|
||
"3. `query_api` : requêter proprement l'API d'intérêt"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 63,
|
||
"id": "1b58d076",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"TOOL_REGISTRY = {\n",
|
||
" \"retrieve_indicators\": retrieve_indicators,\n",
|
||
" \"retrieve_countries\": retrieve_countries,\n",
|
||
" \"query_api\": query_api,\n",
|
||
"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "49459ab3",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Structure de l'agent\n",
|
||
"\n",
|
||
"Dans LangGraph, un agent est un graphe d'actions et d'informations évolutifs. Nous allons définir :\n",
|
||
"1. Une **classe** qui représentera l'agent. L'objectif est de pouvoir stocker et utiliser par la suite à bon essient les avancées de l'agent.\n",
|
||
"2. Des **méthodes** qui sont les sommets dans le graphes d'actions de l'agent. Ces méthodes exploitent les informations contenues dans la classe.\n",
|
||
"3. Un **graphe** qui va permettre de définir comment passer d'un sommet à l'autre.\n",
|
||
"\n",
|
||
"Commençons par la classe. Nous appelons l'agent à partir d'une **question** sur laquelle il va **réfléchir** et élaborer un plan d'action qui ici se résumé à des **appels** aux fonctions que nous avons définies. Puis, exploitant les **résultats** de ces appels, l'agent pourra **requêter** l'API puis **analyser** le résultat, à la lumière de la question posée initialement.\n",
|
||
"\n",
|
||
"Nous allons exactement coder cela, chaque mot en gras correspond à une information que l'on va stocker et exploiter au moment opportun."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 64,
|
||
"id": "08461b31",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from typing import Any, TypedDict\n",
|
||
"\n",
|
||
"from pydantic import BaseModel, Field\n",
|
||
"\n",
|
||
"\n",
|
||
"class AgentState(TypedDict):\n",
|
||
" \"\"\"State of the agent during its reasoning process.\"\"\"\n",
|
||
"\n",
|
||
" question: str\n",
|
||
" tool_thoughts: str\n",
|
||
" query_thoughts: str\n",
|
||
" tool_calls: list[dict]\n",
|
||
" tool_results: list[Any]\n",
|
||
" query_calls: list[dict]\n",
|
||
" query_results: list[dict]\n",
|
||
" answer: str"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "bc6040ac",
|
||
"metadata": {},
|
||
"source": [
|
||
"Pour travailler, nous n'utiliserons qu'un seul modèle, ici Gemma3 version 12B. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 65,
|
||
"id": "ebb591c8",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from langchain_ollama import OllamaLLM\n",
|
||
"\n",
|
||
"model = OllamaLLM(model=\"gemma3:12b\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "055fd5fe",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Préparation de l'appel\n",
|
||
"\n",
|
||
"La première étape consiste à préparer l'appel à l'API. Nous avons besoin que le modèle puisse à la fois nous liver sa réflexion et les outils qu'il souhaite exploiter.\n",
|
||
"Nous allons donc définir notre propre *parser* du résultat du modèle :\n",
|
||
"\n",
|
||
"1. **Raisonnement** : pour auditer la réflexion du modèle. Ici, simplement du texte suffit.\n",
|
||
"2. **Outil** : pour savoir quels outils appeler avec quel paramètrage. Ici nous devons définir une liste d'appels."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 66,
|
||
"id": "081174ae",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"The output should be formatted as a JSON instance that conforms to the JSON schema below.\n",
|
||
"\n",
|
||
"As an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]}\n",
|
||
"the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance of the schema. The object {\"properties\": {\"foo\": [\"bar\", \"baz\"]}} is not well-formatted.\n",
|
||
"\n",
|
||
"Here is the output schema:\n",
|
||
"```\n",
|
||
"{\"$defs\": {\"ToolCall\": {\"description\": \"Representation of a tool call.\", \"properties\": {\"name\": {\"description\": \"Name of the tool to call\", \"title\": \"Name\", \"type\": \"string\"}, \"args\": {\"additionalProperties\": true, \"description\": \"Arguments for the tool\", \"title\": \"Args\", \"type\": \"object\"}}, \"required\": [\"name\", \"args\"], \"title\": \"ToolCall\", \"type\": \"object\"}}, \"description\": \"Output schema for reasoning steps.\", \"properties\": {\"thought\": {\"description\": \"Short explanation of reasoning\", \"title\": \"Thought\", \"type\": \"string\"}, \"tools\": {\"description\": \"Tools to call\", \"items\": {\"$ref\": \"#/$defs/ToolCall\"}, \"title\": \"Tools\", \"type\": \"array\"}}, \"required\": [\"thought\", \"tools\"]}\n",
|
||
"```\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from langchain_core.output_parsers import PydanticOutputParser\n",
|
||
"\n",
|
||
"\n",
|
||
"class ToolCall(BaseModel):\n",
|
||
" \"\"\"Representation of a tool call.\"\"\"\n",
|
||
"\n",
|
||
" name: str = Field(description=\"Name of the tool to call\")\n",
|
||
" args: dict[str, Any] = Field(description=\"Arguments for the tool\")\n",
|
||
"\n",
|
||
"\n",
|
||
"class ReasoningOutput(BaseModel):\n",
|
||
" \"\"\"Output schema for reasoning steps.\"\"\"\n",
|
||
"\n",
|
||
" thought: str = Field(description=\"Short explanation of reasoning\")\n",
|
||
" tools: list[ToolCall] = Field(description=\"Tools to call\")\n",
|
||
"\n",
|
||
"\n",
|
||
"parser = PydanticOutputParser(pydantic_object=ReasoningOutput)\n",
|
||
"format_instructions = parser.get_format_instructions()\n",
|
||
"print(format_instructions)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9e4b951b",
|
||
"metadata": {},
|
||
"source": [
|
||
"Nous avons en plus les instructions à fournir au modèle pour formatter la réponse.\n",
|
||
"\n",
|
||
"Nous pouvons enfin créer notre premier sommet/noeud du graphe :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 67,
|
||
"id": "02cd965a",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def preparation_node(state: AgentState) -> AgentState:\n",
|
||
" \"\"\"Prepare node to decide which tools to call based on the question.\"\"\"\n",
|
||
" prompt = f\"\"\"\n",
|
||
"You are a AI agent data analyst using the World Bank Data API to answer the user question. You MUST use the two tools to answer the question.\n",
|
||
"\n",
|
||
"User question:\n",
|
||
"{state[\"question\"]}\n",
|
||
"\n",
|
||
"Available tools:\n",
|
||
"- retrieve_indicators(query: str) : ask a question to see which indicators might help\n",
|
||
"- retrieve_countries(query: str) : ask a question to retrieve the code of a specific country or a geographic zone\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"{format_instructions}\n",
|
||
"\"\"\"\n",
|
||
" response = model.invoke(prompt)\n",
|
||
" decision = parser.parse(response)\n",
|
||
"\n",
|
||
" state[\"tool_thoughts\"] = decision.thought\n",
|
||
" state[\"tool_calls\"] = [\n",
|
||
" {\"name\": tool.name, \"args\": tool.args} for tool in decision.tools\n",
|
||
" ]\n",
|
||
" return state\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "221e51f0",
|
||
"metadata": {},
|
||
"source": [
|
||
"Nous avons appelé le modèle, récupérer sa réponse dans un format facilement exploitable pour nous puis stocké les informations. Il nous faut maintenant concrétement appeler les outils et stocker les résultats : c'est complétement déterministe.\n",
|
||
"\n",
|
||
"### Appels d'outils"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 68,
|
||
"id": "48939e28",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def tool_execution_node(state: AgentState) -> AgentState:\n",
|
||
" \"\"\"Execute the tool calls decided in the preparation node.\"\"\"\n",
|
||
" results = []\n",
|
||
" for call in state[\"tool_calls\"]:\n",
|
||
" tool_name = call[\"name\"]\n",
|
||
" tool_args = call.get(\"args\", {})\n",
|
||
"\n",
|
||
" tool_function = TOOL_REGISTRY.get(tool_name)\n",
|
||
" if tool_function is None:\n",
|
||
" results.append(\n",
|
||
" {\"tool\": tool_name, \"args\": tool_args, \"error\": \"Unknown tool\"},\n",
|
||
" )\n",
|
||
" continue\n",
|
||
"\n",
|
||
" try:\n",
|
||
" result = tool_function(**tool_args)\n",
|
||
" results.append({\"tool\": tool_name, \"args\": tool_args, \"result\": result})\n",
|
||
" except:\n",
|
||
" continue\n",
|
||
"\n",
|
||
" state[\"tool_results\"] = results\n",
|
||
" return state\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "612b4ebe",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Noeuds de requête\n",
|
||
"\n",
|
||
"Après la préparation et la récupération des informations via les outils que nous avons défini plus tôt, nous pouvons définir un nouveau sommet/noeud dont l'objectif est de définir l'appel à la fonction `query_api`.\n",
|
||
"\n",
|
||
"**Consigne** En s'inspirant du sommet `preparation_node`, compléter la cellule ci-dessous."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 69,
|
||
"id": "36935e0f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def query_node(state: AgentState) -> AgentState:\n",
|
||
" \"\"\"Node to decide how to query the World Bank API based on tool results.\"\"\"\n",
|
||
" prompt = f\"\"\"\n",
|
||
"You are in a AI agent system. Based on the following informations, use the tool to help answer the user question.\n",
|
||
"\n",
|
||
"User question:\n",
|
||
"{state[\"question\"]}\n",
|
||
"\n",
|
||
"Your reasoning:\n",
|
||
"{state[\"tool_thoughts\"]}\n",
|
||
"\n",
|
||
"Data retrieved from tools:\n",
|
||
"{state[\"tool_results\"]}\n",
|
||
"\n",
|
||
"Available tool: query_api(code, date, country) with parameters :\n",
|
||
" - code : string linked to an indicator in the world bank data. Example : FR.INR.MMKT\n",
|
||
" - date : string of the year. If you need to query for a period, use a list with the starting and ending year. Examples : \"2025\" or [\"2020\", \"2023\"]\n",
|
||
" - country: string or list of string representing a country with a 3 letter code. Examples : \"FRA\" or [\"UAE\", \"DEU\"]\n",
|
||
"\n",
|
||
"{format_instructions}\n",
|
||
"\"\"\"\n",
|
||
" state[\"answer\"] = model.invoke(prompt)\n",
|
||
" return state"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "1170b4f0",
|
||
"metadata": {},
|
||
"source": [
|
||
"Il reste à exécuter ce que ce sommet/noeud a jugé utile :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 70,
|
||
"id": "f20f5c60",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def query_execution_node(state: AgentState) -> AgentState:\n",
|
||
" \"\"\"Execute the query calls decided in the query node.\"\"\"\n",
|
||
" results = []\n",
|
||
" for call in state[\"query_calls\"]:\n",
|
||
" tool_name = call[\"name\"]\n",
|
||
" tool_args = call.get(\"args\", {})\n",
|
||
"\n",
|
||
" tool_function = TOOL_REGISTRY.get(tool_name)\n",
|
||
" if tool_function is None:\n",
|
||
" results.append(\n",
|
||
" {\"tool\": tool_name, \"args\": tool_args, \"error\": \"Unknown tool\"},\n",
|
||
" )\n",
|
||
" continue\n",
|
||
"\n",
|
||
" result = tool_function(**tool_args)\n",
|
||
" results.append({\"tool\": tool_name, \"args\": tool_args, \"result\": result})\n",
|
||
"\n",
|
||
" state[\"query_results\"] = results\n",
|
||
" return state\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "5986bba8",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Synthèse\n",
|
||
"\n",
|
||
"Après tout ces appels et ces informations collectées, il est temps de les restituer sous forme de simple texte."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 71,
|
||
"id": "2c751849",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def synthesis_node(state: AgentState) -> AgentState:\n",
|
||
" \"\"\"Synthesis node to produce the final answer to the user question.\"\"\"\n",
|
||
" prompt = f\"\"\"\n",
|
||
"You are an expert in data analysis and you used the World Bank Data API.\n",
|
||
"\n",
|
||
"User question:\n",
|
||
"{state[\"question\"]}\n",
|
||
"\n",
|
||
"Your reasoning:\n",
|
||
"{state[\"query_thoughts\"]}\n",
|
||
"\n",
|
||
"Data retrieved from tools:\n",
|
||
"{state[\"query_results\"]}\n",
|
||
"\n",
|
||
"Write a concise and clear answer to the user question, based ONLY on the information you have. Be thoughtful, write in plain text : do not use too much formatting.\n",
|
||
"\"\"\"\n",
|
||
" state[\"answer\"] = model.invoke(prompt)\n",
|
||
" return state"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "15cfb361",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Définition du graphe\n",
|
||
"\n",
|
||
"A ce stade, nous avons des bouts du puzzles mais ils ne sont pas assemblé. C'est ce que nous allons faire maintenant. On commence par ajouter les noeuds/sommets :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 72,
|
||
"id": "31e2e6a5",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<langgraph.graph.state.StateGraph at 0x1288a51d0>"
|
||
]
|
||
},
|
||
"execution_count": 72,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from langgraph.graph import StateGraph\n",
|
||
"\n",
|
||
"graph = StateGraph(AgentState)\n",
|
||
"\n",
|
||
"graph.add_node(\"preparation\", preparation_node)\n",
|
||
"graph.add_node(\"preparation_tools\", tool_execution_node)\n",
|
||
"\n",
|
||
"graph.add_node(\"query\", query_node)\n",
|
||
"graph.add_node(\"query_tools\", query_execution_node)\n",
|
||
"\n",
|
||
"graph.add_node(\"synthesis\", synthesis_node)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "e80eb961",
|
||
"metadata": {},
|
||
"source": [
|
||
"Puis les liens entres les différents noeuds/sommets :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 73,
|
||
"id": "a2f9aebe",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"graph.set_entry_point(\"preparation\")\n",
|
||
"graph.add_edge(\"preparation\", \"preparation_tools\")\n",
|
||
"graph.add_edge(\"preparation_tools\", \"query\")\n",
|
||
"graph.add_edge(\"query\", \"query_tools\")\n",
|
||
"graph.add_edge(\"query_tools\", \"synthesis\")\n",
|
||
"\n",
|
||
"agent = graph.compile()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "18f33d4a",
|
||
"metadata": {},
|
||
"source": [
|
||
"Il est maintenant temps de l'utiliser !"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 74,
|
||
"id": "80f9aa22",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Batches: 100%|██████████| 409/409 [00:48<00:00, 8.45it/s]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"question = \"What the evolution of France's inflation between 2015 and 2025 ? Comment on this value using the inflation in Europe, or in Germany for example.\"\n",
|
||
"\n",
|
||
"initial_state = {\n",
|
||
" \"question\": question,\n",
|
||
" \"tool_thoughts\": \"\",\n",
|
||
" \"query_thoughts\": \"\",\n",
|
||
" \"tool_calls\": [],\n",
|
||
" \"tool_results\": [],\n",
|
||
" \"query_calls\": [],\n",
|
||
" \"query_results\": [],\n",
|
||
" \"answer\": \"\",\n",
|
||
"}\n",
|
||
"\n",
|
||
"result = agent.invoke(initial_state)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b8e91b0d",
|
||
"metadata": {},
|
||
"source": [
|
||
"On peut obtenir le résultat de l'agent ainsi que ces différentes réflexion tout au long du processus :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 75,
|
||
"id": "7355b174",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Unfortunately, I do not have access to external tools or the World Bank Data API to retrieve the data you requested. Therefore, I cannot provide you with the evolution of France’s inflation between 2015 and 2025, nor compare it to European or German inflation rates.\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"To answer your question accurately, I would need to query the World Bank Data API and process the results. If you can provide the data yourself, I would be happy to analyze it and comment on the trends.\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(result[\"answer\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 94,
|
||
"id": "ff17fb70",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"{'question': \"What the evolution of France's inflation between 2015 and 2025 ? Comment on this value using the inflation in Europe, or in Germany for example.\", 'tool_thoughts': \"First, I need to determine which indicator represents inflation. I will use the 'retrieve_indicators' tool to find appropriate indicators.\", 'tool_calls': [{'name': 'retrieve_indicators', 'args': {'query': 'inflation'}}], 'tool_results': [{'tool': 'retrieve_indicators', 'args': {'query': 'inflation'}, 'result': indicator \\\n",
|
||
"6299 FP.CPI.TOTL.ZG \n",
|
||
"9781 NY.GDP.DEFL.KD.ZG \n",
|
||
"6303 FP.WPI.TOTL.ZG \n",
|
||
"9782 NY.GDP.DEFL.KD.ZG.AD \n",
|
||
"6301 FP.FPI.TOTL.ZG \n",
|
||
"9780 NY.GDP.DEFL.87.ZG \n",
|
||
"6298 FP.CPI.TOTL \n",
|
||
"6275 FM.LBL.MQMY.ZG \n",
|
||
"6313 FR.INR.MMKT \n",
|
||
"10930 PI-16 \n",
|
||
"\n",
|
||
" description score \n",
|
||
"6299 Inflation, consumer prices (annual %) 0.518722 \n",
|
||
"9781 Inflation, GDP deflator (annual %) 0.460076 \n",
|
||
"6303 Inflation, wholesale prices (annual %) 0.456499 \n",
|
||
"9782 Inflation, GDP deflator: linked series (annual %) 0.409883 \n",
|
||
"6301 Inflation, food prices (annual %) 0.402444 \n",
|
||
"9780 Inflation, GDP deflator (annual %) 0.401507 \n",
|
||
"6298 Consumer price index (2010 = 100) 0.358742 \n",
|
||
"6275 Money and quasi money growth (annual %) 0.320473 \n",
|
||
"6313 Money market rate (%) 0.313243 \n",
|
||
"10930 Medium term perspective in expenditure budgeting 0.311095 }], 'query_thoughts': 'Structured query plan generated', 'query_results': [], 'query_plan': {'indicator_code': 'FP.CPI.TOTL.ZG', 'countries': ['FRA', 'DEU'], 'start_year': 2015, 'end_year': 2025}, 'answer': \"```tool_code\\nfrom wbdata import wbdata\\nimport pandas as pd\\n\\n# Define the indicator for inflation (CPI, % change)\\nindicator = 'CPI-ACP-RD'\\n\\n# Define the country for France\\ncountry = 'FRA'\\n\\n# Define the year range\\nstart_year = 2015\\nend_year = 2025\\n\\n# Fetch the data for France\\nfrance_inflation = wbdata.get_series(indicator, country, start_year=start_year, end_year=end_year)\\n\\n# Fetch the data for Europe (aggregate region)\\neurope_inflation = wbdata.get_series(indicator, 'EUU', start_year=start_year, end_year=end_year)\\n\\n# Fetch the data for Germany\\ngermany_inflation = wbdata.get_series(indicator, 'DEU', start_year=start_year, end_year=end_year)\\n\\n\\n# Convert to pandas DataFrames for easier handling\\nfrance_inflation_df = france_inflation.to_frame(name='France')\\neurope_inflation_df = europe_inflation.to_frame(name='Europe')\\ngermany_inflation_df = germany_inflation.to_frame(name='Germany')\\n\\n# Combine the dataframes\\ncombined_df = pd.concat([france_inflation_df, europe_inflation_df, germany_inflation_df], axis=1)\\n\\n# Print the combined data\\nprint(combined_df)\\n```\\n\", 'human_feedback': {'approved': True}}\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(result)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ffd8baa7",
|
||
"metadata": {},
|
||
"source": [
|
||
"C'est pas mal ! Mais nous avons probablement eu de la chance : la requête a été bien formulée. Ce ne sera peut-être pas toujours le cas, nous avons besoin de plus de sécurité.\n",
|
||
"\n",
|
||
"## Plus de robustesse dans la génération de la requête\n",
|
||
"\n",
|
||
"On se propose ici d'améliorer cet aspect en modifiant le graphe :\n",
|
||
"- Ajout d'un module de **validation déterministe** des paramètres de la requête sélectionné\n",
|
||
"- Ajout d'un noeud/sommet de revus des paramètres s'il y a une erreur, pour le **réparer**\n",
|
||
"\n",
|
||
"Ces deux ajustements nous demanderons de modifier la classe de définition de l'agent et le graphe.\n",
|
||
"\n",
|
||
"### Validation déterministe\n",
|
||
"\n",
|
||
"On commence par la classe `QueryPlan` :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 76,
|
||
"id": "274e67a5",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class QueryPlan(BaseModel):\n",
|
||
" \"\"\"Plan for querying the World Bank API.\"\"\"\n",
|
||
"\n",
|
||
" indicator_code: str = Field(description=\"World Bank indicator code\")\n",
|
||
" countries: list[str] = Field(description=\"List of 3-letter country codes\")\n",
|
||
" start_year: int = Field(description=\"Start year of the query\")\n",
|
||
" end_year: int = Field(description=\"End year of the query\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "683ddb56",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Consigne** : Modifier la classe de l'agent pour remplacer *query_calls* et *query_results* par :\n",
|
||
"* *query_plan* qui sera un dictionnaire optionnel\n",
|
||
"* *validation_errors* qui sera une liste de string optionnel, pour aider à débugger"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 77,
|
||
"id": "36c8ca29",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def query_plan(\n",
|
||
" question: str,\n",
|
||
") -> QueryPlan:\n",
|
||
" \"\"\"Generate a query plan based on the user question.\"\"\"\n",
|
||
" prompt = f\"\"\"You are an AI agent data analyst using the World Bank Data API to answer the user question. Based on the question, create a plan to query the World Bank API.\n",
|
||
"User question:\n",
|
||
"{question}\n",
|
||
"\n",
|
||
"{format_instructions}\n",
|
||
" \"\"\"\n",
|
||
" response = model.invoke(prompt)\n",
|
||
" return parser.parse(response)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "50ba9a56",
|
||
"metadata": {},
|
||
"source": [
|
||
"Il nous faut modifier la fonction `query_node` pour produire quelque chose de plus simple à valider :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 80,
|
||
"id": "287fe8ec",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from langchain_core.output_parsers import PydanticOutputParser\n",
|
||
"\n",
|
||
"query_plan_parser = PydanticOutputParser(pydantic_object=QueryPlan)\n",
|
||
"query_plan_format = query_plan_parser.get_format_instructions()\n",
|
||
"\n",
|
||
"\n",
|
||
"def query_node(state: AgentState) -> AgentState:\n",
|
||
" \"\"\"Node to decide how to query the World Bank API based on tool results.\"\"\"\n",
|
||
" prompt = f\"\"\"\n",
|
||
"You are an AI agent building a structured query plan for the World Bank Data API.\n",
|
||
"\n",
|
||
"User question:\n",
|
||
"{state[\"question\"]}\n",
|
||
"\n",
|
||
"Retrieved indicators and countries:\n",
|
||
"{state[\"tool_results\"]}\n",
|
||
"\n",
|
||
"Return exactly one query plan with:\n",
|
||
"- indicator_code\n",
|
||
"- countries (list)\n",
|
||
"- start_year\n",
|
||
"- end_year\n",
|
||
"\n",
|
||
"{query_plan_format}\n",
|
||
"\"\"\"\n",
|
||
" response = model.invoke(prompt)\n",
|
||
" plan = query_plan_parser.parse(response)\n",
|
||
"\n",
|
||
" state[\"query_thoughts\"] = \"Structured query plan generated\"\n",
|
||
" state[\"query_plan\"] = plan.model_dump()\n",
|
||
" return state\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "278c2df8",
|
||
"metadata": {},
|
||
"source": [
|
||
"Puis il nous faut un module de validation déterministe :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 81,
|
||
"id": "c16a8f6b",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"VALID_INDICATORS = set(df_indicators[\"indicator\"].values)\n",
|
||
"VALID_COUNTRIES = set(countries[\"Code\"].values)\n",
|
||
"MIN_YEAR = 1960\n",
|
||
"MAX_YEAR = 2025\n",
|
||
"\n",
|
||
"\n",
|
||
"def validation_node(state: AgentState) -> AgentState:\n",
|
||
" \"\"\"Validate the generated query plan.\"\"\"\n",
|
||
" plan = state.get(\"query_plan\")\n",
|
||
" errors = []\n",
|
||
"\n",
|
||
" if plan is None:\n",
|
||
" errors.append(\"No query plan found.\")\n",
|
||
" state[\"validation_errors\"] = errors\n",
|
||
" return state\n",
|
||
"\n",
|
||
" if plan[\"indicator_code\"] not in VALID_INDICATORS:\n",
|
||
" errors.append(f\"Invalid indicator code: {plan['indicator_code']}\")\n",
|
||
"\n",
|
||
" for c in plan[\"countries\"]:\n",
|
||
" if c not in VALID_COUNTRIES:\n",
|
||
" errors.append(f\"Invalid country code: {c}\")\n",
|
||
"\n",
|
||
" if not (MIN_YEAR <= plan[\"start_year\"] <= MAX_YEAR):\n",
|
||
" errors.append(f\"start_year out of bounds: {plan['start_year']}\")\n",
|
||
"\n",
|
||
" if not (MIN_YEAR <= plan[\"end_year\"] <= MAX_YEAR):\n",
|
||
" errors.append(f\"end_year out of bounds: {plan['end_year']}\")\n",
|
||
"\n",
|
||
" state[\"validation_errors\"] = errors if errors else None\n",
|
||
" return state\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "e063ee6c",
|
||
"metadata": {},
|
||
"source": [
|
||
"Il nous faut également, un noeud pour éventuellement ajuster le plan d'appel s'il ne vérifie pas la validation.\n",
|
||
"\n",
|
||
"**Consigne** : Compléter la fonction ci-dessous."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "d98d3031",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def repair_node(state: AgentState) -> AgentState:\n",
|
||
" prompt = f\"\"\"\n",
|
||
"The following query plan is invalid:\n",
|
||
"\n",
|
||
"{state[\"query_plan\"]}\n",
|
||
"\n",
|
||
"Errors:\n",
|
||
"{state[\"validation_errors\"]}\n",
|
||
"\n",
|
||
"Fix the query plan so that:\n",
|
||
"- The indicator exists\n",
|
||
"- Country codes are valid\n",
|
||
"- Years are within bounds\n",
|
||
"\n",
|
||
"Return a corrected query plan only.\n",
|
||
"\n",
|
||
"{query_plan_format}\n",
|
||
"\"\"\"\n",
|
||
" response = model.invoke(prompt)\n",
|
||
" plan = query_plan_parser.parse(response)\n",
|
||
" state[\"query_plan\"] = plan.model_dump()\n",
|
||
" state[\"validation_errors\"] = None\n",
|
||
" return state\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a0653b30",
|
||
"metadata": {},
|
||
"source": [
|
||
"La fonction `query_execution_node` peut être largement simplifiée maintenant que nous avons fait tout ces changements.\n",
|
||
"\n",
|
||
"**Consigne** : Re-définir la fonction `query_execution_node` en prenant en compte le plan stocké dans la mémoire de l'agent."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 83,
|
||
"id": "e9d888ec",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def query_execution_node(state: AgentState) -> AgentState:\n",
|
||
" \"\"\"Execute the query calls decided in the query node.\"\"\"\n",
|
||
" plan = state.get(\"query_plan\")\n",
|
||
" if plan is None:\n",
|
||
" state[\"query_results\"] = []\n",
|
||
" return state\n",
|
||
"\n",
|
||
" result = query_api(\n",
|
||
" country_codes=plan[\"countries\"],\n",
|
||
" indicator_codes=[plan[\"indicator_code\"]],\n",
|
||
" date=(str(plan[\"start_year\"]), str(plan[\"end_year\"])),\n",
|
||
" )\n",
|
||
" state[\"query_results\"] = [result]\n",
|
||
" return state"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "abc0e490",
|
||
"metadata": {},
|
||
"source": [
|
||
"Finalement, nous pouvons ré-écrire le graphe en y ajoutant nos nouveaux sommets et surtout introduire des noeuds optionnels !"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 84,
|
||
"id": "259e91ae",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<langgraph.graph.state.CompiledStateGraph object at 0x15047de50>"
|
||
]
|
||
},
|
||
"execution_count": 84,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from langgraph.graph import StateGraph\n",
|
||
"\n",
|
||
"graph = StateGraph(AgentState)\n",
|
||
"\n",
|
||
"graph.add_node(\"preparation\", preparation_node)\n",
|
||
"graph.add_node(\"preparation_tools\", tool_execution_node)\n",
|
||
"graph.add_node(\"query\", query_node)\n",
|
||
"graph.add_node(\"validate\", validation_node)\n",
|
||
"graph.add_node(\"repair\", repair_node)\n",
|
||
"graph.add_node(\"query_tools\", query_execution_node)\n",
|
||
"graph.add_node(\"synthesis\", synthesis_node)\n",
|
||
"\n",
|
||
"graph.set_entry_point(\"preparation\")\n",
|
||
"graph.add_edge(\"preparation\", \"preparation_tools\")\n",
|
||
"graph.add_edge(\"preparation_tools\", \"query\")\n",
|
||
"graph.add_edge(\"query\", \"validate\")\n",
|
||
"\n",
|
||
"\n",
|
||
"def validation_router(state: AgentState):\n",
|
||
" if state.get(\"validation_errors\"):\n",
|
||
" return \"repair\"\n",
|
||
" return \"query_tools\"\n",
|
||
"\n",
|
||
"\n",
|
||
"graph.add_conditional_edges(\n",
|
||
" \"validate\",\n",
|
||
" validation_router,\n",
|
||
" {\n",
|
||
" \"repair\": \"repair\",\n",
|
||
" \"query_tools\": \"query_tools\",\n",
|
||
" },\n",
|
||
")\n",
|
||
"\n",
|
||
"graph.add_edge(\"repair\", \"validate\")\n",
|
||
"graph.add_edge(\"query_tools\", \"synthesis\")\n",
|
||
"\n",
|
||
"agent = graph.compile()\n",
|
||
"agent\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8c8786d9",
|
||
"metadata": {},
|
||
"source": [
|
||
"C'est une structure plus complexe ! Et a priori plus robuste sur l'appel à l'API. Mais puisque l'agent met du temps à répondre, nous souhaiterions être certain qu'il est parti sur la bonne piste.\n",
|
||
"\n",
|
||
"## *Human in the loop*\n",
|
||
"\n",
|
||
"Il est temps d'intégrer l'humain dans le processus pour valider ou non que le plan identifié est correct. Si ce n'est pas le cas, alors l'humain peut modifier le pla voire préciser sa question.\n",
|
||
"\n",
|
||
"Pour démarrer, nous devons à nouveau modifier la classe de l'agent.\n",
|
||
"\n",
|
||
"**Consigne** : Ajouter un dictionnaire optionnel nommé *human_feedback*."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 88,
|
||
"id": "7d6d2da3",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class AgentState(TypedDict):\n",
|
||
" \"\"\"State of the agent during its reasoning process.\"\"\"\n",
|
||
"\n",
|
||
" question: str\n",
|
||
"\n",
|
||
" tool_thoughts: str\n",
|
||
" tool_calls: list[dict]\n",
|
||
" tool_results: list[Any]\n",
|
||
"\n",
|
||
" query_thoughts: str\n",
|
||
" query_calls: list[dict]\n",
|
||
" query_results: list[dict]\n",
|
||
" query_plan: dict | None\n",
|
||
" retry_count: int\n",
|
||
"\n",
|
||
" answer: str\n",
|
||
" human_feedback: dict | None\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "dee8fb70",
|
||
"metadata": {},
|
||
"source": [
|
||
"On continue avec le noeud concernant l'appel à l'humain. On se propose trois possibilitées :\n",
|
||
"1. **Approuver** : on continue avec ce plan\n",
|
||
"2. **Editer** : on modifie le plan qui est proposé par l'agent, puis on repart sur le module de validation\n",
|
||
"3. **Rejeter** : on reformule la question que l'on a posé, puis on repart sur l'étape de préparation\n",
|
||
"\n",
|
||
"Pour éviter le cas où l'on édite le plan, qu'il est valider par le module de validation et qu'on nous re-sollicite, on ajoute une vérification que l'humain n'a pas déjà été sollicité."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 89,
|
||
"id": "44ea5c40",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def human_approval_node(state: AgentState) -> AgentState:\n",
|
||
" if state.get(\"human_feedback\"):\n",
|
||
" return state\n",
|
||
" plan = state[\"query_plan\"]\n",
|
||
"\n",
|
||
" print(\"\\nProposed query plan:\")\n",
|
||
" print(f\"Indicator: {plan['indicator_code']}\")\n",
|
||
" print(f\"Countries: {', '.join(plan['countries'])}\")\n",
|
||
" print(f\"Years: {plan['start_year']}–{plan['end_year']}\")\n",
|
||
"\n",
|
||
" decision = input(\"\\nApprove? (y = yes / e = edit / n = reject): \").strip().lower()\n",
|
||
"\n",
|
||
" if decision == \"y\":\n",
|
||
" state[\"human_feedback\"] = {\"approved\": True}\n",
|
||
"\n",
|
||
" elif decision == \"e\":\n",
|
||
" edited = input(\"\\nEnter corrected query plan as JSON:\\n\")\n",
|
||
" state[\"human_feedback\"] = {\"approved\": False, \"edited_plan\": edited}\n",
|
||
"\n",
|
||
" else:\n",
|
||
" new_q = input(\"\\nPlease rephrase your question:\\n\")\n",
|
||
" state[\"human_feedback\"] = {\"approved\": False, \"new_question\": new_q}\n",
|
||
"\n",
|
||
" return state\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "74cbce56",
|
||
"metadata": {},
|
||
"source": [
|
||
"Et on obtient alors un nouveau graphe :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 90,
|
||
"id": "e4ae950a",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<langgraph.graph.state.CompiledStateGraph object at 0x1511e2d70>"
|
||
]
|
||
},
|
||
"execution_count": 90,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from langgraph.graph import StateGraph\n",
|
||
"\n",
|
||
"graph = StateGraph(AgentState)\n",
|
||
"\n",
|
||
"graph.add_node(\"preparation\", preparation_node)\n",
|
||
"graph.add_node(\"preparation_tools\", tool_execution_node)\n",
|
||
"graph.add_node(\"query\", query_node)\n",
|
||
"graph.add_node(\"validate\", validation_node)\n",
|
||
"graph.add_node(\"repair\", repair_node)\n",
|
||
"\n",
|
||
"graph.add_node(\"human_approval\", human_approval_node)\n",
|
||
"\n",
|
||
"graph.add_node(\"query_tools\", query_execution_node)\n",
|
||
"graph.add_node(\"synthesis\", synthesis_node)\n",
|
||
"\n",
|
||
"graph.set_entry_point(\"preparation\")\n",
|
||
"graph.add_edge(\"preparation\", \"preparation_tools\")\n",
|
||
"graph.add_edge(\"preparation_tools\", \"query\")\n",
|
||
"graph.add_edge(\"query\", \"validate\")\n",
|
||
"\n",
|
||
"\n",
|
||
"def validation_router(state: AgentState):\n",
|
||
" return \"repair\" if state.get(\"validation_errors\") else \"human_approval\"\n",
|
||
"\n",
|
||
"\n",
|
||
"graph.add_conditional_edges(\n",
|
||
" \"validate\",\n",
|
||
" validation_router,\n",
|
||
" {\n",
|
||
" \"repair\": \"repair\",\n",
|
||
" \"human_approval\": \"human_approval\",\n",
|
||
" },\n",
|
||
")\n",
|
||
"\n",
|
||
"graph.add_edge(\"repair\", \"validate\")\n",
|
||
"\n",
|
||
"\n",
|
||
"def approval_router(state: AgentState):\n",
|
||
" fb = state[\"human_feedback\"]\n",
|
||
"\n",
|
||
" if fb.get(\"approved\"):\n",
|
||
" return \"query_tools\"\n",
|
||
" if \"edited_plan\" in fb:\n",
|
||
" return \"validate\"\n",
|
||
" if \"new_question\" in fb:\n",
|
||
" return \"preparation\"\n",
|
||
" return \"preparation\"\n",
|
||
"\n",
|
||
"\n",
|
||
"graph.add_conditional_edges(\n",
|
||
" \"human_approval\",\n",
|
||
" approval_router,\n",
|
||
" {\n",
|
||
" \"query_tools\": \"query_tools\",\n",
|
||
" \"validate\": \"validate\",\n",
|
||
" \"preparation\": \"preparation\",\n",
|
||
" },\n",
|
||
")\n",
|
||
"\n",
|
||
"\n",
|
||
"graph.add_edge(\"query_tools\", \"synthesis\")\n",
|
||
"\n",
|
||
"agent = graph.compile()\n",
|
||
"agent"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "bfb38954",
|
||
"metadata": {},
|
||
"source": [
|
||
"Testons-le !"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 91,
|
||
"id": "22e64782",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Batches: 100%|██████████| 409/409 [00:47<00:00, 8.52it/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"Proposed query plan:\n",
|
||
"Indicator: FP.CPI.TOTL.ZG\n",
|
||
"Countries: FRA, DEU\n",
|
||
"Years: 2015–2025\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"question = \"What the evolution of France's inflation between 2015 and 2025 ? Comment on this value using the inflation in Europe, or in Germany for example.\"\n",
|
||
"\n",
|
||
"initial_state = {\n",
|
||
" \"question\": question,\n",
|
||
" \"tool_thoughts\": \"\",\n",
|
||
" \"tool_calls\": [],\n",
|
||
" \"tool_results\": [],\n",
|
||
" \"query_thoughts\": \"\",\n",
|
||
" \"query_plan\": {},\n",
|
||
" \"validation_errors\": [],\n",
|
||
" \"human_feedback\": {},\n",
|
||
" \"query_results\": [],\n",
|
||
" \"answer\": \"\",\n",
|
||
"}\n",
|
||
"\n",
|
||
"result = agent.invoke(initial_state)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "9f3b8cf9",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"```tool_code\n",
|
||
"from wbdata import wbdata\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"# Define the indicator for inflation (CPI, % change)\n",
|
||
"indicator = 'CPI-ACP-RD'\n",
|
||
"\n",
|
||
"# Define the country for France\n",
|
||
"country = 'FRA'\n",
|
||
"\n",
|
||
"# Define the year range\n",
|
||
"start_year = 2015\n",
|
||
"end_year = 2025\n",
|
||
"\n",
|
||
"# Fetch the data for France\n",
|
||
"france_inflation = wbdata.get_series(indicator, country, start_year=start_year, end_year=end_year)\n",
|
||
"\n",
|
||
"# Fetch the data for Europe (aggregate region)\n",
|
||
"europe_inflation = wbdata.get_series(indicator, 'EUU', start_year=start_year, end_year=end_year)\n",
|
||
"\n",
|
||
"# Fetch the data for Germany\n",
|
||
"germany_inflation = wbdata.get_series(indicator, 'DEU', start_year=start_year, end_year=end_year)\n",
|
||
"\n",
|
||
"\n",
|
||
"# Convert to pandas DataFrames for easier handling\n",
|
||
"france_inflation_df = france_inflation.to_frame(name='France')\n",
|
||
"europe_inflation_df = europe_inflation.to_frame(name='Europe')\n",
|
||
"germany_inflation_df = germany_inflation.to_frame(name='Germany')\n",
|
||
"\n",
|
||
"# Combine the dataframes\n",
|
||
"combined_df = pd.concat([france_inflation_df, europe_inflation_df, germany_inflation_df], axis=1)\n",
|
||
"\n",
|
||
"# Print the combined data\n",
|
||
"print(combined_df)\n",
|
||
"```\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(result[\"answer\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 93,
|
||
"id": "2a5677ed",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"{'question': \"What the evolution of France's inflation between 2015 and 2025 ? Comment on this value using the inflation in Europe, or in Germany for example.\", 'tool_thoughts': \"First, I need to determine which indicator represents inflation. I will use the 'retrieve_indicators' tool to find appropriate indicators.\", 'tool_calls': [{'name': 'retrieve_indicators', 'args': {'query': 'inflation'}}], 'tool_results': [{'tool': 'retrieve_indicators', 'args': {'query': 'inflation'}, 'result': indicator \\\n",
|
||
"6299 FP.CPI.TOTL.ZG \n",
|
||
"9781 NY.GDP.DEFL.KD.ZG \n",
|
||
"6303 FP.WPI.TOTL.ZG \n",
|
||
"9782 NY.GDP.DEFL.KD.ZG.AD \n",
|
||
"6301 FP.FPI.TOTL.ZG \n",
|
||
"9780 NY.GDP.DEFL.87.ZG \n",
|
||
"6298 FP.CPI.TOTL \n",
|
||
"6275 FM.LBL.MQMY.ZG \n",
|
||
"6313 FR.INR.MMKT \n",
|
||
"10930 PI-16 \n",
|
||
"\n",
|
||
" description score \n",
|
||
"6299 Inflation, consumer prices (annual %) 0.518722 \n",
|
||
"9781 Inflation, GDP deflator (annual %) 0.460076 \n",
|
||
"6303 Inflation, wholesale prices (annual %) 0.456499 \n",
|
||
"9782 Inflation, GDP deflator: linked series (annual %) 0.409883 \n",
|
||
"6301 Inflation, food prices (annual %) 0.402444 \n",
|
||
"9780 Inflation, GDP deflator (annual %) 0.401507 \n",
|
||
"6298 Consumer price index (2010 = 100) 0.358742 \n",
|
||
"6275 Money and quasi money growth (annual %) 0.320473 \n",
|
||
"6313 Money market rate (%) 0.313243 \n",
|
||
"10930 Medium term perspective in expenditure budgeting 0.311095 }], 'query_thoughts': 'Structured query plan generated', 'query_results': [], 'query_plan': {'indicator_code': 'FP.CPI.TOTL.ZG', 'countries': ['FRA', 'DEU'], 'start_year': 2015, 'end_year': 2025}, 'answer': \"```tool_code\\nfrom wbdata import wbdata\\nimport pandas as pd\\n\\n# Define the indicator for inflation (CPI, % change)\\nindicator = 'CPI-ACP-RD'\\n\\n# Define the country for France\\ncountry = 'FRA'\\n\\n# Define the year range\\nstart_year = 2015\\nend_year = 2025\\n\\n# Fetch the data for France\\nfrance_inflation = wbdata.get_series(indicator, country, start_year=start_year, end_year=end_year)\\n\\n# Fetch the data for Europe (aggregate region)\\neurope_inflation = wbdata.get_series(indicator, 'EUU', start_year=start_year, end_year=end_year)\\n\\n# Fetch the data for Germany\\ngermany_inflation = wbdata.get_series(indicator, 'DEU', start_year=start_year, end_year=end_year)\\n\\n\\n# Convert to pandas DataFrames for easier handling\\nfrance_inflation_df = france_inflation.to_frame(name='France')\\neurope_inflation_df = europe_inflation.to_frame(name='Europe')\\ngermany_inflation_df = germany_inflation.to_frame(name='Germany')\\n\\n# Combine the dataframes\\ncombined_df = pd.concat([france_inflation_df, europe_inflation_df, germany_inflation_df], axis=1)\\n\\n# Print the combined data\\nprint(combined_df)\\n```\\n\", 'human_feedback': {'approved': True}}\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(result)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "df94fd32",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "studies",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.13.9"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|