From bd1ee43dc5b022a2f770f4e20d480aa0362d6894 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Sun, 6 Oct 2024 14:41:39 +1300 Subject: [PATCH] Use Keras 2 (via tf_keras) for this chapter, fixes #163 --- 16_nlp_with_rnns_and_attention.ipynb | 475 ++++++++++++++------------- 1 file changed, 247 insertions(+), 228 deletions(-) diff --git a/16_nlp_with_rnns_and_attention.ipynb b/16_nlp_with_rnns_and_attention.ipynb index 011c886..390fc66 100644 --- a/16_nlp_with_rnns_and_attention.ipynb +++ b/16_nlp_with_rnns_and_attention.ipynb @@ -60,6 +60,25 @@ "assert sys.version_info >= (3, 7)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Warning**: the latest TensorFlow versions are based on Keras 3. For previous chapters, it wasn't too hard to update the code to support Keras 3, but unfortunately it's much harder for this chapter: for example, stateful RNNs work very differently, ragged tensors are no longer supported, TensorFlow Hub models are no longer supported, and more. So for this chapter I've had to revert to Keras 2. To do that, I set the `TF_USE_LEGACY_KERAS` environment variable to `\"1\"` and import the `tf_keras` package. This ensures that `tf.keras` points to `tf_keras`, which is Keras 2.*." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"TF_USE_LEGACY_KERAS\"] = \"1\" \n", + "\n", + "import tf_keras" + ] + }, { "cell_type": "markdown", "metadata": { @@ -71,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": { "id": "0Piq5se2pKzx" }, @@ -94,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": { "id": "8d4TH3NbpKzx" }, @@ -120,7 +139,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": { "id": "PQFH5Y9PpKzy" }, @@ -149,7 +168,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": { "id": "Ekxzo6pOpKzy" }, @@ -187,7 +206,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -211,7 +230,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -233,7 +252,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -242,7 +261,7 @@ "\"\\n !$&',-.3:;?abcdefghijklmnopqrstuvwxyz\"" ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -254,7 +273,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -266,7 +285,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -277,7 +296,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -286,7 +305,7 @@ "39" ] }, - "execution_count": 11, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -297,7 +316,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -306,7 +325,7 @@ "1115394" ] }, - "execution_count": 12, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -317,7 +336,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -333,7 +352,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -343,7 +362,7 @@ " )]" ] }, - "execution_count": 14, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -357,7 +376,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -392,7 +411,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -543,7 +562,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -563,7 +582,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -576,7 +595,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -585,7 +604,7 @@ "'e'" ] }, - "execution_count": 19, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -605,7 +624,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -614,7 +633,7 @@ "" ] }, - "execution_count": 20, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -627,7 +646,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -640,7 +659,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -652,7 +671,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -661,7 +680,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -680,7 +699,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -700,7 +719,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -725,7 +744,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -743,7 +762,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -757,7 +776,7 @@ " )]" ] }, - "execution_count": 28, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -776,7 +795,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -802,7 +821,7 @@ " [17, 18, 19]], dtype=int32)>)]" ] }, - "execution_count": 29, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -828,7 +847,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -843,7 +862,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 32, "metadata": {}, "outputs": [], "source": [ @@ -854,7 +873,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -874,7 +893,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 34, "metadata": {}, "outputs": [ { @@ -1037,7 +1056,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -1057,7 +1076,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -1066,7 +1085,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ @@ -1075,7 +1094,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -1088,7 +1107,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 39, "metadata": {}, "outputs": [ { @@ -1115,7 +1134,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 40, "metadata": {}, "outputs": [ { @@ -1275,7 +1294,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 41, "metadata": {}, "outputs": [ { @@ -1301,7 +1320,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 42, "metadata": {}, "outputs": [], "source": [ @@ -1319,7 +1338,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 43, "metadata": {}, "outputs": [ { @@ -1363,7 +1382,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 44, "metadata": {}, "outputs": [ { @@ -1406,7 +1425,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ @@ -1429,7 +1448,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 46, "metadata": {}, "outputs": [ { @@ -1465,7 +1484,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 47, "metadata": {}, "outputs": [ { @@ -1474,7 +1493,7 @@ "" ] }, - "execution_count": 46, + "execution_count": 47, "metadata": {}, "output_type": "execute_result" } @@ -1488,7 +1507,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 48, "metadata": {}, "outputs": [ { @@ -1499,7 +1518,7 @@ " [ 11, 7, 1, 116, 217]])>" ] }, - "execution_count": 47, + "execution_count": 48, "metadata": {}, "output_type": "execute_result" } @@ -1517,7 +1536,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 49, "metadata": {}, "outputs": [ { @@ -1567,7 +1586,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 50, "metadata": {}, "outputs": [ { @@ -1602,7 +1621,7 @@ "" ] }, - "execution_count": 49, + "execution_count": 50, "metadata": {}, "output_type": "execute_result" } @@ -1633,7 +1652,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ @@ -1645,7 +1664,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ @@ -1660,7 +1679,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 53, "metadata": {}, "outputs": [ { @@ -1680,7 +1699,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 54, "metadata": {}, "outputs": [], "source": [ @@ -1696,7 +1715,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 55, "metadata": {}, "outputs": [ { @@ -1705,7 +1724,7 @@ "['', '[UNK]', 'the', 'i', 'to', 'you', 'tom', 'a', 'is', 'he']" ] }, - "execution_count": 54, + "execution_count": 55, "metadata": {}, "output_type": "execute_result" } @@ -1716,7 +1735,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 56, "metadata": {}, "outputs": [ { @@ -1725,7 +1744,7 @@ "['', '[UNK]', 'startofseq', 'endofseq', 'de', 'que', 'a', 'no', 'tom', 'la']" ] }, - "execution_count": 55, + "execution_count": 56, "metadata": {}, "output_type": "execute_result" } @@ -1736,7 +1755,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 57, "metadata": {}, "outputs": [], "source": [ @@ -1750,7 +1769,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 58, "metadata": {}, "outputs": [], "source": [ @@ -1761,7 +1780,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 59, "metadata": {}, "outputs": [], "source": [ @@ -1778,7 +1797,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 60, "metadata": {}, "outputs": [], "source": [ @@ -1788,7 +1807,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 61, "metadata": {}, "outputs": [], "source": [ @@ -1798,7 +1817,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 62, "metadata": {}, "outputs": [], "source": [ @@ -1815,7 +1834,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 63, "metadata": {}, "outputs": [ { @@ -1850,7 +1869,7 @@ "" ] }, - "execution_count": 62, + "execution_count": 63, "metadata": {}, "output_type": "execute_result" } @@ -1866,7 +1885,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 64, "metadata": {}, "outputs": [], "source": [ @@ -1886,7 +1905,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 65, "metadata": {}, "outputs": [ { @@ -1895,7 +1914,7 @@ "'me gusta el fútbol'" ] }, - "execution_count": 64, + "execution_count": 65, "metadata": {}, "output_type": "execute_result" } @@ -1913,7 +1932,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 66, "metadata": {}, "outputs": [ { @@ -1922,7 +1941,7 @@ "'me gusta el fútbol y a veces mismo al bus'" ] }, - "execution_count": 65, + "execution_count": 66, "metadata": {}, "output_type": "execute_result" } @@ -1947,7 +1966,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 67, "metadata": {}, "outputs": [], "source": [ @@ -1958,7 +1977,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 68, "metadata": {}, "outputs": [], "source": [ @@ -1976,7 +1995,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 69, "metadata": {}, "outputs": [ { @@ -2011,7 +2030,7 @@ "" ] }, - "execution_count": 68, + "execution_count": 69, "metadata": {}, "output_type": "execute_result" } @@ -2032,7 +2051,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 70, "metadata": {}, "outputs": [ { @@ -2041,7 +2060,7 @@ "'me gusta el fútbol'" ] }, - "execution_count": 69, + "execution_count": 70, "metadata": {}, "output_type": "execute_result" } @@ -2068,7 +2087,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 71, "metadata": {}, "outputs": [], "source": [ @@ -2113,7 +2132,7 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 72, "metadata": {}, "outputs": [ { @@ -2122,7 +2141,7 @@ "'me [UNK] los gatos y los gatos'" ] }, - "execution_count": 71, + "execution_count": 72, "metadata": {}, "output_type": "execute_result" } @@ -2135,7 +2154,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 73, "metadata": {}, "outputs": [ { @@ -2158,7 +2177,7 @@ "'me [UNK] los gatos y los gatos'" ] }, - "execution_count": 72, + "execution_count": 73, "metadata": {}, "output_type": "execute_result" } @@ -2191,7 +2210,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 74, "metadata": {}, "outputs": [], "source": [ @@ -2202,7 +2221,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 75, "metadata": {}, "outputs": [], "source": [ @@ -2223,7 +2242,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 76, "metadata": {}, "outputs": [], "source": [ @@ -2242,7 +2261,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 77, "metadata": {}, "outputs": [ { @@ -2277,7 +2296,7 @@ "" ] }, - "execution_count": 76, + "execution_count": 77, "metadata": {}, "output_type": "execute_result" } @@ -2293,7 +2312,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 78, "metadata": {}, "outputs": [ { @@ -2302,7 +2321,7 @@ "'me gusta el fútbol y también ir a la playa'" ] }, - "execution_count": 77, + "execution_count": 78, "metadata": {}, "output_type": "execute_result" } @@ -2313,7 +2332,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 79, "metadata": {}, "outputs": [ { @@ -2340,7 +2359,7 @@ "'me gusta el fútbol y también ir a la playa'" ] }, - "execution_count": 78, + "execution_count": 79, "metadata": {}, "output_type": "execute_result" } @@ -2360,7 +2379,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 80, "metadata": {}, "outputs": [], "source": [ @@ -2383,7 +2402,7 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 81, "metadata": {}, "outputs": [], "source": [ @@ -2406,7 +2425,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 82, "metadata": {}, "outputs": [], "source": [ @@ -2417,7 +2436,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 83, "metadata": {}, "outputs": [ { @@ -2479,7 +2498,7 @@ }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 84, "metadata": {}, "outputs": [], "source": [ @@ -2504,7 +2523,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 85, "metadata": {}, "outputs": [], "source": [ @@ -2515,7 +2534,7 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 86, "metadata": {}, "outputs": [], "source": [ @@ -2547,7 +2566,7 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 87, "metadata": {}, "outputs": [ { @@ -2582,7 +2601,7 @@ "" ] }, - "execution_count": 86, + "execution_count": 87, "metadata": {}, "output_type": "execute_result" } @@ -2599,7 +2618,7 @@ }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 88, "metadata": {}, "outputs": [ { @@ -2608,7 +2627,7 @@ "'me gusta el fútbol y yo también voy a la playa'" ] }, - "execution_count": 87, + "execution_count": 88, "metadata": {}, "output_type": "execute_result" } @@ -2633,7 +2652,7 @@ }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 89, "metadata": {}, "outputs": [], "source": [ @@ -2644,7 +2663,7 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 90, "metadata": {}, "outputs": [ { @@ -2675,7 +2694,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 91, "metadata": {}, "outputs": [ { @@ -2685,7 +2704,7 @@ " {'label': 'NEGATIVE', 'score': 0.9811071157455444}]" ] }, - "execution_count": 90, + "execution_count": 91, "metadata": {}, "output_type": "execute_result" } @@ -2696,7 +2715,7 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 92, "metadata": {}, "outputs": [ { @@ -2716,7 +2735,7 @@ "[{'label': 'contradiction', 'score': 0.9790192246437073}]" ] }, - "execution_count": 91, + "execution_count": 92, "metadata": {}, "output_type": "execute_result" } @@ -2729,7 +2748,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 93, "metadata": {}, "outputs": [ { @@ -2751,35 +2770,6 @@ "model = TFAutoModelForSequenceClassification.from_pretrained(model_name)" ] }, - { - "cell_type": "code", - "execution_count": 93, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'input_ids': , 'attention_mask': }" - ] - }, - "execution_count": 93, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "token_ids = tokenizer([\"I like soccer. [SEP] We all love soccer!\",\n", - " \"Joe lived for a very long time. [SEP] Joe is old.\"],\n", - " padding=True, return_tensors=\"tf\")\n", - "token_ids" - ] - }, { "cell_type": "code", "execution_count": 94, @@ -2803,8 +2793,8 @@ } ], "source": [ - "token_ids = tokenizer([(\"I like soccer.\", \"We all love soccer!\"),\n", - " (\"Joe lived for a very long time.\", \"Joe is old.\")],\n", + "token_ids = tokenizer([\"I like soccer. [SEP] We all love soccer!\",\n", + " \"Joe lived for a very long time. [SEP] Joe is old.\"],\n", " padding=True, return_tensors=\"tf\")\n", "token_ids" ] @@ -2813,6 +2803,35 @@ "cell_type": "code", "execution_count": 95, "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input_ids': , 'attention_mask': }" + ] + }, + "execution_count": 95, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "token_ids = tokenizer([(\"I like soccer.\", \"We all love soccer!\"),\n", + " (\"Joe lived for a very long time.\", \"Joe is old.\")],\n", + " padding=True, return_tensors=\"tf\")\n", + "token_ids" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, "outputs": [ { "data": { @@ -2822,7 +2841,7 @@ " [-0.01478387, 1.0962474 , -0.9919954 ]], dtype=float32)>, hidden_states=None, attentions=None)" ] }, - "execution_count": 95, + "execution_count": 96, "metadata": {}, "output_type": "execute_result" } @@ -2834,7 +2853,7 @@ }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 97, "metadata": {}, "outputs": [ { @@ -2845,7 +2864,7 @@ " [0.22655967, 0.6881726 , 0.0852678 ]], dtype=float32)>" ] }, - "execution_count": 96, + "execution_count": 97, "metadata": {}, "output_type": "execute_result" } @@ -2857,7 +2876,7 @@ }, { "cell_type": "code", - "execution_count": 97, + "execution_count": 98, "metadata": {}, "outputs": [ { @@ -2866,7 +2885,7 @@ "" ] }, - "execution_count": 97, + "execution_count": 98, "metadata": {}, "output_type": "execute_result" } @@ -2878,7 +2897,7 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": 99, "metadata": {}, "outputs": [ { @@ -2945,7 +2964,7 @@ }, { "cell_type": "code", - "execution_count": 99, + "execution_count": 100, "metadata": {}, "outputs": [], "source": [ @@ -2988,7 +3007,7 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": 101, "metadata": {}, "outputs": [ { @@ -3015,7 +3034,7 @@ }, { "cell_type": "code", - "execution_count": 101, + "execution_count": 102, "metadata": {}, "outputs": [ { @@ -3042,7 +3061,7 @@ }, { "cell_type": "code", - "execution_count": 102, + "execution_count": 103, "metadata": {}, "outputs": [], "source": [ @@ -3065,7 +3084,7 @@ }, { "cell_type": "code", - "execution_count": 103, + "execution_count": 104, "metadata": {}, "outputs": [ { @@ -3092,7 +3111,7 @@ }, { "cell_type": "code", - "execution_count": 104, + "execution_count": 105, "metadata": {}, "outputs": [], "source": [ @@ -3102,7 +3121,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": 106, "metadata": {}, "outputs": [ { @@ -3111,7 +3130,7 @@ "[0, 4, 4, 4, 6, 6, 5, 5, 1, 4, 1]" ] }, - "execution_count": 105, + "execution_count": 106, "metadata": {}, "output_type": "execute_result" } @@ -3129,7 +3148,7 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 107, "metadata": {}, "outputs": [], "source": [ @@ -3151,7 +3170,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 108, "metadata": {}, "outputs": [], "source": [ @@ -3170,7 +3189,7 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 109, "metadata": {}, "outputs": [ { @@ -3181,7 +3200,7 @@ " dtype=int32)>" ] }, - "execution_count": 108, + "execution_count": 109, "metadata": {}, "output_type": "execute_result" } @@ -3199,7 +3218,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 110, "metadata": {}, "outputs": [ { @@ -3208,7 +3227,7 @@ "array([1.])" ] }, - "execution_count": 109, + "execution_count": 110, "metadata": {}, "output_type": "execute_result" } @@ -3226,7 +3245,7 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 111, "metadata": {}, "outputs": [ { @@ -3306,7 +3325,7 @@ }, { "cell_type": "code", - "execution_count": 111, + "execution_count": 112, "metadata": {}, "outputs": [ { @@ -3356,7 +3375,7 @@ }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 113, "metadata": {}, "outputs": [], "source": [ @@ -3387,7 +3406,7 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 114, "metadata": {}, "outputs": [ { @@ -3422,7 +3441,7 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 115, "metadata": {}, "outputs": [ { @@ -3431,7 +3450,7 @@ "' ,0123456789ADFJMNOSabceghilmnoprstuvy'" ] }, - "execution_count": 114, + "execution_count": 115, "metadata": {}, "output_type": "execute_result" } @@ -3450,7 +3469,7 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 116, "metadata": {}, "outputs": [], "source": [ @@ -3466,7 +3485,7 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 117, "metadata": {}, "outputs": [], "source": [ @@ -3474,26 +3493,6 @@ " return [chars.index(c) for c in date_str]" ] }, - { - "cell_type": "code", - "execution_count": 117, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[19, 23, 31, 34, 23, 28, 21, 23, 32, 0, 4, 2, 1, 0, 9, 2, 9, 7]" - ] - }, - "execution_count": 117, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "date_str_to_ids(x_example[0], INPUT_CHARS)" - ] - }, { "cell_type": "code", "execution_count": 118, @@ -3502,7 +3501,7 @@ { "data": { "text/plain": [ - "[7, 0, 7, 5, 10, 0, 9, 10, 2, 0]" + "[19, 23, 31, 34, 23, 28, 21, 23, 32, 0, 4, 2, 1, 0, 9, 2, 9, 7]" ] }, "execution_count": 118, @@ -3510,13 +3509,33 @@ "output_type": "execute_result" } ], + "source": [ + "date_str_to_ids(x_example[0], INPUT_CHARS)" + ] + }, + { + "cell_type": "code", + "execution_count": 119, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[7, 0, 7, 5, 10, 0, 9, 10, 2, 0]" + ] + }, + "execution_count": 119, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "date_str_to_ids(y_example[0], OUTPUT_CHARS)" ] }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 120, "metadata": {}, "outputs": [], "source": [ @@ -3532,7 +3551,7 @@ }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 121, "metadata": {}, "outputs": [], "source": [ @@ -3545,7 +3564,7 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 122, "metadata": {}, "outputs": [ { @@ -3554,7 +3573,7 @@ "" ] }, - "execution_count": 121, + "execution_count": 122, "metadata": {}, "output_type": "execute_result" } @@ -3581,7 +3600,7 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": 123, "metadata": {}, "outputs": [ { @@ -3672,7 +3691,7 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 124, "metadata": {}, "outputs": [], "source": [ @@ -3690,7 +3709,7 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 125, "metadata": {}, "outputs": [], "source": [ @@ -3699,7 +3718,7 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 126, "metadata": {}, "outputs": [ { @@ -3733,7 +3752,7 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 127, "metadata": {}, "outputs": [], "source": [ @@ -3742,7 +3761,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 128, "metadata": {}, "outputs": [ { @@ -3769,7 +3788,7 @@ }, { "cell_type": "code", - "execution_count": 128, + "execution_count": 129, "metadata": {}, "outputs": [], "source": [ @@ -3789,7 +3808,7 @@ }, { "cell_type": "code", - "execution_count": 129, + "execution_count": 130, "metadata": {}, "outputs": [ { @@ -3798,7 +3817,7 @@ "['2020-05-02', '1789-07-14']" ] }, - "execution_count": 129, + "execution_count": 130, "metadata": {}, "output_type": "execute_result" } @@ -3845,7 +3864,7 @@ }, { "cell_type": "code", - "execution_count": 130, + "execution_count": 131, "metadata": {}, "outputs": [], "source": [ @@ -3869,7 +3888,7 @@ }, { "cell_type": "code", - "execution_count": 131, + "execution_count": 132, "metadata": {}, "outputs": [ { @@ -3885,7 +3904,7 @@ " [12, 8, 9, ..., 8, 11, 3]], dtype=int32)>" ] }, - "execution_count": 131, + "execution_count": 132, "metadata": {}, "output_type": "execute_result" } @@ -3903,7 +3922,7 @@ }, { "cell_type": "code", - "execution_count": 132, + "execution_count": 133, "metadata": {}, "outputs": [ { @@ -3984,7 +4003,7 @@ }, { "cell_type": "code", - "execution_count": 133, + "execution_count": 134, "metadata": {}, "outputs": [], "source": [ @@ -4004,7 +4023,7 @@ }, { "cell_type": "code", - "execution_count": 134, + "execution_count": 135, "metadata": {}, "outputs": [ { @@ -4013,7 +4032,7 @@ "['1789-07-14', '2020-05-01']" ] }, - "execution_count": 134, + "execution_count": 135, "metadata": {}, "output_type": "execute_result" } @@ -4061,7 +4080,7 @@ }, { "cell_type": "code", - "execution_count": 135, + "execution_count": 136, "metadata": {}, "outputs": [ { @@ -4090,7 +4109,7 @@ }, { "cell_type": "code", - "execution_count": 136, + "execution_count": 137, "metadata": {}, "outputs": [ { @@ -4116,7 +4135,7 @@ }, { "cell_type": "code", - "execution_count": 137, + "execution_count": 138, "metadata": {}, "outputs": [ { @@ -4125,7 +4144,7 @@ "{'input_ids': [3570, 1473], 'attention_mask': [1, 1]}" ] }, - "execution_count": 137, + "execution_count": 138, "metadata": {}, "output_type": "execute_result" } @@ -4136,7 +4155,7 @@ }, { "cell_type": "code", - "execution_count": 138, + "execution_count": 139, "metadata": {}, "outputs": [ { @@ -4147,7 +4166,7 @@ " 16187]], dtype=int32)>" ] }, - "execution_count": 138, + "execution_count": 139, "metadata": {}, "output_type": "execute_result" } @@ -4169,7 +4188,7 @@ }, { "cell_type": "code", - "execution_count": 139, + "execution_count": 140, "metadata": {}, "outputs": [ { @@ -4208,7 +4227,7 @@ " 504, 239, 249, 825, 512]], dtype=int32)>" ] }, - "execution_count": 139, + "execution_count": 140, "metadata": {}, "output_type": "execute_result" } @@ -4240,7 +4259,7 @@ }, { "cell_type": "code", - "execution_count": 140, + "execution_count": 141, "metadata": {}, "outputs": [ { @@ -4305,7 +4324,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.9.10" }, "nav_menu": {}, "toc": {