From 9825a36344d5f7114e027e70691ecfa8c6f13fe1 Mon Sep 17 00:00:00 2001 From: franzi Date: Fri, 20 Aug 2021 15:54:59 +0200 Subject: [PATCH] include rl notebook --- README.md | 1 + course_description.pdf | Bin 35648 -> 35632 bytes exercises/8_rl_gridmove.ipynb | 275 ++++++++++++++++++++++++++++++++++ 3 files changed, 276 insertions(+) create mode 100644 exercises/8_rl_gridmove.ipynb diff --git a/README.md b/README.md index a33de23..700ffb6 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,7 @@ Have fun! ##### Block 5.1: - [ ] Read the whole chapter: ["ML Algorithms: Reinforcement Learning"](https://franziskahorn.de/mlbook/_ml_algorithms_reinforcement_learning.html) +- [ ] Work through [Notebook 8: RL gridmove](/exercises/8_rl_gridmove.ipynb) ##### Block 5.2: - [ ] Answer [Quiz 5](https://forms.gle/fr7PYmP9Exx4Vvrc8) diff --git a/course_description.pdf b/course_description.pdf index d4c522fdf98716759a016669137bc410ba879775..3a5eaf220bf6dee181efc192001606f8dcfa7b84 100644 GIT binary patch delta 4811 zcmai$Wl$6h*M{kohNUDPLK-&Mr4d+4y1SPpWK~j{C8cvIX%QAgmJUhj4ub|s0hI=k z6uxKXn{VcOf4yh!KiAyX%$YgAPER&L-M?9;K<~X+9-(tZjRskyaa`dBU}x=dUIY8y zIohY!1k`SXzpdHPiATn@{z{)I4KCpKf2SrCKV+!Y3LD3y8Q3UFo-GGXUw++{5WU-8 z>(?*-=qZ2aF@hlB;ylyA3QKHd^crTLtc7ua+kNlHk7JiNDjwHMV@?jHn}Tk>9`A_; zp#cY~9fktG)PoHrUb~7GtNo2mWGy&Bk(iTu64$PVLUJzN8v6QkQPEGid|puSqvcSa z-e_ddJwBDbI{7o*|9F-9J~CRIcH)cH6IKb%@Bu9>zr=fsN|OZ9sVW{OVKaVhOUt1P zG}qnjMsLS@u2>{@V4mr5cH~Kub0T-02|6>7NrLn4nsp98=iV)eR(*i%uHzy?E4skF zE^Wj6RSDl&sS->^UZ*-JxR>}-y$mY~%BoSvFidh|=o_hd)ThF%_ovjjm5FyDu5s|0 zoWH*Ax|2kExY!qStmjs}i-<=24*^~I(S!h_@|C!XYJ96`CN#%)9O2p%KmH}5;Pou{>e`lX3s;SH$ zr1ij)HZ&~ti5hj2TW-S$kpY9LCFdWEVfPbpDQ`HVM`IVC;cGI1=7@QfRijXT6Wh!H zS_`$&Z&JI#b-Qo3Cu8_S3>u6ifoKs{G=1YHouz7U?nI#J>B-t}S02h$<}q}NY%S?O z?7Mf#*+BjiVmbGyTf!kWm4Q0$kj2?v+;YUu_a|!mdBzhP*8bt#ui<~Tk?7SHFFrX( zC=qC%0LM{6<|}Fl0SD#Tyhlk)*1tgM$Usu?=EOHn7{09Q;2b%9Oj|~$nOMoD@96#} z4&BNoM6~_o8~<5^??oX``egFzzIevxxQuL6 zyvGaOw6Ue8(>EkxeiQ`t8feCuO#gKh#eI;~w*HS(H>%Ts_cgz9cY#1U7fSJo6gT1Y zP5&J^8^gxo7ZvN=*VIY$BC4!)Mfu(hSZf)I4kpX^gAZ9EWcQz(L?l@V4-h~$-;FIO zJ~8Zup2z`@kuZQoN2QvpP;_x3swXtLO_ky9QBY3I%JbjUbY_})Qt0wG%6V&1%N_Zw zj`I9e?CGFS?fM%GPxGi^j_46HH{lk~2zSdjg}J-(jaCJ_hCjR!=l-&$7x9~%JAs;m zpFLjTJujgwe|Mzx?$w;JQ(b^#y0uhVJ{R;+ zIyOoqAl(cnFYVvSi*>Fj7+xa{AxGF3rGPrD0|=#a3+NTOB^k5E53aLvCc_4;%nY-Z zo2&o2j6XG8$b1DYJ1NJgH|gd$L;`)qtJ`FwD)sEU^#UT;1JQK0MlesM#c@^U*3C4V zgp5{^oT5^$Ujbj9ycuum2tOdJ#^ate?PWJc7|O>Fk!0Glu2cJ|_`4WF#+j9tnAU?? zbR*Y=+MvgJqb|i3t?i7F9^*1!$aX9NDh8v^xZ{Afm^weeS9iqn^g08HA;6mOByTf= z?s_OVtu-B^gx+F*r#Uach_ri0o6xX8l92Y1f+5`Q4QKxLd<(#_q@bEEf~>dY6*5mV zeE(~ydOqq~`T4F1eMgTC*ZbUSlKhEA!DQ(W@OqkvW%N1##qRkLpEvj+QAQ#TUnoN* zR)Ngyh%VhEi5Rq$Hif7D;;nz00|{`f$_!UbByX0(fX+T8i~8D~Xz}Z5BSX_w|G*l_ z&>8ZI1MfR`1{m?C#ML-u-t*$9Yn{PlSm?1@sb8kM*sCc+pdx>+EF-pbo?CFP66dPc z4aMJ`8SwHewd+E)914P(&ojTvkDJn|gu%V%mF}MJl-yHiA7%Uyi@59+U0qKJ`Re6e zNN3d%f<6!rDT5*q^tSLiY$*L?6njRiRZ&R!*{Rh|bGNj02=ayBe;dO|JfIq~0^S>$=1PzJ%(Y0sNvQsZt_Keg zb!NMM#$XEs)Z)!aNrPpM=)RHhZKTt>!n`iJXUs}RwjN%1;$PZ0z6bRgHI}@dlya8EZT{hu)vAsH z9J+kytk7M^`$k%0vK(Uj6s+$XNPD|HYLxt`g(bUn=?E-$V4JW_qa3T6PAorywZWk+ z>jLgId;5&I6Q5nx&9nivcmsf*^^1;)Sm zv+K1v7M|=jObLoRFtImyeJL$HKFEruf%d|XK2(87y)$3C#5F8qx_vsc0&zJ7N@a9u zB`TDn+n??}_}BRjq3QgO^Wj9!kE=~JyJx*2j8;XSg5&$-5n}NAptf0M0L45!3XYTA zNlm}CXJ4-|zFC@N44ev1$g*;`t7qSTSo!kx&;H%A6yaaD6Kd;ABy9caQ_5&&HVEg# zxpXl(OMHtZ6RisC_Dd?sM2e&Wo#>JiKu(aOnz{cTO3!t+;fN=(;8*s>&NIWY%+cl# z^|=s*?DovWg8+=9Y)Ad&%TA&2lzc0eS^CYgN&2-5ED0d?)K0zp+gLCdN2pWx&8;`U zi(q#nx|Fic==zI%hDAjF-6}d{?c_DcqoxP!xfQ0p6(v60QD!jZVbY#Ny>+gQ!M}!Z zGzTrlYvpk62$HI%cMs|w)+{XXXrEQQz$F_fe~lQyFJEwz-=&BOUz0R6M7m>PB|P@J zgAfH0Y~kZtS~iBOp9i+e_``}%gQev&)7P9#slL*V@(mFR=&c`cXrmd*{xStx+d6dn zilM%>{x(6bCv(WmTAYA+>L~jZhV4UIfkT1aEe_(HbmROSMokz~VXZ)C)U^-HV5V>> zC1M9eXWB`%hroX9+`sVqreyW5kV82BS736{^nRObkbzLW`$~mHU)ORu#}u9@m2)y# zLrlAPoAT0tJE0npFZ!uoSS(}^tG9g=G`|=@`C@%iYuRHSvKNdoCA^r^n9zzfI_bJ5 z)qU3KzUG`=C2+Vbdff#u)46$ZUw>;{rG>P9e9k0Bni6B@l!dofD8=`W_=q82!#gqB zLxUzC>QilnQ@Rz!!QZQH=DuemTsjs!9+d&Yb(y)|#7#R?X!OH+0z|KI4eMI?+Maqj zQIq*T>Yb69{FQl*A-1?1c(xtPmszJ&H|O%>5nZR^k;a0+Pl6`Z1 zqKYwg*1~F}`H%952?3wiEA)AWyQ=rpPsbzL42^m=Az~{L;Qg0_Nm|9*zlH`St)+$CWq#d`wgX-AxdkbXeXS2IdQOM25jfpO7lH9#zI1f4k* zvz*UtgR2JoAeii?(nc5hWVk49o&3TDtTKBx)b7;koXm4M4`SI!Vt} z(#o{N5|R{$hwmuK%Xr%{pM4lG3IK8fX-SMGz+VIm1OdT7C|Cpt761YTgeW9*UpT8e z`XadnRb|0IFbF99e@f#o0Ie&ZvJ^}j4uU9)!GX#uVjwsODh5-B!NfqSs%k)GX_zEL zRR+Tu2Ic#=mWBb9A!2aIzm6&p0u%#7)YZhmsvtE85CR55B$YA#VQ(nGKobz~-(9j$ zFoqW``ycFvj{E=Ic!Q27L`%NRkewud*)jTjof2x+I3%0xD5P}u=Z1#Ab+xC0QOF^t zEFkyI@5#p|WxJ7q;$6;KUSdSXujHb#x0O8}5hwJReE;l@<ypuyH#FDn{7>KBb+t--nGgSi!p zyblzv?TR1lbz;*6;Ar`griD?1(}g()(-zZQzvTER!le6KNr{Y)i^nG(x>W7FU~u~L zH5-~~@zgt(?9{{#r`+kP_iBr*02)l~wr>H)J#viNL7auY|0giv-sv+ZZ$tI`|04KQkUbY{bQi(rOE6d9F6{i8>ouf3^STdtdQaw z+MeHh_OgwIney_=<>#(-ZvJ$|AYq=(`acaN>=wCFAMuGNh zeMqG(UWSUiP_Ss!^4Vf+g?Xy+f2i- zt`WWzLX&}*&{s_8>ccEo5+ACcjHmB&{=vOTi{fa=JO3mG&vOeiX<1cm-U$nLx zZz=a9h{Q*@Uvl;yQ?hA|E%wF}p1fM2#aWR@6Wt7qZBt0X7THiBz&A<{HvJlq+XBu74G`8-KqaNwi6)@6+SHu-Ov^Ko%E+~wjLky zJXfaX%PiL5Fb`NeZJ+=5sw7LvD0gOj*RxQ!`zCuQ<9{)mu~z86U)X&midol2 z)BM<3P9Aq`YS1^R-1?nvWy>}2?y(q7R#kmj`69^Asekqk-Va1uZ%>#!}96~ zMzIx5MXAYwvn^cPW0n_Dsa*Eo7Pp!1v-RKei0nSwKK)fNUTpnW#y6nI{N9CmyV_zt zcrNBwVPQ>ig5Fj2nefNQ6^|){gb#-CL)71IQaz0-)v~VrfVyMwvP%+>QGjKLu^-4e z7dU+Hs8qkVZ|O5iZ9@HV2V*-|d9JxGznrJ}m|7;a1jOdm0dgE{K06ln;H z(2xmkCxpIh6LkcqTZ>*?wgf+KSEw~9+SxB2c^Rjlx5ca@;N6<^XhY)tzJdp-D_YqD a8~#{a#gHGvlE_H`lmS!l^Q-BrQ~VG97Whg4 delta 4808 zcmai!Ran!3qd@5n=@gM}w!y{*qhWNTv>>BFx}<(|2#7GG1tcXU6hul832`djFuJ5m z(*OVEe&4-M_wjt+c{wkqA&cNk7D2-t1+eWk_npq!V@OtP@MV~2Iu<^Z%&I|^-EN)n zp2e`3`nT7!8}W#uMPkNq{GEuc?xozB7#)jKNRufUxO^RYk{{}Hax}KdcXOurHLXZN z3V9bg*h{1!=1`UB_&8RbK*K0j@eZcDHGcd4tKt%k;vJOSgl_RprKVMpy8B0rGti^o zz~LY5*}!sIqMqk)Qv3j{kYFqJ=Iac8K~Hm#?b-IN9pn50X-p#jiJixREE!GA=HU#v zKyV1>t@m~7vaM$VprPVUBjAAboX@Tee za2DRWyWC{VmHbG419NM>fO{D^0E-3NB<;VyM7xp0027Lg*PaO2w<+e4*q{(|_>{vd zG{W;MU5`YDq2{F+E2-?HAg_g9Kb_`^Ia-lI|K7HZ`BraQCvNb_9gD0N)!rC$`vh?W6~0`%v6PM5r&6YoHl4NV0Yvs3d=y6XbsH@TxXq=CdiF>8W47`pO0~yP$tsZTk;MREukl4R<2${?@m3S=aDqg8B@Haf{#W&n$OwAd zCk4$O?(l9woBi}_2-{|=LImE<=kNFzVmp&;(P%sb{_{Q!NQ<2D3!HFdYRaj3{L+2;Da>(nfp6P8JyD6S-FGvH#KRc%kj7p$`8F%+9$H{5$}$* zYT{Z(_X%2f4+$crF$j0D)aQ46>>P#>r7-o7pPFsKAKXo+{^B#-=qUMLCp=sA4l z=ayt+SC?j^8xSSDf#IfM4Kv2Q(n`O^Z4AM15!CwU)0BLxR)iTEZQ7ytY8QD$gfV@| zBYFNIpb&scH$G0=qgoL5fpWFiMp3Owz*CiUi$Fm9>ZC+ zqxk1i&5KF&=!rC@jrI~c^epOTVP56meX%9=9u9(|&L&;ZB#z zQdKaEDPK;n&yHG_7=6K3Wfax9+e=eDhcV&GDBTT-@?-h}19$!yzofnzsM zFRFwQ)e<#^tKI3VmAL1|2{|3#l%kk8uyk4R#8{6mfhkw2)(eoK{R8P%map6L>%Erx z4{|c|H{2@*kw$yUiC%&_@+wmH=FA`E99t&aGlw5UQy(GE)LsR&xy>{L0~eZKdkwjd zpM7hXjt7bGHejk2wyHfh67%M(Bjw)p=Q89Q1tYXtC(r-BL%9A3eO|zLIm?)ca;1m- zq~Up->1N)6c4M_Px|UTU4?2q;de~mQY+opK{_(we(|QlyY^N8jS%S`>2-`8r;HD;5 zyUy)o9IpIq1%2>DA`7&15Ad*oS#7!7i_`ATvfrm|NzOGhb65vEJbZf(9K*uj=rPTC zZs@JfWH-e!rLex;E>s;l%Z%Ia{M<|Q)Y^bX(6W{{OS>2v8x!Qq2Z|Nq*yLaBgLp_R3_1Q9cZR0ZLsR`EGh2;V=b!-hoMcX zr}snf@xS`*bw~EfrrY7>2}z)$C9F#6mo|}#^Ky_Lxm0L`x7S0HNWY?2D2YClFTin_nfp{7y**m_-bk*&RBfaXvKMLAw zl&MboH+`q6GNn2r`?(hF#yU6h8(EQ%A>Y%y6>g(%6Kazh|Jd?LLf#uULDaAHqj?`u z7jHPtzR6bWuNGn;4!U1n_n23uda|(`j%vz>8LIv5zZ!Vp-P*9P)YiLi$8n5?1cH>m+~R^vqR2kCUQY=4-%@Cg$Yc{CvX2J93ZWhzX*9EZ7~uW*u@Cx5(9{kw;1y>$Gm~$@8;%=@u8d6H4Vj zRaXlBH&*Id7RNaX>F;#XDoB3pi{oL&)5=)^EVY<;*wFsXe!*qBkjq(-GQJ!ENjl2c zK1}0s{{29KjrP(LhJ4bil8lP`@D*IuT8^GcUk_@UkF4q-xYUj`{HJnrexZgV^|R~Y zsy9At1@h;%!N(p;O2@nVjf;mVO42OWA(h!&^2qFjq z2?|q4>UlY+J@t3w7E+UifFNKH{C||@Ip%xLd}>g2gtQCb~kf|3@OQC3EYL!nYIab=haR2(La0IA8SN+Z;iWiT!g zZz;iWQwT^FDlH3vVHg3~|HBqQ;{WF6KR^;8&=;8T`56sIWxgGu?YqW2=wu_64t`$i z5+KxeMxs!zi7(2h?hDWCx=eU+Y7P2To~#m zK&7g!WfbCSB+^s6ImXVUjZ>d~(%eW%E#4PoB`+hpMeV9=!-gMJo~>(L zz&K7r7CJ=wkyx&$TO$`&4(6NCR3sa)qAY+GR2@QFeXVY)v$vQM^YM;6GBY^gxy+}2 z&^C=&(Y#}C6d3>Q>zX^_zOzPiyDu_#!8yhUTqAg9 zvZ)&QC?Z$SyZwutR`PlBPJK@5iG?0>ja9wr?v*8JNQ}bAD%HJR(haJj>d(Od>yNJ2 z4Iviu2Gt=)ryJhZ5#ozZo5?B_Ho6-DZMf>^ybYm` z3aPFtFK-5zoA01Okui=_t;drw&>CqOeIIv1Gt+1eQKW?T2J~v={G}BGm}ETC%RSPa zPv0mACuLXbzPsqV(`0<}K%z;@~y8f9_y=^P~6P&kX zW?l2=)glVp$z&l}1PBRS{2QM~jC57z1u=eh32e2B=-v~$H#b7YB)m-9>z66}>o~!F zz>01B(%&}--d@CP?V#FG*p1hT^_*4LbD2aG>Qa7t?ea+tS0OZ;#T}oIa zw-#+mmcE6hkj*cexAZ8Oi%s=vP{vETZP=!{>xlR%w)k^Z+V&5(#@yMO%LV0{hu`Yi zGWpzcV@Q>Als|x~Bte^-zblLbsm zGAk4o*(??)T3t)!F4B>s>e7fxZMn#;3WCqzpukYhEypKjxCD5M-J@j)Zt|W4>JaloX`b{ zz;4*TU%)kUh1wDt#*C-e5*{iKu*puZVcZT@?T5JD*w6v+#oh{oQ15BN8Q5alR&DZw zmyF$UK`cI$zT9tbkUO7xud}8)Uj|Cb{PuLvX(xl{AE;hqM=E8VU?vhQf$QtT z!>f&%6*;yhPeRqZb*g8_xU-|Axg#T;FkO0OdijuI_+ycF#sNBo;m!Htj=BQd0^du&8pQ+}gOi=5 zKeYu+yCI4f#UVwC^+hlAFKbVDl)5VwDMLi|2JpjBH7nF^+3K%bp7_k%2(El!HSi7b z;Io3w1`Wb{27`=J@4Q-~t27|_;dv>-?^_e^NCT#X*fD#FT=-xNRnik>I0ONPLuEiH zH53F4M#xA>L7*TRxGEF@0fC?>umbY`F)%erQUw2*8e0X5sxi zOuNogBb5bj$Q+X+)chOIP_Yn%oS`w-p|dIHM|^V0tC6B*8ZFsw|F-Lpu3rmYO&y^j zA!mEnNTOnG9`W$;arVMQm;1%sO0-a^SmzR1cbO4_*q^+cbI9&Xr!d8^xjX3gwaNn3 z0_UGpbBhmdK#zbKm$KVFKDG99JZ1MT8m{H3A!j%NTf!d zEU@xKgmi2xWqlO7xKNlJ@e7S0E1u@{QT+$D2o38k{btN diff --git a/exercises/8_rl_gridmove.ipynb b/exercises/8_rl_gridmove.ipynb new file mode 100644 index 0000000..97595ec --- /dev/null +++ b/exercises/8_rl_gridmove.ipynb @@ -0,0 +1,275 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Reinforcement Learning with discrete states and actions\n", + "\n", + "In this notebook we demonstrate how a RL agent can learn to navigate the grid world environment shown in the book using Q-learning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# tabular Q-learning is so simple that we don't need an additional library\n", + "import random\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Grid Environment\n", + "\n", + "The following class contains a simulation model of the small grid world environment you've seen in the book." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Environment(object):\n", + " \n", + " def __init__(self):\n", + " # episode ends if the agent dies or finds the money\n", + " self.terminal_states = {(1, 2), (3, 1), (3, 5)}\n", + " # immediate reward for each state (incl. unreachable states: 0)\n", + " self.rewards = [[-1, 0, 0, -1, -1, -1],\n", + " [-1, 0, -100000, -1, 0, -1],\n", + " [-1, -1, -1, -1, 0, -1],\n", + " [-1, -100000, 0, -1, -1, 100]]\n", + " # filter all states that can actually be reached\n", + " self.possible_states = [(i, j) for i in range(len(self.rewards)) \n", + " for j in range(len(self.rewards[i])) if self.rewards[i][j]]\n", + " # state transitions via actions (walking into a wall = staying in the same place)\n", + " self.possible_actions = [\"right\", \"left\", \"up\", \"down\"]\n", + " self.transitions = {\n", + " \"right\": [[(0, 0), (0, 1), (0, 2), (0, 4), (0, 5), (0, 5)],\n", + " [(1, 0), (1, 1), (1, 3), (1, 3), (1, 4), (1, 5)],\n", + " [(2, 1), (2, 2), (2, 3), (2, 3), (2, 4), (2, 5)],\n", + " [(3, 1), (3, 1), (3, 2), (3, 4), (3, 5), (3, 5)]],\n", + " \"left\": [[(0, 0), (0, 1), (0, 2), (0, 3), (0, 3), (0, 4)],\n", + " [(1, 0), (1, 1), (1, 2), (1, 2), (1, 4), (1, 5)],\n", + " [(2, 0), (2, 0), (2, 1), (2, 2), (2, 4), (2, 5)],\n", + " [(3, 0), (3, 0), (3, 2), (3, 3), (3, 3), (3, 4)]],\n", + " \"up\": [[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5)],\n", + " [(0, 0), (1, 1), (1, 2), (0, 3), (1, 4), (0, 5)],\n", + " [(1, 0), (2, 1), (1, 2), (1, 3), (2, 4), (1, 5)],\n", + " [(2, 0), (2, 1), (3, 2), (2, 3), (3, 4), (2, 5)]],\n", + " \"down\": [[(1, 0), (0, 1), (0, 2), (1, 3), (0, 4), (1, 5)],\n", + " [(2, 0), (1, 1), (2, 2), (2, 3), (1, 4), (2, 5)],\n", + " [(3, 0), (3, 1), (2, 2), (3, 3), (2, 4), (3, 5)],\n", + " [(3, 0), (3, 1), (3, 2), (3, 3), (3, 4), (3, 5)]],\n", + " }\n", + " # check which actions per state actually make sense, \n", + " # i.e., we don't want to let our agent run into walls (this just wastes time)\n", + " self.possible_actions_in_state = []\n", + " for (i, j) in self.possible_states:\n", + " acts = []\n", + " for a in self.possible_actions:\n", + " if self.transitions[a][i][j] != (i, j):\n", + " acts.append(a)\n", + " self.possible_actions_in_state.append(acts)\n", + " # get ready for the first episode\n", + " self.episode = 0\n", + " self.reset()\n", + " \n", + " def reset(self):\n", + " # at the beginning of each episode, the agent always starts in the upper left corner\n", + " self.current_state = (0, 0)\n", + " self.episode += 1\n", + " \n", + " def step(self, action):\n", + " \"\"\"\n", + " This is the main function that run in each time step.\n", + " \n", + " Inputs:\n", + " - action [str]: action the agent took; must be one of self.possible_actions\n", + " Returns:\n", + " - reward [int]: immediate reward received for reaching the next state\n", + " - next state [tuple(int, int)]: coordinates of the next state\n", + " - done [bool]: whether the episode terminated and the environment was reset\n", + " \"\"\"\n", + " # see where this action leads us\n", + " self.current_state = self.transitions[action][self.current_state[0]][self.current_state[1]]\n", + " # get the reward for the new state\n", + " reward = self.rewards[self.current_state[0]][self.current_state[1]]\n", + " # check if the episode has ended\n", + " if self.current_state in self.terminal_states:\n", + " self.reset()\n", + " done = True\n", + " else:\n", + " done = False\n", + " return reward, self.current_state, done" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## (Tabular) Q-Learning with epsilon-greedy policy\n", + "\n", + "Learn the Q-table for this environment. The updates to `Q(s, a)` are made according to a more efficient iterative approach called Q-learning (somewhat similar to gradient decent, only that the target value changes in each iteration)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def epsilon_greedy_policy(Q, state, epsilon, env):\n", + " \"\"\"\n", + " Choose an action based on the epsilon greedy strategy\n", + " \n", + " Inputs:\n", + " - Q: current Q-table \n", + " - state: current state\n", + " - epsilon: current epsilon value (probability of choosing a random action)\n", + " - env: simulation model that knows which actions are possible\n", + " Returns:\n", + " - action index (to be used to access env.possible_actions to pick an action)\n", + " \"\"\"\n", + " # exploitation: best action\n", + " if random.uniform(0, 1) > epsilon:\n", + " return np.argmax(Q[state])\n", + " # exploration: random action\n", + " else:\n", + " return env.possible_actions.index(random.choice(env.possible_actions_in_state[state]))\n", + " \n", + "def learn_Q(max_steps=25000, # number of sampling steps\n", + " learning_rate=0.01, # learning rate for Q update\n", + " gamma=0.99, # discounting rate for Q next state\n", + " max_epsilon=1., # exploration probability at start\n", + " min_epsilon=0.001, # minimum exploration probability \n", + " decay_rate=0.01, # exponential decay rate for exploration prob\n", + " seed=15):\n", + " # set seed for reproducable results\n", + " random.seed(seed)\n", + " np.random.seed(seed)\n", + " # initialize environment\n", + " env = Environment()\n", + " # initialize the Q-table of size (possible_states x possible_actions)\n", + " Q = np.zeros((len(env.possible_states), len(env.possible_actions)))\n", + " # reset exploration rate\n", + " epsilon = 1.\n", + " # we want to keep track of the cumulative rewards received in each episode\n", + " cum_rewards = []\n", + " total_reward = 0\n", + " # actually learn Q\n", + " for s in range(1, max_steps+1):\n", + " if not s % 5000:\n", + " print(\"Simulation step: %i\" % s, end=\"\\r\")\n", + " # get the index of the current state (to index Q)\n", + " state = env.possible_states.index(env.current_state)\n", + " # select action based on policy\n", + " action = epsilon_greedy_policy(Q, state, epsilon, env)\n", + " # take the action (a) and observe the reward (r) and resulting state (s')\n", + " reward, new_state, done = env.step(env.possible_actions[action])\n", + " total_reward += reward\n", + " if not done:\n", + " # map new_state to index\n", + " new_state = env.possible_states.index(new_state)\n", + " # update Q(s,a) := Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)]\n", + " Q[state, action] = Q[state, action] + learning_rate * (reward + gamma * np.max(Q[new_state]) - Q[state, action]) \n", + " else:\n", + " # we terminated, there is no new state to take into account when updating Q\n", + " Q[state, action] = Q[state, action] + learning_rate * (reward - Q[state, action])\n", + " # reduce epsilon (because we need less and less exploration over time)\n", + " epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*env.episode)\n", + " # save the return we got for this episode\n", + " cum_rewards.append(total_reward)\n", + " total_reward = 0\n", + " # visualize what we have learned\n", + " vis_Q(Q, env)\n", + " # plot the cumulative rewards we got for each episode (--> how fast did we learn?)\n", + " plt.figure(figsize=(15, 5))\n", + " plt.plot(list(range(len(cum_rewards))), cum_rewards)\n", + " plt.xlabel(\"episode\")\n", + " plt.ylabel(\"cumulative reward\")\n", + " plt.ylim(-100, 100)\n", + " return Q, cum_rewards\n", + "\n", + "def vis_Q(Q, env):\n", + " # see which state-action values we have learned\n", + " plt.figure(figsize=(4, 7))\n", + " plt.imshow(Q)\n", + " plt.xticks(list(range(len(env.possible_actions))), env.possible_actions)\n", + " plt.yticks(list(range(len(env.possible_states))), env.possible_states)\n", + " plt.title(\"Q-Table\")\n", + " plt.clim(-100, 100)\n", + " plt.colorbar();\n", + " # plot value of each state\n", + " value = np.zeros((len(env.rewards), len(env.rewards[0])))\n", + " for i in range(len(env.rewards)):\n", + " for j in range(len(env.rewards[i])):\n", + " if (i, j) in env.possible_states:\n", + " value[i, j] = np.max(Q[env.possible_states.index((i, j))])\n", + " plt.figure()\n", + " plt.imshow(value)\n", + " plt.xticks(list(range(value.shape[1])), list(range(1, value.shape[1]+1)))\n", + " plt.yticks(list(range(value.shape[0])), list(range(1, value.shape[0]+1)))\n", + " plt.title(\"value of states\")\n", + " plt.colorbar();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# learn Q with default parameters\n", + "# -> finds the best path quite quickly\n", + "Q = learn_Q()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# much more exploration (and more steps to do it)\n", + "# -> also finds the second path\n", + "Q = learn_Q(max_steps=250000, decay_rate=0.00001)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}