From 661d591b040773bfc5c5a9c2ce2332ac1a648b73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Thu, 27 May 2021 17:05:34 +1200 Subject: [PATCH] Clarify the Decision Tree instability section, fixes #422 --- 06_decision_trees.ipynb | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/06_decision_trees.ipynb b/06_decision_trees.ipynb index 5994c42..ab7f00a 100644 --- a/06_decision_trees.ipynb +++ b/06_decision_trees.ipynb @@ -206,7 +206,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Sensitivity to training set details" + "# High Variance" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We've seen that small changes in the dataset (such as a rotation) may produce a very different Decision Tree.\n", + "Now let's show that training the same model on the same data may produce a very different model every time, since the CART training algorithm used by Scikit-Learn is stochastic. To show this, we will set `random_state` to a different value than earlier:" ] }, { @@ -215,7 +223,8 @@ "metadata": {}, "outputs": [], "source": [ - "X[(X[:, 1]==X[:, 1][y==1].max()) & (y==1)] # widest Iris versicolor flower" + "tree_clf_tweaked = DecisionTreeClassifier(max_depth=2, random_state=40)\n", + "tree_clf_tweaked.fit(X, y)" ] }, { @@ -223,23 +232,9 @@ "execution_count": 8, "metadata": {}, "outputs": [], - "source": [ - "not_widest_versicolor = (X[:, 1]!=1.8) | (y==2)\n", - "X_tweaked = X[not_widest_versicolor]\n", - "y_tweaked = y[not_widest_versicolor]\n", - "\n", - "tree_clf_tweaked = DecisionTreeClassifier(max_depth=2, random_state=40)\n", - "tree_clf_tweaked.fit(X_tweaked, y_tweaked)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], "source": [ "plt.figure(figsize=(8, 4))\n", - "plot_decision_boundary(tree_clf_tweaked, X_tweaked, y_tweaked, legend=False)\n", + "plot_decision_boundary(tree_clf_tweaked, X, y, legend=False)\n", "plt.plot([0, 7.5], [0.8, 0.8], \"k-\", linewidth=2)\n", "plt.plot([0, 7.5], [1.75, 1.75], \"k--\", linewidth=2)\n", "plt.text(1.0, 0.9, \"Depth=0\", fontsize=15)\n",