From 9e7002c88ad17f14bd12f5cc19abad1d28fe276f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Sun, 6 Oct 2024 17:04:51 +1300 Subject: [PATCH] Use tf_keras instead of Keras 3 --- 18_reinforcement_learning.ipynb | 275 +++++++++++++++++--------------- 1 file changed, 147 insertions(+), 128 deletions(-) diff --git a/18_reinforcement_learning.ipynb b/18_reinforcement_learning.ipynb index 9d32862..3c63698 100644 --- a/18_reinforcement_learning.ipynb +++ b/18_reinforcement_learning.ipynb @@ -60,6 +60,25 @@ "assert sys.version_info >= (3, 7)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Warning**: the latest TensorFlow versions are based on Keras 3. For chapters 10-15, it wasn't too hard to update the code to support Keras 3, but unfortunately it's much harder for this chapter, in particular adding custom losses using the functional API is not implemented yet. So for this chapter I've had to revert to Keras 2. To do that, I set the `TF_USE_LEGACY_KERAS` environment variable to `\"1\"` and import the `tf_keras` package. This ensures that `tf.keras` points to `tf_keras`, which is Keras 2.*." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"TF_USE_LEGACY_KERAS\"] = \"1\" \n", + "\n", + "import tf_keras" + ] + }, { "cell_type": "markdown", "metadata": { @@ -71,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": { "id": "0Piq5se2pKzx" }, @@ -94,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": { "id": "8d4TH3NbpKzx" }, @@ -122,7 +141,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": { "id": "PQFH5Y9PpKzy" }, @@ -151,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": { "id": "Ekxzo6pOpKzy" }, @@ -179,7 +198,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -204,7 +223,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -229,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -243,7 +262,7 @@ " '...']" ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -263,7 +282,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -272,7 +291,7 @@ "EnvSpec(id='CartPole-v1', entry_point='gym.envs.classic_control.cartpole:CartPoleEnv', reward_threshold=475.0, nondeterministic=False, max_episode_steps=500, order_enforce=True, autoreset=False, disable_env_checker=False, apply_api_compatibility=False, kwargs={}, namespace=None, name='CartPole', version=1)" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -291,7 +310,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -300,7 +319,7 @@ "array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32)" ] }, - "execution_count": 10, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -312,7 +331,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -321,7 +340,7 @@ "{}" ] }, - "execution_count": 11, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -346,7 +365,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -355,7 +374,7 @@ "(400, 600, 3)" ] }, - "execution_count": 12, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -367,7 +386,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -404,7 +423,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -413,7 +432,7 @@ "Discrete(2)" ] }, - "execution_count": 14, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -438,7 +457,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -447,7 +466,7 @@ "array([ 0.02727336, 0.18847767, 0.03625453, -0.26141977], dtype=float32)" ] }, - "execution_count": 15, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -467,7 +486,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -504,7 +523,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -513,7 +532,7 @@ "1.0" ] }, - "execution_count": 17, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -529,33 +548,6 @@ "When the game is over, the environment returns `done=True`. In this case, it's not over yet:" ] }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "done" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Some environment wrappers may want to interrupt the environment early. For example, when a time limit is reached or when an object goes out of bounds. In this case, `truncated` will be set to `True`. In this case, it's not truncated yet:" - ] - }, { "cell_type": "code", "execution_count": 19, @@ -572,6 +564,33 @@ "output_type": "execute_result" } ], + "source": [ + "done" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Some environment wrappers may want to interrupt the environment early. For example, when a time limit is reached or when an object goes out of bounds. In this case, `truncated` will be set to `True`. In this case, it's not truncated yet:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "truncated" ] @@ -585,7 +604,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -594,7 +613,7 @@ "{}" ] }, - "execution_count": 20, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -612,7 +631,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -643,7 +662,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -667,7 +686,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -676,7 +695,7 @@ "(41.698, 8.389445512070509, 24.0, 63.0)" ] }, - "execution_count": 23, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -703,7 +722,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -763,7 +782,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -800,7 +819,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -853,7 +872,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -885,7 +904,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -919,7 +938,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -948,7 +967,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -957,7 +976,7 @@ "array([-22, -40, -50])" ] }, - "execution_count": 30, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -975,7 +994,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 32, "metadata": { "scrolled": true }, @@ -987,7 +1006,7 @@ " array([1.26665318, 1.0727777 ])]" ] }, - "execution_count": 31, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -999,7 +1018,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -1011,7 +1030,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -1030,7 +1049,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -1040,7 +1059,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 36, "metadata": {}, "outputs": [ { @@ -1076,7 +1095,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ @@ -1101,7 +1120,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 38, "metadata": {}, "outputs": [ { @@ -1169,7 +1188,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -1195,7 +1214,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 40, "metadata": {}, "outputs": [], "source": [ @@ -1206,7 +1225,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ @@ -1228,7 +1247,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 42, "metadata": {}, "outputs": [ { @@ -1239,7 +1258,7 @@ " [ -inf, 50.13365013, -inf]])" ] }, - "execution_count": 41, + "execution_count": 42, "metadata": {}, "output_type": "execute_result" } @@ -1250,7 +1269,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 43, "metadata": {}, "outputs": [ { @@ -1259,7 +1278,7 @@ "array([0, 0, 1])" ] }, - "execution_count": 42, + "execution_count": 43, "metadata": {}, "output_type": "execute_result" } @@ -1298,7 +1317,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 44, "metadata": {}, "outputs": [], "source": [ @@ -1318,7 +1337,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ @@ -1335,7 +1354,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ @@ -1348,7 +1367,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ @@ -1373,7 +1392,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 48, "metadata": {}, "outputs": [ { @@ -1425,7 +1444,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 49, "metadata": {}, "outputs": [], "source": [ @@ -1450,7 +1469,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ @@ -1471,7 +1490,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ @@ -1489,7 +1508,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ @@ -1521,7 +1540,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 53, "metadata": {}, "outputs": [], "source": [ @@ -1543,7 +1562,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 54, "metadata": {}, "outputs": [], "source": [ @@ -1563,7 +1582,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 55, "metadata": {}, "outputs": [], "source": [ @@ -1577,7 +1596,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 56, "metadata": {}, "outputs": [], "source": [ @@ -1613,7 +1632,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 57, "metadata": { "scrolled": true }, @@ -1652,7 +1671,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 58, "metadata": {}, "outputs": [ { @@ -1681,7 +1700,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 59, "metadata": {}, "outputs": [], "source": [ @@ -1712,7 +1731,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 60, "metadata": {}, "outputs": [], "source": [ @@ -1736,7 +1755,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 61, "metadata": {}, "outputs": [], "source": [ @@ -1753,7 +1772,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 62, "metadata": {}, "outputs": [], "source": [ @@ -1797,7 +1816,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 63, "metadata": {}, "outputs": [ { @@ -1846,7 +1865,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 64, "metadata": {}, "outputs": [ { @@ -1874,7 +1893,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 65, "metadata": { "scrolled": true, "tags": [] @@ -1901,7 +1920,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 66, "metadata": {}, "outputs": [], "source": [ @@ -1978,7 +1997,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 67, "metadata": {}, "outputs": [ { @@ -2006,7 +2025,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 68, "metadata": { "scrolled": true }, @@ -2025,7 +2044,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 69, "metadata": {}, "outputs": [], "source": [ @@ -2051,7 +2070,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 70, "metadata": {}, "outputs": [ { @@ -2103,7 +2122,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 71, "metadata": {}, "outputs": [ { @@ -2131,7 +2150,7 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 72, "metadata": { "scrolled": true }, @@ -2150,7 +2169,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 73, "metadata": {}, "outputs": [], "source": [ @@ -2209,7 +2228,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 74, "metadata": {}, "outputs": [], "source": [ @@ -2225,7 +2244,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 75, "metadata": {}, "outputs": [ { @@ -2236,7 +2255,7 @@ " 1. ], (8,), float32)" ] }, - "execution_count": 74, + "execution_count": 75, "metadata": {}, "output_type": "execute_result" } @@ -2247,7 +2266,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 76, "metadata": {}, "outputs": [ { @@ -2257,7 +2276,7 @@ " -0.05269805, 0. , 0. ], dtype=float32)" ] }, - "execution_count": 75, + "execution_count": 76, "metadata": {}, "output_type": "execute_result" } @@ -2287,7 +2306,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 77, "metadata": {}, "outputs": [ { @@ -2296,7 +2315,7 @@ "Discrete(4)" ] }, - "execution_count": 76, + "execution_count": 77, "metadata": {}, "output_type": "execute_result" } @@ -2325,7 +2344,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 78, "metadata": {}, "outputs": [], "source": [ @@ -2359,7 +2378,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 79, "metadata": {}, "outputs": [], "source": [ @@ -2401,7 +2420,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 80, "metadata": {}, "outputs": [], "source": [ @@ -2430,7 +2449,7 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 81, "metadata": {}, "outputs": [], "source": [ @@ -2449,7 +2468,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 82, "metadata": {}, "outputs": [], "source": [ @@ -2466,7 +2485,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 83, "metadata": {}, "outputs": [ { @@ -2510,7 +2529,7 @@ }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 84, "metadata": {}, "outputs": [ { @@ -2543,7 +2562,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 85, "metadata": {}, "outputs": [], "source": [ @@ -2567,7 +2586,7 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 86, "metadata": {}, "outputs": [], "source": [ @@ -2626,7 +2645,7 @@ "hash": "95c485e91159f3a8b550e08492cb4ed2557284663e79130c96242e7ff9e65ae1" }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -2640,7 +2659,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.9.10" } }, "nbformat": 4,