From e0233faf88fdcfaf6a354f60e0b6bd08d60ece50 Mon Sep 17 00:00:00 2001 From: wassname Date: Thu, 6 Jun 2024 13:39:17 +0800 Subject: [PATCH] it learnt some acheivements, still takes a long time, and I might have made some params to small --- configs.yaml | 49 +++- dreamer.py | 14 +- envs/craftax_env.py | 4 +- image.png | Bin 0 -> 36554 bytes justfile | 2 +- nbs/02_torchinfo.ipynb | 553 +++++++++++++++++++++++++++++++++++++++++ networks.py | 8 +- poetry.lock | 13 +- pyproject.toml | 1 + 9 files changed, 628 insertions(+), 16 deletions(-) create mode 100644 image.png create mode 100644 nbs/02_torchinfo.ipynb diff --git a/configs.yaml b/configs.yaml index 03a46fa..2be2d69 100644 --- a/configs.yaml +++ b/configs.yaml @@ -62,7 +62,7 @@ defaults: initial: 'learned' # Training - batch_size: 256 + batch_size: 64 batch_length: 64 train_ratio: 512 pretrain: 100 @@ -136,18 +136,59 @@ craftax: action_repeat: 1 envs: 1 train_ratio: 512 - video_pred_log: false # FIXME + video_pred_log: false dyn_hidden: 1024 dyn_deter: 4096 units: 1024 - encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 5, mlp_units: 1024, } - decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 5, mlp_units: 1024} + encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 4, mlp_units: 512, } + decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 4, mlp_units: 512} actor: {layers: 5, dist: 'onehot', std: 'none'} value: {layers: 5} reward_head: {layers: 5} cont_head: {layers: 5} imag_gradient: 'reinforce' +craftax_small: + task: craftax_Craftax-Symbolic-AutoReset-v1 + step: 1e6 + action_repeat: 1 + envs: 1 + train_ratio: 512 + video_pred_log: false + dyn_hidden: 512 + dyn_deter: 512 + units: 512 + encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 3, mlp_units: 256} + decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 3, mlp_units: 256} + actor: {layers: 3, dist: 'onehot', std: 'none'} + value: {layers: 3} + reward_head: {layers: 3} + cont_head: {layers: 3} + imag_gradient: 'reinforce' + batch_size: 128 + batch_length: 16 + + +craftax_smaller: + task: craftax_Craftax-Symbolic-AutoReset-v1 + step: 1e6 + action_repeat: 1 + envs: 1 + train_ratio: 256 + video_pred_log: false + dyn_hidden: 256 + dyn_deter: 1024 + units: 256 + encoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 2, mlp_units: 256, } + decoder: {cnn_keys: '$^', mlp_keys: "state", mlp_layers: 2, mlp_units: 256} + actor: {layers: 2, dist: 'onehot', std: 'none'} + value: {layers: 2} + reward_head: {layers: 2} + cont_head: {layers: 2} + imag_gradient: 'reinforce' + batch_size: 256 + batch_length: 16 + atari100k: steps: 4e5 envs: 1 diff --git a/dreamer.py b/dreamer.py index c8fa493..ed940e0 100644 --- a/dreamer.py +++ b/dreamer.py @@ -308,6 +308,7 @@ def main(config): agent.load_state_dict(checkpoint["agent_state_dict"]) tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"]) agent._should_pretrain._once = False + logger.warning(f"Loaded model from {logdir / 'latest.pt'}") # make sure eval will be executed once after config.steps with tqdm(total=config.steps + config.eval_every, unit='step') as pbar: @@ -356,13 +357,12 @@ def main(config): except Exception: pass - -if __name__ == "__main__": +def parse_args(argv=None): parser = argparse.ArgumentParser() parser.add_argument("--configs", nargs="+") - args, remaining = parser.parse_known_args() + args, remaining = parser.parse_known_args(argv[1:]) configs = yaml.safe_load( - (pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text() + (pathlib.Path(argv[0]).parent / "configs.yaml").read_text() ) def recursive_update(base, update): @@ -380,4 +380,8 @@ if __name__ == "__main__": for key, value in sorted(defaults.items(), key=lambda x: x[0]): arg_type = tools.args_type(value) parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value)) - main(parser.parse_args(remaining)) + args = parser.parse_args(remaining) + return args + +if __name__ == "__main__": + main(parse_args()) diff --git a/envs/craftax_env.py b/envs/craftax_env.py index 32fc068..2ea73cf 100644 --- a/envs/craftax_env.py +++ b/envs/craftax_env.py @@ -224,6 +224,8 @@ class Craftax: def step(self, action): state, reward, done, info = self._env.step(action) + info2 = {k.replace('Ach','log_ach'):v for k,v in info.items()} + reward = np.float32(reward) obs = { "image": self.get_image(), @@ -231,7 +233,7 @@ class Craftax: "is_first": False, "is_last": done, "is_terminal": info["discount"] == 0, - **info, + **info2, } return obs, reward, done, info diff --git a/image.png b/image.png new file mode 100644 index 0000000000000000000000000000000000000000..a57bec32a42228884c7052bfa89f725f9b4fbcfe GIT binary patch literal 36554 zcmdSB^+S|x+bs+@5{ig`ARq$LjdU}JfV6Zs4BaK&DyejbbobCH(hS|*UDD0I2Jh#2 z_j`Z;z_plvUUJS~p^#E*?3C(~ZTZ!C!NWNotya6yHK6h086qke0SVI^0YrlE@FNA~ zVA=0-QiA{d{-3YN$-Ql3F>zY|x$Hk5N5vy_XvLeOO#SQg|2Hmgc_Q}<7Wa;JcXy zuX#$~^vZFqkT28szw`5O*&^z10%y7FUFIVbuK(|S!!`C5PHQ;$dHf8!M{@RlMso=j~z4OU* zW#C+~x)VF@>!B;LsB=EB;&C$cWnQ_W;C^UJE?9SKI16%jbGe@3Vh^L$s)Oh+{+ysi z*uZ@Vj#UYgJni208jVVupK3Fjg z{A(W`H1aX?TB$q+werdS>7{?>D3qK;D{~L=l;=df5fu3Rq3ddxx!DtT@Ala&a~*&F zGbvy)G?PpPo;>ZZ+6A#Be@Nj&K=Kr6>_nZhLkrh868WKN|JMMK19K6Z^pf=9%KwS< zs2Ct_O$(6!9nb%7T;B2|D;oWw=}{pq-@R@Oj@J{dDI?7094FWXZ?|Nz11K>64n*1u z(4cu}T8eebE9c823P)hHzO>9(H(S1OS$)CkcD5Ubxz$h)%)P*I{6O4%qgzh!N0Yxx z5sVDniVr3atV477d_LzK8=$Z?4Aq#V8#{2Z--rYD&#WL}T~u%y6jjr04+`l|;*Q5v zN1n%PzG#;k36a>P)47z0a!y>(xxbnSI;_pX}(fo{d~BnIoKEX zt}SiCPDnBxh(T{K>n61H&hCnFSGHXYo7zv?V4bnf>M#HSWJkG(UZ3$@piW=rDIR0x zxr(N3d(aYN?P^3H1d$G+IGiYlY^;h;g{ENm$8fUoeBO;i^7g(!=Qh7_K;tt9Y%u^Cd+)$QHXbj({HFKTogv+JpUaGZDEH&4A)?PE3PE$tsnK;S*JXURA zpMP%(WCA7lw-7nc>g_ii_fyh5Ye&V3^LwA!aVZ8k(yHgd}7f{iBm3{SNiZ`GS2#qmKL}?Sl@7ii4Tg|&8D6$oD zWCHa~vh`*xtJ%#9%S6);MIK8id{H-2uXD)Tnk=KB3NRNDZFq&0db?{J%BX3NHTZ@% zMxUd~YGl)&&6OeW&E_<;fO|VPgXQs$Y(A)AQl=&}Q!;6s)LYwpzmCyQo6hn1O*@_r z8w#;|>6F8WG>ox+SO4zMcdtt5xAtJX#MJwn4anx;Pe?CAH5*lujqgyl%f{VFj^LEc z&t=-sKP#P`RoN1H&$$*6?2o#s5 zS-Fw8aFab(&!*dga6I63H{yP=^!4oCIMt(yL9%=4GvCFcAEfIC={m5K1B!ybzH#lk zIf`wrv`e+^z#s3L)B8%)H>RSKtZ4X4D6JCYdl1h>+FMFbCc|+=>i9yoLN!+w0cR{`U*9NnzJfyqf*G9g7p{+N_PtHeZ=5 z*w#$0op_1FhWLPZ$pLgidwnY>H4N*8AZk>K_lS8xVYvjaS~9U~5mCahYt(d2*&17p-E@)Q-KfS8ac+j>?V0M5IeUk@%Z*n!e6AZhc7(!{X~=f>&}gOK5>u6rUtt ze5#k-F$;|3_-^x=;d#+|gnwR)==8OEmAvVe1Xp9_Qrn-3iz0>grt`*wDare4--1j* z_hIwxFv(xr)!g zl7v0R)6{hUM>=dTeI>SfCbcEARX-xb3v0OAE})dzH;ec0%{n9~jk0oE4y3Zs$@2* zQd@IBOOp6*3RXSJW3m~M=2Y`jcm$QRQ|TpFWyDi1EFm&Xx=?-g@2-M3gAuKPGYZy%d3fIE2nuEV_u#bU^obih9>CePPRt& zdyEWP*fPommdqg0r(2?bMrIIEnv4{;pd>9xJIKs$m1NUj7|;rfh>NB{6`{ zosc#1EdwV!q=El@wgUE1em=o*G~#2l2CG}Y?vvQ8;R9X#{B8xBmH4~SMaaYmY#FmiN4jzVb}{!DqK>Ur0KT3}n<&5D0g zjUt`#7RV~l3UwfTdCf#fg&vs=_=qAO)E@}cTWr~YrYh6bNL*#eUyf|%v*~H2M6U)7n@1Zu zo*xbCjUB~=jhrFmSf8qBu(5du`GMj1M&uKt@f#r1W+m5-On-)Xo1z7rC7Tj8#%Ryk z^)fhdSi{&S*di&V%L(JxDoHgYUqMTD{&W&^*6&p%&V6K1F|DGOR?Sk9X6!Dch-@8s zJBB!x`!l4;vTCZMXI1p##H-qBu1|({&(N6A@5|4aBvYEWm>QGkNlBt0WgQR4Yp z1iplTr#uX3(Wz(exHbgnwF=OG`+?)Y(-iSRI+95?P2ygiu$EfcC%@-p6FeK2hI_&} z+#8QX+L_w&^yTc<)>NduE;1*6Qy0?xngu$#HUYo#*6xClDV0>{$T?5x`-CaP9j zx<$L;+D&SBfg3?)LvqmHc}o3UjY`P3X)IaMWR@4|wMs!M$c@U=@w|JXBYA4g zcc&$LM^E+~hp>82Cv+BA#%NTVZ}+VC)=g+I_cjvvT*`@BClG0`z?mFg+axcXj|`)P zA}9-wipXQfr_E!SKTBjUThuIl8J`?IsXiMu{USrmb>Dn{-CT9{bELI0gsp~L0VOY& zk460y#!a-DE-aDWGZgd5*DJq!Btz`EdDrcZ&V6}Dc1(sv!=DZ(mv~VdBVF@7c8mw+ z6oH+aDHs(5>;^~0Rnul?Y*Je})CSZ{M9QHx4y6t`$X}t)gsuZH_zuDKSkw~xLVn0< z8T-GUGF2|k<)@lm+7QFh9X?vrc+ChYXgZ48KveQY^Zo5XqQQ6mD$DNlvlY|li2~mF zZ9Hs8*&gP0(;3-WAy4eDwsK~edGyt;Pku2_{yi<`%@(TB zUgO2SbMp+Yl+h56#~z)7%=II~K~wI)v#Bi-oapvFQuK}Rb!Q5fN?Xld35tT5KF_l$ z<3t*TY}Ii%LbbEY?v5?R{3s00`J!dlO>t)S8%^KTkix~Elo^D6_h?jaslm4y%6VVj%gD-V++W#SmBo2Fu{BXTn~&!qA*4Z(Y0loZjbWS@ z_57UV(^JF~b%xQKKn3K+@uiiukUbfeO#rED;cUQ?MB8uHu-fci+}7aIhw*#KlK%p z)L0qu=#i%S1_6`ewY^Q~v{_!bD@?pFzBQaoOcWRV=!eyoks$&f=_e9^SfVz@6v{_z z7Lw~lti?#@0yukkzqS|Iwk*$moi&w)qYZ4OPD|6qvOBKrwpWVROuG2eXVIcc#h1)N z2|g}``9B!mqyaDZ>AX-7oFA0?olz{H?KzUsTW@VT1xrIZN&FLn4y`K3>C(}04eQ1) zYtkK7iNCxe00}mEBzxKev_Do$>4+eMp!Ef@vhsuzA)Z> z6WPdiAvpI0n`47ZpiZ~3StqY^^Mj z|G-H{PO2x*0KJ{8+s(xv$b5rK?M9>u@<_HcE0DQhbeFRu0KO@XXV!baf4Lg(lj1`m z8M7KMQf%4`K)UtY^M&ty1})x*qIktFs2@u%5Bona1>i}|^KccY6LDH;DvPfr^k<4k z_5u9M9!3`Y+hfSwe!cG>~?>}?iDLQY#vZzQ9s@c-x#aHWF!Tg#Kj zYr-54!pxvPh{thzRG%Wd=I`(5Aop>dSI)3OB$`KX{R9i^^Hx@@1og|e4@(LGM2mf> z=2b>tVK--o_pjB02VX6@l>Yv}EG`oJ1U$tf(E3M6@s@=@m%jHTFN>58zHYvwqQBvr zlqkSvw|#gq0-_G;MPYMX?QvE#Db>GL))P=%Z2`hpjfX2g@%k`o)s4Q1kNbP05AqA3 zMS8RBxkEn~;7x{59o`%daJs#qA%1uc8!2MyDPAv2VyTck`CqZ&uW|%VI|bx0sCKe~ zKK$bV%s76xs-gerN*5nmjBne^`=S^uTYVp|SItjRbA#m982m+9Hj!q)yq{~IW=Y4QQkNk|l z@b=$|CV>dT3M0j(042ezVyVn2~!TzoAT z0b>`HloTI;+IDWboDIQJ1sD2~xJNMT0_Ff9XV$n~M9i6nE0ek~g{2Cpw%i*@_I<89 zMaE5l0%M%`8BKxEqr+uAeuFChxc~LsxSH41W#IZMYeyUdC!t|Sl8YU0UdD3e)048! z!$)Bi^XKz6Ww(TP41)MiVznHDrwnp_!g)N0rh2SmO63-9L!|X6X)32!*eoYYpoAXn zT(A_q@j&H~|4v_fY+Nfxi>n9_!C2_tyRxaFGd&Lb$U94IQYj}jlY3phOF5^C{vBzqjsi}ya+NXnM z8L1l@>}~c*8?5UJ)6R2_vGD>Pb!xy);DJS`_E#W;3#VV!)8Z2VIz9oXjxUY(q8t{Q zJj}aD)uQqgatL4ztMOti;$~0RAdT+PI$0^~<}qsaF!srbYMLMOO^#MOSb3LNNR#X7C2++fw09D!X-A}W-av~Vhj};*dl-I7vw-QX);Ep zESimB<$BD5MOr_bI9>Jq`Hb&^MqjL|gfuzgugw=#8D~Z5y4A%|JIC5pCUq0@aH@y- ziza^*%AKws1(|j2f6*Q`5KKyiae&kj)R+^04_rNXC+OTtB{}9e zaQPVfTM%|)YbWO~Q;=7EC63_hV)+EQBa`)pakk~7?aLFY3Nc1Yd^YvGVW{~;6}7dN zOZ}}NtrwnPn2~1F>6Q7R&?DR>4gyBmVfWW6cshT4Eh$ZTmTZpKVHMX!r@H}K>P|va z%b!)uXCg?$!qMvsXH;tXu9D9#pX9t46K&HQzU$AJHOp|}38{~AXg)2n#{Xv*&Xw|B zi;KRDTHBv2GuSI%{7URO`JhGFye2wibv@bq_(DZhFJL$oli|`RQ<GZe^yb0Z=pkp{2im`>)uQwDMFdc zo|>nu&tj=e>llfEUi^9;k^K@qXDX44!*uGgQLto%`b>;P^s4}uart#K+hGBnrIYk& z6E3OY1sSh3xqV#^ifAfnu8?}RO4O$h&?69y@6T`Sz&5N>TTpT1*>qK3Nz;`XtKKhv z)wW_=dE7`=x8$HDDaKH;G;Dp_F*G2Z-0-Coxh7*G{gbGCqCjkkpDYKoddSf*9-jFa zOLx5(JZbnC5%oeAynq;CPc^r7E$cqWQK!A&d2Q*I;e0SZgm}n6F`&BcPBA9 ztQM2ZJFTYiOUiB5fQa1(D!-8uiSG@x`13d6{yFU-0zMBw`;G6AfpNL|*%04u8-Y4W zRKihj6Kdo4bMv2UhlSZvLnlN*f&?P@gn`5|YFHW&3k5x^OUqQ(UW-u$E-q>FyuE^+ z|3(3N?Y#l&IUD*6A*itaB(TF2QolR%h~_;4(wi3_%?I0GJ`*nC$0sWK?gv&w^GC5g ziiiwz);0{5j&-FEbH)!ben=*^meEN-T8+T$>PZWxlOYghVtjT3^z#U-EB#n4%2rJ~ z=p!;B>Z%ZsyPoG~q9LM2g@NahwKRPqU%6_a-1EYC(-e1cl{H4tQNi@WYIm=0rx^Efj!_2` z^@eot1{C-6q$KR}Cg2^sSG@Nr+~fd8&By#9hj(WoH2afUDWYeuiuEML^UAcMaz$eO zU}C;$HZSq2(l_YpO$iww*xTMztBP&?~1e8cf(%SmaO3moL1O$8KOP}jP{FM4qmmrCY1n@Nhbi7 znSTX6M+LW_AgKq@tjL7n*V~OxQYO0FA(?Pz0d-imFofitjxs7^RC4$<>#sA4Kf^`_ zKq$b#-qd4%zfF(L(WgY*7FbI=x>_G~X@u)M0e;|kw2OukXrE^IGb8~x;MAMFoIUzM zgv9WX#AGu0w9PtB=+U^P(33UBa`gJhfma6ZY!bcFxLZKu)zL|rmI0I&@B%e8sP>YQ@MVlM8OwA~%a3w|f4l$yf=Vp1PJYUXeYDv2Jpm{Ra3T@8 z0AT{3XfsNZxwQgdg%;FK2LgkDBTiy5Su6V(hemkFrIsaT>+>c#deS@B`#qIJAeoQ{ zzJK+XW`!g4ejjO#-ak?Tnx9o*{vK$=Gm^oeFUf7Y0+9U?EU|@=}{@RdN}XTIMZO_zG165*~m_`qUG~q@dx9A zj<&mYE=>X`*3?E3O~Gm5f3dlRVJ|7LH<833e3_#%ai4MCfb6?7e=r~<=KP1BrC_zx z99J%55+cE*5ZbIRryZ~N-6yM;A$*2W0=Fk4XF!1>8qMa12;GF|JEH;f?Aljbz45^V zSHVW{YZeM0AN#~mc-L7KH#U=EsN$$Jvp1& zn838Zv)>TG7XFA@iHwNO+q@HxQLu3O_2eP^76kzgdcX4n3!FxSsyo@Pk}=clG^JGi z8-=zt`Ovf8vU+ZO^A{W^9|n3KJW^`v)%jAj{IJsqKYfpkTFHfoUcdS*t?zo@_j*q; z-MqNALz@eg?J43J$C^TRoXo32cm&5*`L%Q*_&Ze|imQ z#_{GQjpOuc<-e2c`?ZU<1?;ft7};5UmZqZGV-~(g)_{b-|v;} zqR|oO6g(xwQo*+OV%kI>}G)Y7UP$LLCpd)2Xrf zDNUMWoMd6?_!> zA`eh$)=wEg*_RVQtEsS3y|!PcvuADq9fJU@^OaMXe+{CKA&7Z5!q)Dm=GXMy>;0sp zi0eV&=TJPjo(T>Ibu_bakL>z>u^2q;Wy8iQ2n5?GL$g8pZ~5&8Lg%-cAp82X?EP-* zeerO21*D|$jHSYOfCGid;RimytAnNutL1d%q}%EIVEN$wT>VCZSR;#aP_x69JedKv ztZ4Mn!1DY92espM-fz&fW0G-&-7Is)8ik8k5TJxjwr}Ql!s1zTNF8l%EuX*Avdc+k z>nUvP$Y+QCrH%sCr2LN#8jfXVSPp8gaL6`#6tx?lpK3c~WXL(Hl{5f#qrf&_L6C$n z2~0}a)L#@6mFz#1^c( zGW>r&R&=-Av2PH!3nuAIc?CE243!z22H+&3Z(DcBCRkh&l@9c0Q75dX9vvYStU?Z0 z%wE&=J_M-kl}oZhqwIhAqm^}N0Yp$?{RiL0Xr4kSMEj_Ztd?y{xK52pXH6v#$_cQ* z2{x9i5BpKg+-+!H2S}xg?rkNC_29-cV@To)n zfd9ym;-evuz_BK6FYwIvHP{LCPJb?oN~%WOKrJNpyQxyy#E3mA`K1*SojOAajgc{4 zt76f-$Uq0pvUo4fXB4gtCwSqD*~#zAdx_=0SLViULAFOJ692s{ES2uM8ieX0WJJz zeCYMN=5x;oqKy16#_NeI0_gRb$=nh$QGGfvO}_NIP4#?=an?6bS9Gyw=mi_PpcygY zcW9kSiI)YSajgx13$vO+2cF?h>|ajPmwW5PBf_;$zU033g9kK>-NN;8Amr5=)eljX zD44FS1n2YGiS3Hw>Z_+KjHbP$uR&OtHcUggVOGAIjwm$2K40}zoRvB#7iDx|!tw18 z^`(yAu?wZn8OUsmjVJOMPf>i7DSu?!&-Em2Gq{m)Rrh$qF=tV=>voV1xpC|SXVWde zHC%Gif`xWrQWN`#5`KQ=)D8_4b%6M0*f@X^YeBDJA++&L8q}PID<&)Z zKs2$|J&0!Rl%Havu>@F8q|X9ZQqSR2dpSP+qq-6_`%NkQYo4O7TqV47c^nDX^OSy9 zVz`HK24U~uH-mRx{QW)-h78oT79<54Gk)28k@SO zrLEL_KtPtbI%mWKM_{0VHfs-O58}sJh!f)WmjwWT@B^UtkPI4*+<<1W*qs+e)FhIs z;qP{r=5(DZk=Pa01)L;}st9hoWncfU(?KbKmNgDR^fhA2EgH7gAKqGC`F7PC90odK zIb`0oE3BjyukQ~aRLsTVJDy=&2;~gn+-GdM&lHbOz z88Pr~U5YFYl7*ctg2D*eNR>S30eZXkB zK~9IJmE(LUy3z;*2~GFfs2KIvKj)PZmPA2T5g{gHM%JmeM@BZ zsZmQys-UXwl}kH7^uRfB6xZGl2Jx0dpqLO)NqCNOplL(uXReC;@PR z1_hh*AsTB=u=1$M(&R(WJXLS zGX@^Wo~|6Rw$+pu=PX7~9l8@`Yo=Mw^A-n`5RTBH#yU=GWQL0bDEab{Q5j1h_x;_D z_wiYI1WF;iFTah>tk7_>5qcymdCNxFy+mVF0@YuWtXn8n&FgxGG=<%j-=@A= zzO_r^9-*&He-u$7wS~(+sGMCMs}`ed)QW6+Z+2da@8AY%)<8=VV=B*H)AE9YtO_$0bmAzG^viil~h}) zc~sv+OKc}Z64Y{sLj1;jQZ#@mlDLr$x}{dxEq*EhRhyBUwNaq9l8s z^*hD&{eKr;mmdOaj>{)YsRSs&v>M@Lv2^44nc>6EYcS1R%Yo(-tL@es<8KrqE&nqG zw+)uLHZ>$(zaT^147VTwdj_B`y+KlQUfc-}sQClxk(X85{Pnb-_+ixsTW8Grz!$+| zL?F6l+Y_9clkpZYvETqR61z*C$xx4ro%`qu-t(2>B_(eJ0qoCPW3rI_AimLG#@bye z#K3K=*6oRq=DDtLvQWm?Aa{}@-VyS zNH0-qVlUPctrSUds!2YfBR*r!wa@o>>VnRZIcJ?Cp9&=EcAOMVl0_Os=Lm99As54o|7%iHi6w34t?!AR#jKt@vNsSX&4?B+%0WW&&0iio_HNdr!w0T_3ENj$P2V$SP!lqrfXhDMRf*t9YLMvT1`$ zc^MZ%yt;{8MB84X7<1TY!_z~jsXx4FKZ@t~P(eHmI+)?dCE6*tJKw~UzyWAC(;Jx^?1Xa=(jq&}X+$iPJ*W!07 z>b6_mrI7_V9cxR5xEp2%=Qf_E+oG3`LzfO&v165)W$^^*#SWb>wBxcQh4)3}y#r>~Y>pmg`1gaYwAeNunKhsFTwHPR)f z%@bz6x9Zkb#YHMRpGFg({~;9Cg?m>=bLBXW0Z}(TT19)RFOjq4-7hwmVm{$Lx=DOv zeTAa7d4Z}jqWj86vT|Fa+bY3*&~+un2G<$W#y?8^x&-3LctQFr!UBLb_>-h8|7le1 z*|a$O$^^aLm%G9D8tSj8R)(`ASFA*?RGv25 zEd8qbtea?pm zQtxcdYCc~04fFFuJry0;#bkPAPP}mY8HVyH8dGmpsnouK2_CG|43_6JHLyotxkjSgB{ zkpYHdzrL{AxaW^qW?f5UY?u}JM~2<(9X;)dS6=V(FZhgn{OJ}n^C3F-%t;+TW7@ly zN?bj9nJQdGPm{p*vi*`uT#&K&$l+X|`n(CUh`f^tNhbHgcLoxRRYV$lKx>U;pfp)h zomKO#ll@N79_ROLsj$TqESui>lf@Od>~|ppPF}F~*wI*r7u6eGPxDn@<4b#7pB25z z9^%yby!3d6h8J)!&wNm;EOd>{F`p>T`%r?bUAT&41(=#4v&SXStc7N;`|r1@C2^_2 zD(^mRX7`S`eDzY1q@#7$JHfoT z%v0xTICgqYa^+{6fX(HLLHb$1Qef9vOI1z{T6GNwR&;X9I=Fuv>rgo~jIP2}!_@tBP%nPUrK zEAJMSOD#fqZXTz;JNE+D+L{Zv!ZG6-Rd9lCg8Pw?%MB|U*)x``eNR&8bJ=il-B!h-vKm4!UbIc8)8pK^vVPyOs|wd8Ns?`AoguF6TOw)a;X!F&86 z^{O9=+{gy9=+s3s;$!)ayDHWz*HtUDn|hcy5;3M*3TxLBvL#~~UDA6PEMcGX?Dx%y z`jUAkteno2Je79gp4fLMK&t7 zDOeU@mQHA#Wcy_(pdb{C&s(pS%Bb4GLRnO#%+?Z7P70hR0E&CuJJIrMAJbAs9@*mr zXOilo)vRRn4&w8kO&a9AMExw1hCUFg(xUGldKIQBH8|VQOCoF)K(;iC=Dy_Z*QegN zjKW}#mllT3yua-&Habq;GMqAA=$tCC$HUIERUV_js5HL#)y(`BG6ZDwZEOxhOFJaB z0m!9)L3HoF0zbQ7A&f>OVSOy1K&|1&r}mG@ zwPJdz7USUz7&XEGHUh>+Eln*(hkB#yAr*$9V$*aUH0yWMrVw{w@!+P57i?u2@?|f`0vOyaP()6zc`Ws5jHP=xrLCXZXCZuBqwlpvC_Eq zbrB|;rY5UUXavp~pLROH=b46Lk!!1=y|x*Q?VC2YQ-`SSpwyE5!1tj)BQ>H#M-w8U zLQ;bsE68wDP@#$AB5k4Km>z;C>@UZ8T6~A1ic`?$_&~Ew{QQqjG1$i6uYHpfo8V(s zm-JNY&{KoNu!%pM#1TiW8PQ~_-;Rm+=>H1SL;&fPPH54HdgywGM6$rSF0c70Sj06i zaEKb9OQCONo3ycPm?o9wsI`@`-dB_T$Y%Yq=2`!gsbTOh6#quHm?3Z~x6XBUA%Vjfol-ceH$z z-h+A63>vPSC_JCBO?v_>wE&pbU`$a`*CU^|Xlz%tJ!D=OiZtbqfnE`OB>9f_@lmkw zI0Lwafz0Z61fxK}Akwptx9`)Q_^>^aQi@pS06igV$-#UgEG>#@f{uuxV2bW+SCs0# zqpGB#oOJ40e?>v{7}ZcLO>TkO@v>Znx4y2Ahhm%3Jl!)#MOx+R6Ay;spPw;5%9cP} zQdw}ab|XsCI6hemMcf3r<9SxIyyw3u?RZf{+g3*=x2iyt_Mi zy~{Q&NN^8bobY3soTMv(s$sd+d!*1w=L!E;G!e!i&H4d5K=2H@=<~g-Lx}Jf z{1+`Xy{P2fhO%k&L~=aHQxgm{xRNa%Eqqo?LfY4R{Q}n(aK#9(v}{d~2d;4J?4E(XiY(G6&k zo^qi4PzzR8IYsFovO}O*wN`rjp_zC+geP zab95-m^zerp20Zu@2} z@fRexplQ^P6Pw(FK6I`efGb6coM?;G^5vkhYd&`oS0KTU5JFYCJHVXEKC+N%>6o0< z%VaY8_}^!gVPI4WJ&6UJc9_q>79iHMl|P-5r*zN39^_v`WI~amXFR@6q_9 z*d*bYW=9(`1HncoUMCnECl`+L!YSYqRH14W$c&`BkgDN$eVX(;OAF!^QqM9W^R#FP>$`BH%r4Ks^2^#M3Ora zy=>naTTHpOX^=SorB`vmfWLx%MnC*rg3N+kG(dcXs!e21pl47jGVB-fPMtch<3LPr zhWth0j7d$uuM2N&K7RXyfg{^;%6}jUoHh0691Qs*FG2U2L4wEfjsI8h5GaW>=ZuIm zHzHa}?2Sw2gQP4#idp@Hyk9U#BS{YDWfi)oX^%3p3n7EB+^J-vVzKGNa$BogvqA*< zN)n`(+oKfJ*vRP}o%EV3lgY-56C&AyG&@Q8{9fR_@p{nByuoUygH#3D*sq+&@L#+& zMzZRb%A#=#{B*Lbg-hH7!zbz%PJZX%rT(C)ehRQh!hln*FQ#J7W~>D(PRBJ8;fG+fWr34u zhFp9F*q>dFa0ovJaaz&5AKav~JL`9R@|i}tMC|>AFbUr`tg$UvKAVcROV(de9l^%} zZS``$Us+yI;TXP`N~FQ&0CNNc@FlJn*YC5y?YmPzUEct{A~Uj2ds77rO8P=AOh~qV!Jg4brf~Pt ztF}I^w3mFxeXLI7#X$Hxl(SbtIXPB3zv~ny!-+AyYAHe+V@A$Kbhm<~#2Ah_8HRGV zTerF>syH%I9$Jk@VIEm_hz)=S|BfF*?+&3ziaY5dC)~-)>g-d+>AbHA_v&&G!zMid z(%B2`hjR(wRA(YU#C!Wv_-9@_3<^O`Hf5XGvByGkq!V%1k^^ zZP%==MOD=w^^grXSl8SNCEnb*9T@h8cVs`F60L!{ zlgO9z+4n(@&Zed|he<>m9S^SEa=XqXv@uN_k;as_DtS_s83c_94K%Gg7m!r%Cx>JQ2d(gD$K{S&W!j^zEumbS}n?Yg=8BCNn0$Ns=8tbdWIu zH}VQpk^ULCsDCi>8pQ!t;i0~$7(YC~eMx!8?|Q<_?{;QEb$8x$&cAaGKRrIj&qwkJ zhLx=P3Dc!TN1y7gR|_C7pbQzw!-aOwYYDDsdToPUnVEew3z~@10ldEMW~(0 zRMaf|9734Dl?fqa=~n>!7%XJ5#V4S5+wIH#Xz#P@99B=lTNv*`Etw!Fh};G^z5gF$ zsP|uc{wrsn7!Lj4sBqWoHNHQ*#F+$G(Vf<<02R*ZGHVuaC+xCF?bVK$CJ2AUV6We= zE0EoT;*{65Nu15P%p|OKUne@h&X!kzVxqyN?733fQ`8kTh55##*P~ zoo)S-gNc&VT0lpcxpDNC$P!kZrtt2x1NrtEa%?X!y(CIb*h1%p-L2V%xMh8x&&G;r z!8EZ&GUx~~wMD{d2GTQFi)`)i)Xi|Xgi3ILb+)T2qk+BV(-)^%6AY1!1Gv+ycys){ zX9q(!-_4?jS-RZlyM1;vF~D4x7TN+xG1fqm9yCGfYhtfM?h*_`uw)IBIvp@!N9Z5s+G>@BT3i=w1KCYJi5q=|lezKrG~#DVHms;?S3>Y~*b{ zl0+*@xYUjcQx5H;jkQH`LTy#pY>z5H+a9?=T&h_wn79&9KnZ~el9P#f9Jh46QO@%N z9~LfD?8xU0b-hT~XnfMqg&wxbNRGR5++J9a5mbfo`-%s*R90aTO07GY4Y0ew4^;p* z#dMyO&}F!vjh%eQ6~&+MtwU%$OQ|MVO-RNi+>Od&g-(tq+rrR2Dp{xfcr&D}kQ3iN zPQv}x7H$j$vtpZ2JW7w>DcdzFsJ>S?J0jItL9jj|Nje`7SjndW9lBiRF$YsY&Y`-u(GbKj^nd*Mz>;82NC zqGj3R=ZwyM61I;=y3LfCFP=5&qlFXxrI`Tard*uyBj{25Dd&Nz-AK7~lZQ(^)Fd3>;@|_Qi!>TGP{J;gGNHFxa3md5$?kZ;9 zkCE>df0-%r*lKmrpQ%eD`UOh<#qfLn zm=SzwH6mB6ukSBMBfqk~Um5NMjuL(pTU-uilEL#C1X?EzIwIgJ5`Xnph(ruY1YKg& z<*JplDb9fANID_D-u<%5eOt|Qh7kjzywE_`gBnVQa))Wcf%H#d8Qx}(jYy+-m(Y2X zXS5*&7XIJ+j)N z|EImLj*6=L-vtCk5J3Tz5)_dR=?;F#bNl#qB*oMy8Djh zcXw?S%kXJpULAVw$2mdt@LSyQPcMT&tp^q>Ji<%0x;KZ&$a0(8`7R`cW)x?}6u}@e z2L}6KYVmNb%&o)O4V<{S;&rnMmi3=}H&s4R1`45@Q@;kn`(a4TljYkaQ&?~GGNyc9 zQ=;GY105M+?>Q4u7!A5Pg<{Kj4RbpCq7}xR+_gbLXrvbTEpI8_NsJ@~Llsp^IC&QR548b6C7*?rYHCg& z4M>gs7*DqkSHHPF$lw&2r&Xzs9IeHQo@wIi3ujif+{N6QR3SIcLGiH((&j=j{vA+G zLkXnOD7}X(RE7$8z zCkW^%6DJ2W)Qr)@yWxh>Z@jPf8)LtnPSXNHCbL$~5=}8MTm^1lU;U)w*aBY;d3=tS z720%5g46&45s+M5!As%O(a-?@$KiE-i=pKR?YNa_e-|H|ClC7U(eJzODM1X5laV1k*=eues37<3L`7cBrVWqMprGN>a_WM!zrCRYD4VKPNPwdJ+o#~vg|8h zCZSU%dGUPe_4q@|p%pWpPu!oznAAy zXyd5YR*+b`4kMbj1L7A+v|K?0+|N>ux4IONb)j%Cdw+#X^E<=wa`TUP8#2Kqa-9M48zh9iM`F`~U8gU5tBMRR%$;lZnz!e6 zaNdPZ#1xdx?W~u)cxNuf>?FjLsQMudb8RP{(<%-)VQB??8uq=4vzOJU-Jt@=1A@w#_NUg4mXS7i?xQbhSWN8mS zthyH=7|UI6Ro%*~+Vkni?`Qqq{$2iPF|9E$;@|>+EGrazlDw)bCKB>+OVWzw zTu;6)>2jQ+P(x`JE3GZUNOLbY09G`x%D~Fz?vv|6u5S*D2r2P5-PmuqI=+`ahuqcm8Z-ZRpY zPP-3dvg)UFmuCx4lFByl)j$b=mYYf zM-!C3y=UE;GX(I#TfFO6y@@gT0m{kI@=%%x(|$S3`nn#Gq4-rOmD0NPiQDFdx)fp3 z%{JZHAMtFWm5)@|o*QYUJoN2h?2L3%vbFgnZC4>74P+>fsQdD@xuCLC1t#&)&7K86 zl?jl<$N|=1^Uw6-Gay9&bXVW4e+`JWvlm_F6M5iYFAft{fRT#fDO>M_;o}<5i*v$F z`uQq$pQphjTs?qK$6lTwMOn4;cJbr+$)v;KxMBF~0ohN+By{D$X{=N@OPbWP%LDeV z8y!Ezb;p4yKKIA@2YQ8~B%i4i1j9W}oKo}sj}jzuuydMm>jtL;=d=73tI9owj+N@+ z4qqD1S$iGg5EUJ5nQ(RPb@Rqy@m>41>dpvj>~!5|U)66vj^MhVFdH8zl6pfd>*a4z z@lBkQmM?Kt`mTo|kdI83GY%M8?I6al_2tM_+_(D~Bto%ln zR=s+TPzRPbKc*!xQ*eGH6Gw!9@5zg|&WP;HE--V<4pkCC!0Ol2W2i$K4z}aC1g~RD z=R-bs2>YQW_cA`cuPBYa?o*o{Yy`rTw?o*%wqowPOf*ilPiIjfii8;E@}#2>Fc8vC zD15{0Omu{`s+$5`=(L4zUSkg| zd$G^Sw1w@oKHS`S#MX0(jI%P6=4>tIZd_u;OxNSto%}@-diksFhZ5KD*yn16%3ouC z-(GhWHaIGmb%CU<>yNDbRL3Z!|xx`B9xm@ zR?2^Ix+^~OrhW25KO_tt-&-EeWaMkt!)*jQjd;jl!hU6dYV+Yjcmwmyl>quWV~-iP zT{+0^tBNB1aI~wqfdsy~wwYVlVBWYm45-hJd8^J96IA{{`_gBNKILu7sIf(dap$g1 z1A1DFyTe+wjFz1n?y_f-+hl1Q*QPS3BT^cPRO=9<-E&CK4p7Ga)}`Blm~|DtxDtKH zC9bS4H~foNQ*5Soa8WPJa-bj}7zHm2Fppwv0clyg+YIj(ni18nhE$mPQ5YynWF>+* z4k(#M^}7Pe3SLo1la5h2LyX(6{b@U@Z!>O)BnzU!DuR#^rS}`O8?4_;%@)Xz6M|&i zmcM7Rk}CG5!*43$^L3jj3f@0+jtuqclgS5Pub|1vdBxUd6#R6ItNS7i9T`FZHTdPj=0hySW;AcS?lwKPdQ*K{?cCU0CNjPH`VC_CAD3kh zaQSWzks6+!%xHBqEu%5B#<&})PfU_@&^`d6@~M5!AXscGl0gkj{>#pU z;u%vsRO1a|^q(=7PuASBX&2rKA7MgXS&R5sACSL+g7yGygYQSrBqZ6E*bM=lqJH38yqwkoKBLTm*%Bbi%UxU%`Vfil) z!mzzZ^*se*?4<~xH7FPFdDxdH9P7x7d}5UmS^WYK@?dya|2Q<7H^1>FouLLoyK5?{ z#w7s%W^)1e$O5{z`UA-Btx{Tv_EJ~Re(O>sj?d8o3-mG2py!zb%Wf3r-s))KfYk=H z1bYeNpocRI%ihlL&(nS!fE4gB4$7ut2N>M##?ZgS2md_N;wbYymBgKjdN*os7=#mm zXDXx#00mHDfE-&2dNi#b>9+JmaPnHl>Wtl`_AG1sxhySyUhmL<0*xhZx;S8rZx;M~>6U+vJy4Ejxa4X&NtBJvF&b>w252BSJ^rSa)ABJq+ zUq;1vrtFv3m#VGcee`mS5HmgxoXB9V1~ZLt6Cc{K-bj|H;H$ZY_;goW8&`)ZcfSuZ zsgg9|Y%~#_d_9el<2j2C3qRbuOXKTR89&7w;tFip!JBb~WFOV$HkwAmG}b-RtDL6;BWQbsn0W>=wDXw;GBh%_-jNn%#xGwcS5&I)OA;U!L0BjUn>awaIG-V9)!4uz`t0^$G@iYM?Rg#)OF z>s8&hKQJpyCE~RFM8y62c`c%PE`Yu(PpecGfUIlvkdYbm@aS^$Lo>dN)qCl)XG2t= z!Z`9&np7?^2jNuLqs5qRAg3@C4<GNb5(jFqy zD%Q!2L?-C@?ow~1?{n-s45oKqjOsd-d%!mILRW!4MWFw~rKB5AQtbKGudke^Z?EGk zx6+ILissxf)B%Qsloj`n57sSW^c?~Cny}y{62&_Wo==XeRJn-xr zg;*sY_p|>=gK;`)Z6nHZ9?rP>U**+8a~?rrc!6r;WnW9$s(RN4s5)=o&kceH{$PPHH%{+4K+zC$MXfBB%Fr;Z81&+-JdQQ)4>rqR1t?7>9|u z8Q@Sg2!fEYH1~qfV^3}V*!cn(Fr(F!-|EZ#VU5bp)S+7+TMMXSVFxrz(Jv9TcUg1T zb#vv5qaWrMw-L{#A&sPy7q8a=BS*$w6cps8{35?x_mb<)7PQ0@=u}U}?W+RB2g^@w zL9T7wmn>pjKCE8Kh7q?IVUx*CT;7IU>YH}Z-4cV-?*DmgE&(nyD~+p((Gq@Kn--rF zZ}2UZF{u0{ivL+b-P&O_|3AjB!%^aF&lxw# zWU7wlhV@0uBH9)j6G`Fcz`YbQlY~QYaNu3$3XSn8NVd-TF$_XaQP7EHo35QL;;odL z>()AQfOO{ZJ$ge>vbb8-txc2;=-DKh`!;>7c!m(=#hgY?__CBS~SkrNU$LY`=ohOQtt zvrJqo{(c`$GV|L%7Fg`mQy-~?lvJH;HqJnipL!42T&LwksvB4}1U`$VqA ze!!=OU9grtAd-h{R*n!FjphY_X3809vo6>p6nc=B{s~y~GRW*deD@B$?2$PTHy$Jw zKmUrRjkZ3X!0&8JD~*3>Z4vp}Z2$TFUiY{X>R~5oJ7BD27T&{<$Lx~dbxbd>(7H16 za9_3R%W|yX{*X#HkPXA5&g0ljPTQ*J(d{Z?q0JEvoWD3S6Kn%`z&4Zrbq`p^ zn8ZtoXG^sey02SQq@L^-dQKDD+d>e~I_8gyKe=LDyK7uc(!;YwX3r;IH$~)QSzAHL zv28hCp>oEf)*00Exw)AI3BUpt4xl7XAT&O( z_1CxlHWXH5s7mx`3D^%Aod5nVG@;t!2{^(d`c|^c=OtATO^F{QqK=NTSpC&ij^s<9 z9YCV2IS$#hIKG~X0fxx7u*NLqfYSGgll%Zuukb&*U3{T9=#o5dS z<5F*^EuN-Xdequ1ip#6cG5QGl$=Yr%fo0!@U9rL_Ezj1KZ0>1TWD-`R48+8RNN};-h8tq}l!`U^$yPgfK-7t}H?3tMcsJXN$-9g^lMmJiq{?8;EJGPKXdP?NXJw zGnudFZad0*CVCvNz=0pKWi8_w{aLnAI^=%BzXFdu0Df&o!2FTYbuL7HW*it2O&T-W zc9g&kfp0W(oUY?|Ya2Pcsp8~XA6qZ~OOv_`LrV#iKfVC=FBBnbwKkGhmTjTrjXHt| z?&vSg=fLbY7Wh52x-HV{n?-C^FD0>kJ|4z9;GRC>Zky8+x3Y6mCt$#Vllelmo8m1j5wV1o}{1->vyVu%jEYl*UHu^#VmNkC2 zE@2{UA$o8vQEx{?RQf#FTy*BB$Ej=FJZ}6(3JNbH+{3m*;=J{S z9IKR~s-01fViNQs3vjX{AYWbPL(<4}9E|Z+o6ub+n^JRy*~1o)8+9>zY-5)~L|=_o z;<7%5!il;9K<^*tq!c^If92i@1vQ**ASRo<@hVa0j>B;buRG+BZS@<_CHk3$qgf*x zcScV{R?M7ZHvDucKYuP_IJH&~?$vt@muIU6kXrfG`WuZYy5%O>rp}T2`7SKEOM+&x z7Q;?ZESosohZ21C&5ons3J3YVMuK(zf#u?~p0B%jd2!jXOfJrzyfQhASMs9 z4bjVqnR`))TSs}0uj{+cw5kjsk+5@F=&zfphFWk6ZRNm3z9B>f^aQvVIDS^DoO`XsN3 zNY}zu&+czIsiQRW{HF|lGghtm$B(|3xv+F1CuYBsv^Lu<-mz{n$0+tjSo2ZJsN1l< zO%2DXCB})BSlMWq@0r0_Iw8gcd1gL3FSE$iDG*+~n!GXWpI*68h$_2Mm^>()VVs34 z`1_Y?H)j>Wu7c;W%H-{C)el2>_vQ<(ng)lIfWpkgM|%O;xTt;uIxz;ntn{f~p^Lid z&#CTwz}Ok4eH(T)qgi!`cJxYS#7(Ygdlgsxc-a2)l?OvVKa1CZfarTiwy28zy_TP@ zn^t_akm=ua#De_&p%0o9Cp4Ko1)TlnuRd|2at)Z7MgFc4WiY&Fwwa1}TFS!fpeAo* zzq5K6)38Tnkzn4Y&`C(VIHWXWDXe%78Hfc?I;%0;#zjTi%<7#RnTv`zLd(0pO{t|q z0sj3++6_;ei*^N%7hIpNP5OVcs4h;@C99ZAAZQfdP`l<5U9KaR=yt?MsbEvKpy8{8 zNX7^jqW$_yHtB`{5hgLgx2Qlbgjh@vPxsYYEnXac^Xr9I!+(OhfXwK5zj|SurCz99P zfx&9AtM6i&r_Ny-C5RelDg@tcWRvG;9YVbVtgdL;9PxT)HV)mjmwA$oJ$<{o4p)T1 zL&S~a;Z&ituy`{>#lWF(h}@%iti`FxF-wZJER0Pg49i%Y(lSXj%B=E+`TN8|Mzv&^`n8v2IB|KiSt85BlkoS znbv*74060t7)AQ>QWAb6_LLbzBfh4IYWT>PIw9(EZ*t5HVP;l5;3@A?OO$jVO-wK+ z37319omF3j(LZTd93sGvS-IqmnJU!Ot*y3z|w31$?tV3dDIqeAGy5q)~y*g&E8gFRs%S{`YPRwMh8S-jU zjY$MYJXtP0OxZJJNn+dbxe_ZiAWf3d{2ebEBlL|1nr_XA6G}^6wH9Pd zBN{B_j*k3VM{fG}IQR1v1yA$(A*qO2W~1eVpi)rgopdzI<8}Qipr-7hT&L_6SKJ>I zLXduQyT7aOhFCX+%onpr{X|c{eq|HJz<$#eC34T3U)PtgM@1Dww(1`NyXig)Dal*F z3kvFmGnCuN9_R?HFUAx;w42qwI~bL^Fnz6*8&$R_MUS_~Lx{tPz#~{G`b817QC;uC zD_{qJliW932|;o^T=7fiy()zUQ&FWjl!Op|({~;sT=!1pIh;QD)T21%D=1-d*#gAUm%2yjlwV%8WoN1}D z*uE;f4J^+_9wrOa8u4Z{lrO6wQ6;>rD6fX}fb$%App50Ee2RoI>Ck{8Og$xdC75Xs zY}iNJOH6{6W%m{?`^->1_!W2_(5&_rIR6F$&qL^hA5$eS#sV7@t&PDeM)U~KdBSjS zJZ4y9Q;DRjLiNlv1WJBj~o zulS9-^;3=1M1o{_y2%%T-Rs$LNS!t?grih6J|onSshqN1{myF@7%CLMKVI=_r_C>a zlzhHG&!+`qQ~Y{@8th{p;iAQv1Clt6PoP5dli2aD&gc-|cUaA+z9}hl-rw z*W_0?++9Rm)&Zto>XHv~T$`_}2PRs^)0k41m|6)11;(=Rn&H^Wl~QNACI-XDIG1>U zFsZlrGRW7fCYm5yZ4^M>$4f`K224gY&kg^SD1L+sr$YYA(dCjbB+wBSCme~UW0L4W z&fHno)b(jw53%a#tm z3nL*b8gxZo!qU)Bx-Bsowca8{9tA=D>jqys!ReIuk4h05T+xD0p~1+5k63Ft#Oc#zQ0pU;AW{s;yQjI2$8Yaa54 z|APV90O@)F&;jhLhx}4YcvfF_gB0gg7ntjMUXQvzwsJLW2QMaBJvoC~I6z*{ve<8cH|lte^TpYr zRSh66aDhkm3BZ0j*~xGF0&3ZEIU=pkVwlM27FvBYM(ahow_ay-=BsRA|CU$$Mf4rg zY!wNGfGn>30rBn%;WG?LBtc7=mto`*#Q>94wTYKN1&yU#YWY92%P4%clsCr zN!NGz4S(R-0SGo$DOEf+T_(o(v3C&Pa5q%!9US6Y0MQIbvp-HJ1MY1f;B?Suu{{L- zSeDfI>U9Pfp4j(UtjCu5-~uNm*iVoXc7x2$7$n>Lc4G$bVu3qwtmnnaOoea$IBjrO zXS+h9=aw(|*hiPMi&kcjep1j>#TX+y^;$pls$iKV(r9m=>E(ezq)8QBMZ)Gh?Ww0f-8djUG~qKE1Ixgbe!m2D#jR zzqV_};nKXQwp)T_mKV!b}o-(AU_aHq+&-^OxeM=KjsG=PBz|1w! zVn+XDV$_n@mKCwrfXU=e1k{rG*hK5L!1KH2jB)-I01iiJPS;F?EXrDBC`;nLw2FWV z*mC)saPWsyYg(kxNUX%2ce(sP7WVbpd_fy* zvg?On>Sp=GMu>>3$6UAdbOndMyC?e!-OG{syhNMKeuZv^_9}N!RIt(l*7jDN!enFY z8$gCwLZ*|Rc{hglm$VbW=gq>mS=S)Wn!1jS*lz92Sr6oXB+izandFjF4fsi2fPVk# zB4!eRBl%xg0U_Cu`cZU`FvairCMLP=tjE|Z`TMHmJ!w08}FMx*2a42@iwy9@M>xbbVCIKJL-um63`!J?_7eCKR96PM~c z`IA&2(D?l6CK+!F(~`Y^BJWymK}}BmUL!7O1=g^IA)`vxr&Vek3xOb(MPIL4kxkMU zRDGKQ{F0I9QT-37n<;|<#hl)13+fC7Pm7RejU)Uhas-~BG^w`EqBz~bS+aP`k`Fup z*|b{H@4QwnB)k4hTDjoWOoMppHcT(;W_&bB2!zc6s>#7Q*sI6XJ<_Q{gidF1*Mp*r zQH-+6;_OOIAow&^$g|-HJOl=Zt?6d3Y}U0mvDXVDCV_480y&nA>NWwx2ma;VeOAzt!Lpn#rH<#8aG~1OK)u|q%%8Q(#bG(j~ z&mo9|@6o)xxNmf1wE#9ED`zl=)$}OVIBEg7Q6%I@T4zr|IN{=)Tp}-r8gIkZksDz* zZax_j_A+*4r*{H&re*Zu+Hfb(Qi(*jr?q&%GRHWod`;o%3wj=hUxv@xKPj7tmKxeh zO==6QYj>JM6o}nI&$Hb1Uk?~#hIdoPwW{+0&8QloNOrK?pBk0(z61{1hp7`Wxa6~b zX4ps>ElC~|!~s*|asDq~KoyDRP)Ql*T4x+AsRAcwLc9Ts^g+TTIh4?~p4NcxW|QQ7&P7JL>rO*!cd@ z%rv3o03p>Iguw?}r5BmUxlYytX;KwAJ-P~v{FHn+a*1GVf^K^vS9F1?M^eD;=i`r= zEob+@lkMr2`DleI62sru}9w2|I`CWZi@|TeZwL`EwU@y zYTqgKkxG4$%V(vzOp*rx36iVpUEX}!%ub8B@YJgco-e7hd$3&_RR!1RJ3DIOaEgEW z^IK8Zor(hg>Zquwaa7X&d%(;cJUqRSe1bgv!kf`UTZR{r;4euzmg5iqM04(1xk_b9JSors&RaDLIMpy<-rDaEk4pB z%}y{y*3E|U!=O+v6lBf3g*En%x~8Ousu>~dOHI@B+Q0w`_X*?0e<1b0H-HZbV zs8_B&D$r_FNZ{_bM+~(;9!0G8RoN~jd~Dm+;>kD9PLr#bihGw`$glV+qM=Z_P-_2& z+4jwPru}*m=>zsI#rkyVt-VwpZDQ8d$Rq*RxF~uJn8m&-&+d+!I!A@V=&FVuj6m}7 zc*9R{k2$zXK+XT4&7Wyv7niY*eh?1ux`RG|6-t1PCyauMjwRw%rH+~ZgX1@0 z@={MBMXv;S6ZwCTFAph%ZSQ)ECu9Bh0N`(4`l#W2jXwP3|NJQ4doPrBU36k^y8oOD z@=G2Vk{hR3J8u8@RXkw|G_*%fw@lG<{&AAv0)7)xRPWON*u4dv+TXMA#>5Pp5=*|7 zh4zoj#H9G8;kEF;?SzGGFi$}t$@m?aBa7$hwuFriiw1I76^GEEj)uGINKFMoA+F>8Yd?WAbhY3+oD1e6$8{K+9b6Xb zIut)%YQpsb3R=5C70v=_dT4wMay6<7Pw*zxF#t*R1w<#7m)}D;z?)9^BRC|SV+)i= zOA8SS6(_(Tv>n-|5Pbf!E(qwSU2z;1x4-wrA3~(`-oho&hwrXj2zl!@uzXD{#0`#T zZ9p+8Sp4I8X5Zoy>jq5iRaqX6y~MQxRx+B*SZJ6n+P?r~!~xZv#zEz+3+iyE=f1Bt zioN6F=?FS11jWQ_=$wF-DBaK+;93CUy@nHO6@C4gXG8P_!y_XTz4V}PSQs1%l2dOG z3F+VdzLldLfiJe~Z8;wo6U;652ut-4X1t?oxU?o1e<`JtqW zl9Ivs@hbCRfiB;us1T@}1Xr9LEqykImlr$pX|dN&5~iBvhjt)ehy&%vA!LCQ@^^VvzaPCalCPaQ?@}*R zS%oIH^^Cbv)<$l71C%REZQvkIDYvl#0O(rU||T7g;f<%BW}eK@pp?SVK8A`2^}M$RiPXUlA?+S9vE>?V|#X(S!n; zSa%@7vMe+-cLwMi*D5w)e$}<2JShH)X&I|R1LCdglz8(BD>Drvvl+hnpPxmuEoBwv7U6`W_SFH*HMDi0ndoFy zMtsX>&vR$|>xESrzq1Z?T2;v;HaF!h_rM&Fc(&{fZ{c7jDV$$^^%AriI`A(RR4|lY z6FgeLpH%PZHWtud1hqID$ z`xS-WW}{MD>1gu4vK~f|xr}oH`a`gstd)k!twq|?zs@%Oy#_iOT5gu5-@%bkiZ^|e z>GBISOOcLG%g(BlZVyz646=}vk&_Fat^h9Xk1J@us)U!B59bW2JVRUnbgDw2TSM0& ziVO(;wM=3`?*=+6vUkz0N~8f`!FI%k|y1#ZH_G;#sp>SI>AH^v^E}y`ZIH@Rq%=>}D$QS?ZzowBgXp z>{j}M>?L`w!GshgX5_FTcvAK*{tO!s=QjP69Fc?^tvY`uYiL&y;uIzV8=OSfOmzMu zc&3XaIndCYh_(lVJ;s`pM;aR=(jF@CI-ZkJaiI~579PDO)*Ao4 zE%!GcoCAPJmvcoE7zI^tw~?${OJg}4PSEyjy66aHH`A8shDuat^?-OV4s;ZjS(Cdu zojqJczf|y|!eR7U0B8|2QB8dOG132bRTSg__;~VoOIc6PI+|HOX>1=PEhV91?%cVP z<=fn7evbjRfjf3D{cs3sOMCEB=^LTs-BET!buQK@%T)nQN)~oz*77&4p^=#_a3oUq ztbYux;Y!pD4~4z%Ng+@S{jjl+W$`(j1a(?>Bz_r6Aps;%2gV_RdwM$I$xHExBb_*c zeDX}wq*!?9igkH-`GTK#LW`V~B!W{we=;>G8Q3a~PXR#T!|L#6jmQ2^Dn52^S!ZEU zJ&im~10l_<@5Ij5W__)+x8jN9ZVQJ#({%WWo{e0T5~*TJ zmcfo6OF-723q^RV6q@aPvqr4!5*H8T`J0&Ng9u1#%S`X1V387xX$y#Bh#9|c64XD) zx+|5-t!zt~OxVD#4>|2xu)8TsjQ&1=QfXPhV!L@{M{k}A5g!hC&90l4e|((1DyJQh z;gA4&o=sZ8wnVYBvLGpXMWswL%w9r)KA6@f)FpZ78oJ!F`gCHFVYkH)Sv&lEdR0En zClQWgT5bRMU2T2G8-6UbZg+QyFFgjMUt>dakDdl-9T%MdYCpBhzREmoY@gnLx4M;M zziP!DrjL_m0QwKA(CcNaD(weYm@Dl6UV$gLpkjYvr!H_Fe?umrf@M#sWgT!7M_myA zU@9-PjsCqstX7$?y8OTvd%dZ{yY~Zbs8rWs=I20pS)nj8{SZSwe@k+kIT@D>7OmM# zQBEAC0?`leqV~PUcc#BC!vvc^HX+V3JNJRdn~~PTr_DU;y@B2T`T;EnW)grVI#Jo3 zWvIgj>Q*|P;diafJo=79e;omp1F>!V20cc2@8``K4h_-WnY|iSh19ypPjsg3WjvmT zah+MSBCtaOxJ3WVibR(nyggy|^77Hvh!XMph$d9+}`NkuE#@?%}ze(^Ndd>-V>q98-w!Udq zXqC-Z&q?ad3LW|I(zPP3hYBaw*&Q(zHc+xhzhmU;15yD-7|)Z*`BHK;%msPjm+)Th z`Jk4uqbagj$_IEj$aH{x!Usly=Z|8;c*CgYqXKGSRa;S)Tg8nfaf)6v*tWF31KucK zq%(Ye4;hEw!lx(@%M;0xUq;)T@{{kme(Qa-mpYMM!xd&2M=B@gf_T&FO_K_f=g*`o zl-J$~pJTkYIC^!(952sHLnn2KNz*kXM0f#1a_!;FGgUk&JWnP9!(;rs*y8K;7F&@h zXSL6t{PaC6M)C_waZ0q_9c0WfMH%-s31Am^S5ZLWy!W@e6a|DHBzAt_jD+u@6-#f7 zj10wP8h3HXL>IEiM4y?CoF3P5AnMepcbfzBEx$Q|yZR}_DBWPp=H_KW;0CfOx9x31 zZum+f z;GY!gXvML1hsTp@@LD7KxJQod6Eh5Y)~;J{_pKRC4zwsSDF2`lM4f|em-A=@fz$+@ z5ifqhdV64VU9n5nNJ-SEax1qr*6}6ZWm4iyNKpXviV>6`p~0X9A1!dy)qN3?s3Bwe zKBKfRVAR7>6q(addjyV{^#klsG92y4G$GMY$d#jE~o+sOLS_F0(3OLFg6^ElokA zZGhGCxSH?um$&!_?HR#(9qp%+F7~VDSJx|JRs}6D^DPQ2=-1289FUoz42AlPT_!Wl z=;2rt*NxtxoKNlgFB;~BJ|s@R^Ff}k@D23$7Eeqc<-7;~Y>`Y(Yxtt7p~017Yo;Zk zAET&UIs)*L-@N{#5&Hs;$2C_>)MUH0`xIqe*6N1F{K0n!GQD1q+T?RjCgqBQ3&K)r zqAzG=t)_1gpNm7Le?GUiuzbdoo-D^lkB|7CU8aOv0T$ryNw7Wqck)<;9X3@xS%mT5 z%n|(UJrpol#sviY`@GfJC`%=&o?MI!NH2n=;#9eb(8=9=|V&I|J>**>YxEb z6HX9j(-s86&5=k3mO4dHmAI`&5i>c{4QNspHtnq_#Dg9m@jE*`+x3{{Aff_;(IkXs z(8>|cr)P_td>?NsqU7b}v0u5iL6YsFM07c;U=Q((bttPXq_DS1p?AE6(GyU{8ZSJ% zs%)I%CnhF@dzwWq0TfJ8iZyWDyW+XF$ko$32(2CU^|!b0x_*82l`}IlYg;P5ayfyO z2!2R_Uk#k2_vVF=BeBO}OehEIla)URs&tTtd3x)I2KWShqz2xoH}t%#Qs|xC{iuaLKZ1iD-=FKraS0}db(bo&GznNVTI>_1q}`1J+8 zf{#;1qw;IWB{1xzw=8N=f%O&aoQ1K3W+aA%PL1aWU4_nxEppXIy^HGpF?ZMDTF)G9 z+{>lCG1}sIKxWP^EOg>Q!-W=LObR?Q@2r;caQ`MPeO=vY$&9kRROW23tKx{7(YcZN zQX@Y_P|$=Jwz#dYSWw9KTD^6Mlny!N0*V767k(8-nW~NB-}N zurOG1P`}!0BG3Ns3xzCdcr8mg&)v(5`14`BFt9=j7yTbNUG(s&a`WL1%ZW1DWIDS) z=Td^N$e=8{qy65(1H)MmIul_gD2lqc#Vp6DU1|Lifj|sc&OW$2>2LU&jUM%8gBih_ z__s&h@GtKNkVzQ^Sa?5jTA!%oDlacT$ZTj3ycYBI!R3nkLr>U=)56?5$mjW_<9exQ zN@r3F$K8APp)Cev{uGl{>H}eJ!y3?5rD~@2Jwer0blRm9Kzg_95=?FJphJ6GVn|C% zTZZb^9+24*`@XjruRu9L-j4-k@B?tqya0)OLRH$!L0pib7)_M`o2|K2Parh^D7SSO^x%FMaG6}j0)ICiv|HeQzm4+!_Z~Hi6$W%D-Q$0@k;@N*r(JRb z&0sP4@2y5I9Ku!Le$TqtasIv4$PZKCf%`42$@A|$3a-sQ)WYS2*52GdEu=e#7AYpAhN*4JiCLfg4@WnZs+;%aKUKN|W1=?&Qtb_\u001b[0m:\u001b[36m27\u001b[0m - \u001b[1mLogdir ../logdir/craftax_small\u001b[0m\n", + "\u001b[32m2024-06-06 13:35:50.153\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m36\u001b[0m - \u001b[1mCreate envs.\u001b[0m\n", + "\u001b[32m2024-06-06 13:36:42.176\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m57\u001b[0m - \u001b[1mAction Space Box(0.0, 1.0, (43,), float32)\u001b[0m\n", + "\u001b[32m2024-06-06 13:36:42.178\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mPrefill dataset (0 steps).\u001b[0m\n", + "\u001b[32m2024-06-06 13:36:42.180\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m92\u001b[0m - \u001b[1mLogger: (128521 steps).\u001b[0m\n", + "\u001b[32m2024-06-06 13:36:42.180\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m94\u001b[0m - \u001b[1mSimulate agent.\u001b[0m\n" + ] + } + ], + "source": [ + "from loguru import logger\n", + "from tqdm.auto import tqdm\n", + "import pathlib\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from torch import distributions as torchd\n", + "\n", + "import exploration as expl\n", + "import models\n", + "import tools\n", + "import envs.wrappers as wrappers\n", + "from parallel import Parallel, Damy\n", + "\n", + "# from main\n", + "tools.set_seed_everywhere(config.seed)\n", + "if config.deterministic_run:\n", + " tools.enable_deterministic_run()\n", + "logdir = pathlib.Path(config.logdir).expanduser()\n", + "config.traindir = config.traindir or logdir / \"train_eps\"\n", + "config.evaldir = config.evaldir or logdir / \"eval_eps\"\n", + "config.steps //= config.action_repeat\n", + "config.eval_every //= config.action_repeat\n", + "config.log_every //= config.action_repeat\n", + "config.time_limit //= config.action_repeat\n", + "\n", + "logger.info(f\"Logdir {logdir}\")\n", + "logdir.mkdir(parents=True, exist_ok=True)\n", + "config.traindir.mkdir(parents=True, exist_ok=True)\n", + "config.evaldir.mkdir(parents=True, exist_ok=True)\n", + "step = count_steps(config.traindir)\n", + "# step in logger is environmental step\n", + "tlogger = tools.Logger(logdir, config.action_repeat * step)\n", + "logger.add(logdir/\"logger.log\")\n", + "\n", + "logger.info(\"Create envs.\")\n", + "if config.offline_traindir:\n", + " directory = config.offline_traindir.format(**vars(config))\n", + "else:\n", + " directory = config.traindir\n", + "train_eps = tools.load_episodes(directory, limit=config.dataset_size)\n", + "if config.offline_evaldir:\n", + " directory = config.offline_evaldir.format(**vars(config))\n", + "else:\n", + " directory = config.evaldir\n", + "eval_eps = tools.load_episodes(directory, limit=1)\n", + "make = lambda mode, id: make_env(config, mode, id)\n", + "train_envs = [make(\"train\", i) for i in range(config.envs)]\n", + "eval_envs = [make(\"eval\", i) for i in range(config.envs)]\n", + "if config.parallel:\n", + " train_envs = [Parallel(env, \"process\") for env in train_envs]\n", + " eval_envs = [Parallel(env, \"process\") for env in eval_envs]\n", + "else:\n", + " train_envs = [Damy(env) for env in train_envs]\n", + " eval_envs = [Damy(env) for env in eval_envs]\n", + "acts = train_envs[0].action_space\n", + "logger.info(f\"Action Space {acts}\" )\n", + "config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n", + "\n", + "state = None\n", + "if not config.offline_traindir:\n", + " prefill = max(0, config.prefill - count_steps(config.traindir))\n", + " logger.info(f\"Prefill dataset ({prefill} steps).\")\n", + " if hasattr(acts, \"discrete\"):\n", + " random_actor = tools.OneHotDist(\n", + " torch.zeros(config.num_actions).repeat(config.envs, 1)\n", + " )\n", + " else:\n", + " random_actor = torchd.independent.Independent(\n", + " torchd.uniform.Uniform(\n", + " torch.Tensor(acts.low).repeat(config.envs, 1),\n", + " torch.Tensor(acts.high).repeat(config.envs, 1),\n", + " ),\n", + " 1,\n", + " )\n", + "\n", + " def random_agent(o, d, s):\n", + " action = random_actor.sample()\n", + " logprob = random_actor.log_prob(action)\n", + " return {\"action\": action, \"logprob\": logprob}, None\n", + "\n", + " state = tools.simulate(\n", + " random_agent,\n", + " train_envs,\n", + " train_eps,\n", + " config.traindir,\n", + " tlogger,\n", + " limit=config.dataset_size,\n", + " steps=prefill,\n", + " )\n", + " tlogger.step += prefill * config.action_repeat\n", + " logger.info(f\"Logger: ({tlogger.step} steps).\")\n", + "\n", + "logger.info(\"Simulate agent.\")\n", + "train_dataset = make_dataset(train_eps, config)\n", + "eval_dataset = make_dataset(eval_eps, config)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m2024-06-06 13:38:20.651\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m323\u001b[0m - \u001b[1mEncoder CNN shapes: {}\u001b[0m\n", + "\u001b[32m2024-06-06 13:38:20.651\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m324\u001b[0m - \u001b[1mEncoder MLP shapes: {'state': (16536,)}\u001b[0m\n", + "\u001b[32m2024-06-06 13:38:20.751\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m390\u001b[0m - \u001b[1mDecoder CNN shapes: {}\u001b[0m\n", + "\u001b[32m2024-06-06 13:38:20.751\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mnetworks\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m391\u001b[0m - \u001b[1mDecoder MLP shapes: {'state': (16536,)}\u001b[0m\n", + "\u001b[32m2024-06-06 13:38:20.813\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m102\u001b[0m - \u001b[1mOptimizer model_opt has 15732120 variables.\u001b[0m\n", + "\u001b[32m2024-06-06 13:38:20.836\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m281\u001b[0m - \u001b[1mOptimizer actor_opt has 1335851 variables.\u001b[0m\n", + "\u001b[32m2024-06-06 13:38:20.837\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmodels\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m292\u001b[0m - \u001b[1mOptimizer value_opt has 1181439 variables.\u001b[0m\n", + "\u001b[32m2024-06-06 13:38:21.032\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m\u001b[0m:\u001b[36m17\u001b[0m - \u001b[33m\u001b[1mLoaded model from ../logdir/craftax_small/latest.pt\u001b[0m\n" + ] + } + ], + "source": [ + "config = parse_args(argv)\n", + "config.num_actions = acts.n if hasattr(acts, \"n\") else acts.shape[0]\n", + "agent = Dreamer(\n", + " train_envs[0].observation_space,\n", + " train_envs[0].action_space,\n", + " config,\n", + " tlogger,\n", + " train_dataset,\n", + ").to(config.device)\n", + "# print(agent)\n", + "agent.requires_grad_(requires_grad=False)\n", + "if (logdir / \"latest.pt\").exists():\n", + " checkpoint = torch.load(logdir / \"latest.pt\")\n", + " agent.load_state_dict(checkpoint[\"agent_state_dict\"])\n", + " tools.recursively_load_optim_state_dict(agent, checkpoint[\"optims_state_dict\"])\n", + " agent._should_pretrain._once = False\n", + " logger.warning(f\"Loaded model from {logdir / 'latest.pt'}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Now lets play" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(0, 0, array([ True]), array([0], dtype=int32), [None], None, [0])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "assert state is not None\n", + "import numpy as np\n", + "\n", + "state" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from tools import convert, add_to_cache\n", + "envs = train_envs\n", + "cache = train_eps\n", + "\n", + "step, episode = 0, 0\n", + "done = np.ones(len(envs), bool)\n", + "length = np.zeros(len(envs), np.int32)\n", + "obs = [None] * len(envs)\n", + "agent_state = None\n", + "reward = [0] * len(envs)\n", + "\n", + "indices = [index for index, d in enumerate(done) if d]\n", + "results = [envs[i].reset() for i in indices]\n", + "results = [r() for r in results]\n", + "for index, result in zip(indices, results):\n", + " t = result.copy()\n", + " t = {k: convert(v) for k, v in t.items()}\n", + " # action will be added to transition in add_to_cache\n", + " t[\"reward\"] = 0.0\n", + " t[\"discount\"] = 1.0\n", + " # initial state should be added to cache\n", + " add_to_cache(cache, envs[index].id, t)\n", + " # replace obs with done by initial state\n", + " obs[index] = result\n", + "# step agents" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m2024-06-06 13:38:34.000\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mtools\u001b[0m:\u001b[36mwrite\u001b[0m:\u001b[36m85\u001b[0m - \u001b[1m[128521] model_loss \u001b[31m22.2\u001b[0m\u001b[1m / model_grad_norm \u001b[31m14.4\u001b[0m\u001b[1m / state_loss \u001b[31m17.4\u001b[0m\u001b[1m / reward_loss \u001b[31m0.1\u001b[0m\u001b[1m / cont_loss \u001b[31m0.0\u001b[0m\u001b[1m / kl_free \u001b[31m1.0\u001b[0m\u001b[1m / dyn_scale \u001b[31m0.5\u001b[0m\u001b[1m / rep_scale \u001b[31m0.1\u001b[0m\u001b[1m / dyn_loss \u001b[31m7.8\u001b[0m\u001b[1m / rep_loss \u001b[31m7.8\u001b[0m\u001b[1m / kl \u001b[31m7.7\u001b[0m\u001b[1m / prior_ent \u001b[31m48.4\u001b[0m\u001b[1m / post_ent \u001b[31m40.7\u001b[0m\u001b[1m / normed_target_mean \u001b[31m0.4\u001b[0m\u001b[1m / normed_target_std \u001b[31m0.3\u001b[0m\u001b[1m / normed_target_min \u001b[31m-0.3\u001b[0m\u001b[1m / normed_target_max \u001b[31m1.8\u001b[0m\u001b[1m / EMA_005 \u001b[31m12.3\u001b[0m\u001b[1m / EMA_095 \u001b[31m26.4\u001b[0m\u001b[1m / value_mean \u001b[31m18.2\u001b[0m\u001b[1m / value_std \u001b[31m4.3\u001b[0m\u001b[1m / value_min \u001b[31m10.1\u001b[0m\u001b[1m / value_max \u001b[31m31.1\u001b[0m\u001b[1m / target_mean \u001b[31m18.4\u001b[0m\u001b[1m / target_std \u001b[31m4.7\u001b[0m\u001b[1m / target_min \u001b[31m8.4\u001b[0m\u001b[1m / target_max \u001b[31m37.8\u001b[0m\u001b[1m / imag_reward_mean \u001b[31m0.0\u001b[0m\u001b[1m / imag_reward_std \u001b[31m0.1\u001b[0m\u001b[1m / imag_reward_min \u001b[31m-0.2\u001b[0m\u001b[1m / imag_reward_max \u001b[31m1.0\u001b[0m\u001b[1m / imag_action_mean \u001b[31m10.0\u001b[0m\u001b[1m / imag_action_std \u001b[31m12.9\u001b[0m\u001b[1m / imag_action_min \u001b[31m0.0\u001b[0m\u001b[1m / imag_action_max \u001b[31m42.0\u001b[0m\u001b[1m / actor_entropy \u001b[31m0.9\u001b[0m\u001b[1m / actor_loss \u001b[31m0.1\u001b[0m\u001b[1m / actor_grad_norm \u001b[31m0.5\u001b[0m\u001b[1m / value_loss \u001b[31m1.3\u001b[0m\u001b[1m / value_grad_norm \u001b[31m0.9\u001b[0m\u001b[1m / update_count \u001b[31m1.0\u001b[0m\u001b[1m / fps \u001b[31m0.0\u001b[0m\u001b[1m\u001b[0m\n" + ] + } + ], + "source": [ + "# from tools.simulate\n", + "\n", + "# step\n", + "# step, episode, done, length, obs, agent_state, reward = state\n", + "obs = {k: np.stack([o[k] for o in obs]) for k in obs[0] if \"log_\" not in k}\n", + "action, agent_state = agent(obs, done, agent_state)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "=====================================================================================\n", + "Layer (type:depth-idx) Param #\n", + "=====================================================================================\n", + "Dreamer --\n", + "├─OptimizedModule: 1-1 --\n", + "│ └─WorldModel: 2-1 --\n", + "│ │ └─MultiEncoder: 3-1 (4,365,824)\n", + "│ │ └─RSSM: 3-2 (3,831,808)\n", + "│ │ └─ModuleDict: 3-3 (7,534,488)\n", + "├─OptimizedModule: 1-2 --\n", + "│ └─ImagBehavior: 2-2 15,732,120\n", + "│ │ └─WorldModel: 3-4 (recursive)\n", + "│ │ └─MLP: 3-5 (1,335,851)\n", + "│ │ └─MLP: 3-6 (1,181,439)\n", + "│ │ └─MLP: 3-7 (1,181,439)\n", + "├─OptimizedModule: 1-3 (recursive)\n", + "│ └─ImagBehavior: 2-3 (recursive)\n", + "│ │ └─WorldModel: 3-8 (recursive)\n", + "│ │ └─MLP: 3-9 (recursive)\n", + "│ │ └─MLP: 3-10 (recursive)\n", + "│ │ └─MLP: 3-11 (recursive)\n", + "=====================================================================================\n", + "Total params: 35,162,969\n", + "Trainable params: 0\n", + "Non-trainable params: 35,162,969\n", + "=====================================================================================" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from torchinfo import summary\n", + "\n", + "summary(agent, input=(obs, done, agent_state), depth=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fine grained torchinfo" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "wm = agent._wm\n", + "data = next(agent._dataset) \n", + "# self._train()\n", + "# post, context, mets = wm._train(data)\n", + "data = wm.preprocess(data)\n", + "embed = wm.encoder(data)\n", + "post, prior = wm.dynamics.observe(\n", + " embed, data[\"action\"], data[\"is_first\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "==========================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "==========================================================================================\n", + "MultiEncoder [128, 16, 256] --\n", + "├─MLP: 1-1 [128, 16, 256] --\n", + "│ └─Sequential: 2-1 [128, 16, 256] --\n", + "│ │ └─Linear: 3-1 [128, 16, 256] (4,233,216)\n", + "│ │ └─LayerNorm: 3-2 [128, 16, 256] (512)\n", + "│ │ └─SiLU: 3-3 [128, 16, 256] --\n", + "│ │ └─Linear: 3-4 [128, 16, 256] (65,536)\n", + "│ │ └─LayerNorm: 3-5 [128, 16, 256] (512)\n", + "│ │ └─SiLU: 3-6 [128, 16, 256] --\n", + "│ │ └─Linear: 3-7 [128, 16, 256] (65,536)\n", + "│ │ └─LayerNorm: 3-8 [128, 16, 256] (512)\n", + "│ │ └─SiLU: 3-9 [128, 16, 256] --\n", + "==========================================================================================\n", + "Total params: 4,365,824\n", + "Trainable params: 0\n", + "Non-trainable params: 4,365,824\n", + "Total mult-adds (M): 558.83\n", + "==========================================================================================\n", + "Input size (MB): 487.31\n", + "Forward/backward pass size (MB): 25.17\n", + "Params size (MB): 17.46\n", + "Estimated Total Size (MB): 529.94\n", + "==========================================================================================" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summary(wm.encoder, input_data=(data,), depth=3, col_names=[\"output_size\", \"num_params\", ])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "decoder\n", + "==========================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "==========================================================================================\n", + "MultiDecoder -- --\n", + "├─MLP: 1-1 -- --\n", + "│ └─Sequential: 2-1 [128, 16, 256] --\n", + "│ │ └─Linear: 3-1 [128, 16, 256] (393,216)\n", + "│ │ └─LayerNorm: 3-2 [128, 16, 256] (512)\n", + "│ │ └─SiLU: 3-3 [128, 16, 256] --\n", + "│ │ └─Linear: 3-4 [128, 16, 256] (65,536)\n", + "│ │ └─LayerNorm: 3-5 [128, 16, 256] (512)\n", + "│ │ └─SiLU: 3-6 [128, 16, 256] --\n", + "│ │ └─Linear: 3-7 [128, 16, 256] (65,536)\n", + "│ │ └─LayerNorm: 3-8 [128, 16, 256] (512)\n", + "│ │ └─SiLU: 3-9 [128, 16, 256] --\n", + "│ └─ModuleDict: 2-2 -- --\n", + "│ │ └─Linear: 3-10 [128, 16, 16536] (4,249,752)\n", + "==========================================================================================\n", + "Total params: 4,775,576\n", + "Trainable params: 0\n", + "Non-trainable params: 4,775,576\n", + "Total mult-adds (M): 611.27\n", + "==========================================================================================\n", + "Input size (MB): 12.58\n", + "Forward/backward pass size (MB): 296.09\n", + "Params size (MB): 19.10\n", + "Estimated Total Size (MB): 327.78\n", + "==========================================================================================\n", + "Summary Failed for reward Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n", + "Summary Failed for cont Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 2, LayerNorm: 2, SiLU: 2, Linear: 1]\n" + ] + } + ], + "source": [ + "# heads\n", + "feat = wm.dynamics.get_feat(post)\n", + "for name, head in wm.heads.items():\n", + " try:\n", + " o = summary(head, input_data=(feat,), depth=3, col_names=[\"output_size\", \"num_params\", ])\n", + " print(name)\n", + " print(o)\n", + " except Exception as e:\n", + " print(f\"Summary Failed for {name} {e}\")\n", + " continue" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# fail as no call method\n", + "# summary(wm.dynamics, input_data=(embed, data[\"action\"], data[\"is_first\"]), depth=3, col_names=[\"output_size\", \"num_params\", ])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/networks.py b/networks.py index 6f22d0e..7661f7f 100644 --- a/networks.py +++ b/networks.py @@ -320,8 +320,8 @@ class MultiEncoder(nn.Module): for k, v in shapes.items() if len(v) in (1, 2) and re.match(mlp_keys, k) } - logger.info("Encoder CNN shapes:", self.cnn_shapes) - logger.info("Encoder MLP shapes:", self.mlp_shapes) + logger.info("Encoder CNN shapes: {}", self.cnn_shapes) + logger.info("Encoder MLP shapes: {}", self.mlp_shapes) self.outdim = 0 if self.cnn_shapes: @@ -387,8 +387,8 @@ class MultiDecoder(nn.Module): for k, v in shapes.items() if len(v) in (1, 2) and re.match(mlp_keys, k) } - logger.info("Decoder CNN shapes: %s", self.cnn_shapes) - logger.info("Decoder MLP shapes: %s", self.mlp_shapes) + logger.info("Decoder CNN shapes: {}", self.cnn_shapes) + logger.info("Decoder MLP shapes: {}", self.mlp_shapes) if self.cnn_shapes: some_shape = list(self.cnn_shapes.values())[0] diff --git a/poetry.lock b/poetry.lock index d9ceae6..7cee8d2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3578,6 +3578,17 @@ type = "legacy" url = "https://download.pytorch.org/whl/cu121" reference = "pytorch" +[[package]] +name = "torchinfo" +version = "1.8.0" +description = "Model summary in PyTorch, based off of the original torchsummary." +optional = false +python-versions = ">=3.7" +files = [ + {file = "torchinfo-1.8.0-py3-none-any.whl", hash = "sha256:2e911c2918603f945c26ff21a3a838d12709223dc4ccf243407bce8b6e897b46"}, + {file = "torchinfo-1.8.0.tar.gz", hash = "sha256:72e94b0e9a3e64dc583a8e5b7940b8938a1ac0f033f795457f27e6f4e7afa2e9"}, +] + [[package]] name = "tornado" version = "6.4" @@ -3836,4 +3847,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "8d04aef5b114f7ae76dc03bc61d308f1b239d390b0c71fab7d0c8f467cc95dd4" +content-hash = "0275da73363d94f6a5cdadc9662c1b254ef50310aabce3aa663552aa4802b001" diff --git a/pyproject.toml b/pyproject.toml index f378589..5af5591 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ imageio = "^2.34.1" craftax = {path = "/media/wassname/SGIronWolf/projects5/2024/Craftax", develop = true } # craftax = {git = "https://github.com/wassname/Craftax" , develop = true } chex = "^0.1.86" +torchinfo = "^1.8.0" [tool.poetry.group.dev.dependencies] ipywidgets = "^8.1.3"