From deba082cb4dcd36d58b497c7e770c7948085ea2c Mon Sep 17 00:00:00 2001 From: krfricke Date: Tue, 14 Jul 2020 08:16:05 +0200 Subject: [PATCH] [tune] PyTorch CIFAR10 example (#9338) Co-authored-by: Richard Liaw Co-authored-by: Kai Fricke --- ci/jenkins_tests/run_tune_tests.sh | 6 +- doc/source/images/pytorch_logo.png | Bin 0 -> 19688 bytes doc/source/tune/_tutorials/overview.rst | 6 + .../tune/_tutorials/tune-pytorch-cifar.rst | 278 ++++++++++++++++++ python/ray/tune/examples/cifar10_pytorch.py | 236 +++++++++++++++ .../tune/examples/mnist_pytorch_lightning.py | 20 +- 6 files changed, 539 insertions(+), 7 deletions(-) create mode 100644 doc/source/images/pytorch_logo.png create mode 100644 doc/source/tune/_tutorials/tune-pytorch-cifar.rst create mode 100644 python/ray/tune/examples/cifar10_pytorch.py diff --git a/ci/jenkins_tests/run_tune_tests.sh b/ci/jenkins_tests/run_tune_tests.sh index aa97a9bc8..8b8fe0160 100755 --- a/ci/jenkins_tests/run_tune_tests.sh +++ b/ci/jenkins_tests/run_tune_tests.sh @@ -90,6 +90,9 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} python /ray/python/ray/tune/examples/bayesopt_example.py \ --smoke-test +$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \ + python /ray/python/ray/tune/examples/cifar10_pytorch.py --smoke-test + $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \ python /ray/python/ray/tune/examples/hyperopt_example.py \ --smoke-test @@ -116,8 +119,7 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} python /ray/python/ray/tune/examples/mnist_pytorch.py --smoke-test $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \ - python /ray/python/ray/tune/examples/mnist_pytorch_lightning.py \ - --smoke-test + python /ray/python/ray/tune/examples/mnist_pytorch_lightning.py --smoke-test $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \ python /ray/python/ray/tune/examples/mnist_pytorch_trainable.py \ diff --git a/doc/source/images/pytorch_logo.png b/doc/source/images/pytorch_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..7992605b01f44b7002c7daf8a88fd5b148139a9a GIT binary patch literal 19688 zcmZs@1yq!67d3njp&*EKC^;e`(%m6LN{I|1-69Q=5+foE4I-U_)PSUPjUXkZk}{N} zgrFcH`CUBE^Zx7m*80aKG3%b0tL}aF*=HXjbhK2iUb%h+f}pEvst8>OB7lLP2Z=6$ zuYD^HAHlavR`*p9(D}u`+}5H*@C>n=s*xvn#l*#b*geh+U-0DRr)ru?mlp`}$VJ(> z7rY`Mh!s*pDCqf4ZO-@vFpf3Koec}+Cvh8X(W9NBxd)?d^hK543-gPhZiukGA-LC} z*+|@2!Y;Vir?SP{5Yh_x?uMmvYQcC>E+@fr{6++UwF2;2bdYj>TfxT1spw ziAy|T1a-aw0z6*Y8I3DIvX|#|C?v#*Li{Pd?Y_(@USc4H(d1wCoZ@}{!?ntpuBeP0 z4C~A17Dk5d@9>=p_>Ka!7fmEUEH8{8h8~>6{!ILLkZO<>9_rs8Qwdl(pnrcv|6fn5 zLGb@Q4X*hAuKxdjx|j$lK=6?Ox>{dYSa=66SVcdiggCm|Ihn@%evF^==g*%rA0NRX zbkE`*K0f|s0m@spFE^St7%su#7ZWCixG2Lc0~8T$I^3)I7=BHvLI*s^33mzQ1)1Sj zUc#1Qn#TE~YoTO&!L$My8JUV2j{Z3;7W)QsL~!$gt}Nlaf7o*{hkwS;Nk>%^JO#-k zaxoCp$!~$%mqeO!!A;>`83;b)w{2J0P5v|8e=aU2O(cMONTU{!G;qBBY7+`SJtWIJ z495R`bMrI&&Woi-XyB1kK?qw_1S1rS4iHW3vvI-&X&_u;yR3X}QARjNUo1MsRgt#W z6yQ3=P6D?)%xtQ9ftfqwYGHD6@-;#3i?{j64hg{I^=8+Rdm*eu+11<|clxVWAo8%6 z7sL5?8$P9?19T8uHQXCY{nbO3i2rM#CyB5TrKN*9XE{3~{Y(lruyp^fUEYq^LqxU@ zw|_>H2!ZW*;8&_Y*6*`5{q5V-lw{U-)_-49TE6l5SaM`05p^fv$IAko8@r8EiWwcH!$ zmTnme7w?$&8AFd3nGSJcsZr7B0IQti6q$aak!ID%UYWnrxtMXNxc_Y9F-6JM5DONF zO35Y|$`9Y@kL;y$!e#S2e0@tInaY^|Mg*FtEAZb5v*fm=p`@SyhqfyNf^t)0j3>QZ zW#by$$D>76-Pll6kg$9%s!X@5vvA?4Dk`PiA|^SLcfyO0<a@n*!vc;vz!z6WtiUR-_+Gbvt6L&vKa}ZrNriwk ziZwyi>F!u!y|a|X#e9*X<;3!hx1p8yrttW4zByuu!bHaM%h&Cd4d0EGx@gQ?V=eyY z;T9F+ic7!t+K%?3-nYthUoRx@d+yb?J-2#%HX%J+gRH%Z^?`f|k9xm3j+gy)RnGqR zaOHN~hOY1v$`csG8Bgxjwuo-AdgtS*tmrJ=3_RXGB6LjKiKkZe%|5Um-hHIEM-3go zdJp(2&Fh?xt&}tVdyhzM$SG6kR1Y@PdusEmQM0D7zNYSu_UOTBxb$#_efkO;M7#R; zu)Ws#m@_Qxe=pL6Yy$6W-myirW3LwlY;0|rt^3V7A+v}r{Ibc6@#in~S3jY*yf{c! zlqfDiSfz*EB2h6j-S<54m|Cru3&&@R05rlQdXc@u)X()h`P&A2!Fo|pBtm-b#1KSp zU@S-cqCGw=4;xgGLsen_^WmF6`}_ON<tL%2W)O2b$n zMy*F(O3Q6Ya_{X1p0kYz%Mr{kqr)+@6ueK=E^b*3JRRFxAk65~v|Rs9UUHOw*7p1z zD(s$*&w(!T6oqM{{;hNSQL>ltP4)jS?_!4_i$60nH^?o&i@cP+Wv~<|C6?#2bbR&% zDMNwX_);$L*GKsu6iy|`jZ}k9cJbZsA6EO{HO9MIo1M*Uju^s=k9A8ms&i)Pc>KS4 zN`QH;Q5@YTGV5#1wfc7W?s&R)T)y*z*Gv&&spkQ&7?fkn9&xeaY7i!s0D?A~S*tYh zGp|}0kagAJb@yD{y0i+N1#jMbSAov1e_7KKkWL0&PMtVXkGyS-EMmY=`Rqh9y9Zn(5|QUq{f8@_j-`wVXz(Yz&M#S59&TvM4sl+*@C&ZQiYOK0oFKru0OF}7Y z18L?nI_+Eh52G&;C_oR*nV(Ho(=}2-G$q@x)Ek1tZtD9@b~uiU{o{ZYdr(5I;`JXI z&s;Md16go*G`EL$%_^ObE1=kc+A>V6ql=4daZyqDi1Eereq0r2LJ^}PRUz`E;*C!8 zI1kcJm4CO3ALbUpSw8q|Cdv%iM+`Z;#yMv)<#f_2|9$T`Q7#49+~a`B4iHh&)Q5$I zHG$YB5KWL2b?0GWbQB0>(F70N7oVck)uDmd9G*x={)QCav)OY(aEqljwB%7i+c>Xq zd968dsk5%%QIzd=j`;zGwNH>a&bVQGarmlti9h{KRblCc-@bbwm{6{`gq3pA9*@WIW<39`|?V`F0x2zXV@m(!w>hegcvQC(zL zLvm*iPS5X-uD*Z&KKhf5pfmg{k;N-?G@edk>x2t_B@#J-?(N4kO&^@{*iIJq6pgGo zD3FW87i@5zLZ~lLH&l$TLjslPJ@LsPL)g?{wxCK{Y$;^+U`zx`##fDI5bEC8++?(H zBtzcbE79XB$2=tMq%Xx-BfZKj>Fvk~q@c}D-N7-Kaz+6w1m8B>M|Sy_YZ*d;*_hc& z*ny4bN)?zx0{PBzoRYh;DC>ERr9N(mII6kax-?&UjVO_z*IQpRlu=C{@yvB)APvda zezNlc?vaNM;7n3yN(__;LCf#ONCmb1ovT@|3Fg3ijOP}Lxo5-gnxkGI7`xb#i_tD! ze<&g%630#%5ko_1HB#q%8T-cdNXQInN#AV^?IaUaa+fM24gv+=MkA)I#44Ak4BhlF5^De46ay=Uk@WIqpn(top=5Qh?H z8Sg@q)wcIP)%|pVw4-_fXOV3S_`74^YvAWd1tPdkC%@K%F>mN{k#0!|6fLtY(0uMU zOdmyru<`M*yM?lcq)RcuJAuIriv8?RQi*0!EfaI`?5pjg6(ZCo5h|D+I&-f$HJ!^` zyV1>K;?HNyqU2(a(T?n$P^5n6;_NJw+uYoo@__x596(UgPH|1-uz8K$phK8tH0Pfv zx4PXqUbwk>fv!V;HS;C6{#YidEwg6tpRpcjM``AYt?lg`m-Ll45E*hz5jFtvSAPZv zb6*(BA#`g)M-P9uJ-9k^6M9$^Jajcg-WeA}i~Wg2BHQ^^=%&ZV$B7ll15A*^24go9 zRul5K|CD_i-zG(#nn-4aIpfqQ8ZTqB$~R;$%isOYJTf^c4yTREx@19%dI)Rbv&onq z*{b_XO!3?q*N<4bra(TubR2YTdlSHI+7GGIcxBBZy5-~VMr%>fndcSk4NB;l_o$0$ z@RUN&w&!`Z{kJQ3bSvhBXIP2qM{+Qtu*p((l&#SmRl?nS0RuwWtAV*4n6BQluRijz z%lf|A_Q)sUbEs`p^;A7~`|};Kqz4tnGD>)4EQZ8Sdyx9o}(i8 zFb1e>@Jd!2o}?hD!rwm zLu>rDLTMV-WWlGw6fW!7;~J)9b$?>%I|w`oV^?*MRl}&QfanE_PDxJumez=GyyS6teq;`fxder z*>+~;>sP@fA0mL*h%wf^8=Qr!d#C1Zd%jvQ-wFW!iB_l%@XjwPf21_<>Sd)RSzKKX z;G#C;^SBGt9}*tC1+|5KXJ9yH>*U{JT2K0e!{MxB7fqUM`;*yG8%iJVBpLIe5}|1a zqjAChYF^Gn8`NZ7V4n5TtTn7=t1)hQo$ zF&aM>bDfvPRcr$+YGdz@O#WDYai!uO`G<6zuj9K6lT^Ax4+24!y}}=+)|{K<)7<} zi$W}!L6_hoxdiJ{92u3b({p8a8idbMRzv*fjYjbniA)Jz<2J?;d1{h_fdt zc%y;0BuSZ@pMR4FF81=IEb8uC-j3+XM~i;5_)gtj6$J&h-N`)0$&tNDkn-y<1~>5M z>0%e9r>Cb|TyKp#TU#$qCfrn!Wrr8hY2KdZXXgH6FCOa|;*j1iApY}>NuoRpSxD|k zby?n7)su^$Q|D~1k3`l|zqosr64{42YH-Gd|GM*m9Z(U?<09V6pR?-^f@2|cIyC95 z+r-tLFI1dobFbEXjGw8kGHPPb!O3a5J@f9H+S;)X^IUU{O|ni*Uq6FPumj|Lu+O@6 z!@^^Wq@BKPjg5_Y-hVDqVt}gG@`+ZAw=JE)k$Hq&jB>`Ml&rsiW({CtCWIDxRoR)g zM!!WI$4_Xme35^p=epV;=I>|->@(3jsh<`qO6^(oLwLVR)tLdY-8lGG6$hy-Ejw`U z^l5r7^wbpSqGO*G^T>jPJl&CD!pi`ys**}(=;gI{`%@$H*R?-Ny@6XLnDwhpU0sel zOMSI<&gT`FSv2`W8THTXo&3xon(Zd(w1Ygp>iP5Mgh~zNor;{FUD&|><~!le0{36m zSk>W+Y3`oEH6asL_%NRM`;Us!W=J}z*qUoOi~!gLkcAW{w1023`5qs3HlMxwYAyNu z&2mLCK@saxhm-3NzK!0Q+!n38gjgnH_}=#K{YDv^#=mU+ftXC!D!4}w+9U3;ZLU+_ zB(qQ>O|ZPKZeIY(!Bke5@xa^bxJ#$$%9D7E>rTq)C%idF{h`=JjVmGZMx-u&hH9c5 ziGA(VBg8+o`0a|m#c(Hzi*HgMa8X|fl zi;fQ4h`l1Frx;86I7Dj%iG}Mln>m5F2+BinR1GDiw3GsQ4DRnVz;#&& z6z+TDL_wsy(%{pK4!~8HQM7FbUnM+3`lnXlSl)=&^AFni-0t1R83ADl8AyN<37uP} zfo?#X1eHd}`Uq7X=pnKvQkC?@2*s6F+O8Lzc@752wj%fqH?U7%=OpsJ4fOKgZov@j10kpzx+SABPHjqmJ;aR= zXQ3Sf=X21!B{I+u$+ES!<|pyz<=kf-0|SEtHVO$CHUc8MD~~&iTok`)GI2WLp-vzy zCYA$-&wW~Cn70QCQfG!|{;u^kW1_t#vsPB|h^nVP?${2Puk5OEBno`FEwU zl1&MQ_k)G5dkxx&Heqn>slO2F`iSrbq@gp04+)p9e}{Xx$u1m=9>kOu0^~aXlyuj8 zf$Ak#JdQxAB}$4(Kk?Bz4T~VBg}m~2(d))B#tnb6V``?#!&uiPiLHgMAJd{L1TlZi zuHU%PgGW?bgCHl!=nQfpCz!Y%_Pc;wN$$(%&%bJG9VfOpQzS1!F}{TvI#28XB&{_x zbSVxag{^#9T5``%wi1~kUj%Y>VsvymQI%cyulVDA$s_CiYd|=;G;c-GD39E&BQL@P z&DQ<=33yMw1|@WSJ0S^ zDYKa+D>wHL_UO^0!b=Lj#8G{IWXrr1jeJBQ#*p;N@mX@^Wx!AAP&D7<;(nfO|MqQ> z1x)-d{C4qrAGBaeNgZY>tw-uSTwRt1_qcY+X9;a3g_6UpTR}`fXOl;>n<^?Q&h#eS z)P}jyUMfN_Iof(64*@!(Pzmsi~<{ zts*%NI7?>+Z74|yt1DkwvTY1mh=pY~71V|~lTK2fse3YpX}hPd@6G;UVUuODPDT>} zsPz2w=KIvv#tGStrcSx0ukM*bdVBhz<~fcjXl|mA`SWVNvjml8JE{mDScBu#dk6FT zMI+`P<~iGY`yiwY&OP&m=|o_HcFNbzm$3ULbFsFota6_1F4yZTH|&)E$S@0MW#NZ? z`b+K<>sw5LRG?A8`}}O~?bS|VV{b6mptt+Hf^s%rqWa6OcW4kEl1M3c`AF;?9L(im z{!S?$$Re}8H*LLCY(2DAQBj%lnz^A8$2`LYPqA_{^OiMP6dww|-}s8D=dX0f&ktuU zXhsl+BLLB_tS3hKe4Uw@**HE!X0I}Fz{4&01=3-aUk*;&SJpPt;T|L%r~YT&{z^;c z9GdCEg9oQ`)RM4Y-qORyTTNRB0duazA@gn&jSTWfGG@wAmR#62OcSdH9;?*g5)R{! z+GwIXY&+MH1w6FwBF&N>KKA#Hl$WJ_APfcM#L;rV9vjZ8#{`hh7$=pF@$siVv)^1O zxJ$6JT-@`Z*IqPxQZ|gD;J}yST64u-6X3U+4`iiiqh`oh2YwAxG4?ocg&)?})k#9J zWDi_T$SZpRE3Cx~^>@A*FJ12e`qt`kkd=7*LwoyGFqR2@1B3dKbxl8Un_A4QgxC5g z{{5<0Ky;dNFlY9$T}QBhncWtEZh_V0tIJJWQQ50m_VnGB@Ez+V2bdJDF8Gp#H#6B- z1!gW?*x?m8W80FFl3%-Ya=@VmnUmS$^unw2!O+`CablgTSkm7+!3o>PgXJ4%9}iBw z6y5X6H@pFyvXIQD_XZk`NxXaaE+(LZpIBb;ac9|~3KY|_T70j)X*=YjsIqW40^3m_ zhIhTThF@LfEw82UnhbMgPpAg7?*ln0ph7+zlj;ZvASYN&Q9K#>hYgth7Nux>@vhz% z1T|3<-|BarsLPiw6&|;pA6p#-MIfy5w~uoM+Q^jy6H*_iuXu5O;K5o|O~%2Zhu)xF z(^nF*S5K)qPhwR;CLOWSB-j5iL-raQt!rS0{lG|=B_j~^lDc+w&h5a7X$8P_Cre8T zD(|_G$mtj`7yM?l=@-V|PcqBJ%k)5*w6pvY5IdIrX6<1o2BSWG8XONDjLQUx^_InVeukSP~7L!J(+_=`p` z<}m4-3v%~MayLu$&6*X(f&FJ01A~7*I`#Da{rfM_CtTpwGpPdBvG)Pr&^?^JYfv80 zx7(s^tD<)zVH|gm|J1uy2Ws!P5gcsWHb;_@^!R7{NT6^xnOCfq#IJ+^JuqH$z)?#6 z`QlAx4ju)wAchiEYdXH4p+8R0;STE=H_oAawe8PpG9=#<5q=Eg>rGqc0u7VM^6SQr z9@!;(obdEZ5@TU(kse&~zsWoK%T?aXe7dZ+HmP?05zKe zNGr+>9gPXT!zR8~0_b!n`Q<4fhLqR4a*j%S{Ct6?X3U(IUaf7=pLifFnwLmr( zFvO^|4+}yp}5`KT)#8qc(t;@e09LG{i`UK@ObZ<$(!o}?}A|u1M zJ~%QuUleUEoWEtZBgs0U5-@gLTvy*iuQ`>tD6R>o^nTs@X=`h1HXy4X2z7%*37fiO zH5bu6+3F-CJ-r)BrCX5lq?Z3}qIe87aD5syLdb3$2dQbLZT$G58pkYiWQ`Vu&l~o* zYw-u}M6z8=L->RGuvR;aYJ&ROq?JbkRKJ^Z3Yku(B^JPld(qdRVEFG$rM z^_o%9+spneb|1+B1Bc|Ev0SP1breH_E>V1QH}|70fpotH^d*+w2ecz;lN0fh2?=k- zWOUkNf3;&}l7~I}wgkyn=(?uB2gZT%YK} zy0HoX-@qKPD#*PxBKH%4Pu8HgQ0F=qhJ)aCLQ!M^w5RU|281mr(c;7Bdv+~8i9VH(@o znJ&T~0Yle}%Jqr}Les9eP3SoPMr`v+rJS0>-ycYsOz=dl+PJ#91_MS%f=!PPPX#4V zr2u8OTFefapEaau3r3?^&iWgb%pBW4GnD`r2EW;b*k+lWR`h zU_NYR{vch7r?yv??l9-zHtCQk{aZ#U5XA7CyvRKTngTVy^=I_cp5oDaKwOR53Dj!a z_JN3M7O`IdwvODta|XVVq4E`0%HmDfKev_Y}5CScCo^pjU=pG^>6E{7hA(%yp!ul7!<&UizhGZNp`FEd2u9Soh5 z%+QVhz$H^9LVw82E5;tcos}ra+#TP9HG_1G zO{#(}dGv;&WFUoy_NRAGO%EQ_&3~T&xi49$R&o&_JBILtNqrZ0chiGAkQ}K|fCmye zc3b^V!OF7>>52FEYR3m~x}S_)Vq$bo{XX+%-i$oK`c>c3lz^DA}YA&AC5RoHT4S86^9+ zmaN-a=8#~HV!%Zce)gOVnqa3TXT4_j3Qle%^O>JrLEU1*O;L;DgYA@O$5v5XD=U)m z`JWfychTwAhJQ@LKs)Pth{bcpeop=UQ$s{KKKU}+QOovmFyAO~xk=-WNT{(xK>J{Z zSjxvwAChge2paxitoz#6L{`F%MkXl4q2+dm()oYQ5pZk=AS`Y7@;2XMRFZo zzEs7(Z!}%Y7-{=`gO-t4=L5*3u49cTSf&lA&uG~#4es+r5aKCaq*K@a{Q1)|pl?Eu zb-nYD4!U_S`E=*it^0K;&!VM=pXnnN2al0s!zrH*s1tG;%5kRadr77_dAu54wkoci z+oZ4aD>uRfh_OJm=oo)^EsN-f!h=}L!es&F>Rd_q>zUx)O$Sj$LqDMHaG>~90v`jU zsrA za^nIBu?BxYB{j)b2YCrK{F3`w{YDet&X7D?=&6F^=u%&zBZDjljC0tVVeJKltqk4%*sr*H{N0qbRBjH{a*AS4EnP z1O3l@eQ!y8$Q>#STZ?5qfye7zfbT2I?Qo!!TMz)SYB5MwK28#kCsVGf6nrkrToRc# zE|3)MD{bbA+kb~S8kA<-dSfFfVRLunvI0K~n|`|fm{ksC{H`8E(b!f~gx$V9ETa_k`DW#ggdNc^Ifu&6_^&=+%n4Qk%MjrntFI zuW+#}>^mfRvVEMi<824G&5&3U(;gVNa@EPVxPXzftJL%P4`yyBRlF#7bn?=hf+)X_g{FG zY2dxC0%xHL$gb4b5okwnT@k@+TxHdTvHX!0wr)ea@fZa@m%<~*W$zF$0_nQV3jhir zMMbbh79y7FA_L7{8V~0&V+Y`;>JzsT`;wM*#6WE`!j0seGst&>W@$JVTEFkbFJl;F zE(XDuvW{3^dJ>j3_J$&e9ST)L3626W+LLkht08wXbwULm|L&jowh6Bp%Wvr}U*D=7 zzb9N*FCdjYh}P9dV{}d+CX?^! zpX7m_cvj~e?9-pj(~9xGiVEEI^<}{pg>N)_J$?EXHqDs#BekiiNhq=;8<0i7#u~lj zy*DS_>5Toph|AjS%RrDx-d%_tAyDqD1DP}>iiLG4a(Uazl$@Z5 z?0_#7b0ig-dY=cbByvk0A>AHNKudfAlzh^+X4x_a{ctY}7w9XANc7I~J? zoTMXL%!PZ9R%%70zGNl*UiKg+U|$;LWJ?|1T`pXQR8p-V6mVmgPe_;Wv>ThMb>uWtmfi4X4%{7OaZQ@ZEHCe?6}@G-HOf01UWD?uH%zj|$EGm6$s?AG(|O zIP`it^A6pw>KdEAS3iNWKs^7EF^ZWumvc)A`!3B{bBi*cZUUmU?9AZUq1>y zVnO)zkULpZhxe*Fbq+oD8f0BD02Z*+r;^U8g+@$KUCMS+G5If8`IXr1=#^zqhU5f$ zkaCo27CKsYEKupK0&?zhuH|65NaBN3^Miwx=hK%hvNxKPYRBv)v(6N`X^(HAemA}* z37LnJY9$ec{k;oHBBpK%SV2`h0)=0siU)BiKY$J3BV)yjn~}>WS38L&3t>>Ots@60 zmVa7|9O5L4i6EiG9oy9t5=qP`U*#vR^pV{ZAM*T0wn+pv9~1&TWP#7-2SVo70Fe1&g z%FszdQ3dhrg(Kc;G%MApvOl(T%Z56FB1HglSoBy!xBN;BwzNE zJuQ2JuhY}}@*=A685&PAz=5Ihn0rS6|C$7(wP;r0IMUIgu~iLzy(4&d|ISqyQ$Pt3 zi>pXO5I#&^8BvZYFPAMWDq3B(lKldbfdV2i(!H_^Z0z(hrb*aUeEJ>`Lo^@^V(D`6(*J+GA1lt*Hw-Ko-a-zN<_v;g_RrnFS(JA zR)EcW$Gd}|YW{F61+gR!JWYH_=H3XKAhgv@kX|=-y0-zY5;`61e7~vuSlIaSgS6|_ z(^tZ_cv?)3<1zkFXR>*rL)eoM%vAr63xkX;hqLfC6A1Pj;V{4EG8ZESPfFdh*Rd`Rk%?YS)|)d@3w4nP)5AM4c-LO0#(9X zIH$kO{rmmVZ=@$K92C&XvLQDO!3RwzjM72JQ;S{?J|TbCCgPHnOah`1i(ut)@hRa% z2-xN1CG_=zk)tUgmz3Qf6Fkv~{Ce%ijhZxqVnzoj7=#X@2`{-RARrho@KL;>B^t01 zq$6AQOE2`}^)MITX}chh-)y3W=s93JE6e2WjD3~Mr?QfwHjN_x&cd!c#&PP>2l-IN zNr1AqVC0C1y-P-wteu?6tvparl2#GV@b1tWcyNdF0r{Z@@|gA3MDRWaLh9A;B`eZ2 ztVDRcM24XqZJ}ELdW7;1XgBV)_>K!W`Dx=tlN!Z7B#uCKHshsNpq4TYOW;_ccVynVOo)0I{H7zue@nE<@1p#=Y0yd;q(t30qjY09#09 z)zHZcI#XN8N;eABO6H)CLSvGh(U4oS9gsJ(Kxvz!Ae4!z?_bt>EqxlN& zB{%I4ciFSrR1qFch#DN8N2r$jmAwvOU@mOqZ&d>}504gMu;Ju^L}N`R8!+g}lP8mf z5sxd}+2x57zO{N5Xq;OBRf$?!J11^b!>^i0y38)=1&G9QJa9PyT# zE9gAfu^%I-ju~X&Q^tTu-y21-ICye>xduDrxoIR`kwZIY&Zc$9r3*fr=?3g6OtpAj zmqKlHeB7JYsM_XP!DSAjkiUQn!6#N9NM^~hV1wz_TQ1_Yyw-7gR4+^a(<$XfdWseO z{>oNo<2jl<4QwxsQbzaqJL7NV7hAfsX?UZYEuk9f5sZ9ko9?0U8llMwQ^6xOzucgW zJceyjC$zB=T{vJdm(%HMNq+ucx?s+w(Nnk$aIh+!af(|cugr9N*->Z! z={r;_6WT(|aR#?3Bd9wM!s^U292t-^L;rmG%3y{g%q^7UpF8kL2!`(R08Hu7T!6zl zU8L3FLhPg;#A~%FW_2=#8Uek+3@Ss?z|xvmcyka0muC6LogQv6kmU48`SIt7H; zrQ^_l8Ged&)gDA!EgAKwnD7i_`(~F>YA7)d*!sOPkVVL!U5 z{$w|B$P+W{Vdo|ks-!;&2!=8DsM3xO5N&H}tEs4@v%U3*@3SdaL7NvKmKN@1#yO;~ z!Vmr~iXFDKN7Whysi6REcoep!$#k!?n8dwRkT~8Ra~-1z3LTZ2R2RTyVjZV)hd0ub z{$q#Xh28<@zCzS%9Su*4y2!OJSv)}Jt%?H5XVvvPvk0%43$P0hOY?g|X#@1i+c6=m z;9XT?qg?XFNiVAx~eE(Zgn~{FPI}Nh&XHDN9(fxpq~#}C%-r)>YsPaz1L50U4&SN zhUAZvay4|#<%3G+L0hF?2l_>}J#|TDL-{c)3dv^1y(4*x6@;eH=?UHMA*ce&hpG1z z$Tci~%r2nFZO8d9@WP46DE7Z$9?$0Z(#U^wURt7xwh4Nd!z8EcVNM(vS0t7lPe>vd z&7PDdk8mOSl9uIQr0?05?Fs6mv9@*KeDhlDC4D4$IQjF&HH$y9qeA2o{0ISF;mmV? zLz==&OM0h~3tUa}dA1eni@H`O_cJ%2zNf`ZHAU&!j~9H5ddY%m0|8JYN5e<{Xp-IE ze?hdNUU<1Ze(Qd+TdL=s>1S81ZF}k7DJySl%8-niR3w^~Z?tvvF1Koit2+VVE$Pe6 z3(oOZ!Fge&mxiy~fWGhKA4iY8EbtskeHC=Oc2jK`c=J|mq2MdO@9o@y>_&&_5;Exs zgh^jKgsFp4nWYLIW5UH#y3F% zgI#)zLe~wr@57cJQt)*e)2ngWTX#n{8X(NBJ`nzj_q@;{n}oHYwVXZR|iza0TjxYs~P*Tqj7dx`a8!HnJRcAOqn6=}O>isZ*s zk@Cf?vn(CG`pvP?-Ku-a57w*rb!mEW!hxmOq3Y9UmX3y}m2M7wR8* z=>rAnj9h3$mpOf12YJG(VzSh$e|KZ@@MQ#&{3$|OBZ>8pA)i5O&f&LWZiDbw1!^&M z`&A$+F9YAq-(DGai9C>KFIvi0{F_t8GX0P%dsK^wt3`TbUmb%9Mkg{f1&iN7oMc;0 zUQtAz(^mgbp_c)^-Q3N-V3nZHL|W?^Gv1f1(>9u=5Lf>a=JA-+-6h|pi{?OvY;C#F zVH4RR!_q0^wTk%gSQ}pxhoDICl5;gUsMhHmkTMAyc+lYB;AIHCYPh)vNaU+`6+d1< zJ#MDZ#YYTW)L9iEzLoEmyEO*d*#fU9W}XGsOby|?IlX;>o^2S;qnqGQ*Hd7^WnSpX zoghCJpRmtBXXaR^(X?&L4L#!e#x9?_=6dfVQ9ehp4(*AS@X zd3+(g%WFN$E?uuhw+&sn`r?9rLep+>auTgy7mF~SiXUUtPlsAxLBt|D1%JCP7 zmAw*n!RMD}% zvLdYvu!RNED_pE(CJQAAtxfLHuC!NvP|ezCot^St*4|jfIgEfM+BX$?V$%0m_m-fh z&;V@!6~d+>jN5*-bEy9}p~pX-rl~vMhoUW~sjB};bt6xg^^knZTaLJ=w-VJ{t{EP( zV>{xN9QBgq*2%#_F1wiUloNi=3RjFF)0qCS2b-PkmT<0t1qgW*ku8?MpYeII-_xe!*Ui9g9$eP|!Q z*in4ZS&>9<6t7CiWj-+Sead<)?b%+YugJ|~?Tyusaj1nlX%it_$n$}Ap*+1W2iysN z*&?&Pf=4A?FyBN8k1&`t{Y%=ZwRk?EU*^g)-shy_{Bu94jrdY^St|-y`Ef@^CZ0ekw{mInAe(=76Pb=du{w#PmIRDBvIZ4E3>6Kuy|o>r037#v(Nz`E4eB(*%N0ejd@lawKM(HO zi^dig6JX2Mequfd;Z6P}&OccVRl1*JNpkB*Wp$qc;tTsRYIZ`#Cwdzu&e~%Q)Jl(y zQTlFIh%}TtD0)#mMNdApjW}0ibv1Q8rnR*-WAam2GR6#_mB^^YZ>cxI;J-ah92$b@ zt=)!g4wy4T_A%uwwZ5Gt<3E6Rd?12_1ucug2L1IzEjIqH&^t%HOZ>r;g}ZT&9X1pL zY+SCpbDpN;Zi-w<3tX$V9L-~ktrSyB+6Nx6Pp^Ot!zQY?z9Uo_5j03E`lmWSi8;az zT>)$_^+K=yIZhdFu2yM}|Ge9r{rn+}na}QLj@VH87yUXuc9c^l)44L_M=sw+$(%6& zt$v#SG=_K|j=^9uV1HH(%Yp!b9}sJujJ^nQi__CU@*IA|*%=#Sc>=%pTR?&Qgw|G^ zC%3g#23Xs?@KO9XkGFnHyGj^O@b78s1z&&D01TVwM*6lV@roBa@ZLSD`Rp=w&@YIz zi|xbQ#bMKT`90;o{)Xz$CtY4MbPT6E1(=kZ9piDBBmDkqE#4GgP}Wbz z1WDL%mw)F(<&ld}2u)O%Z3{;Kdr$i)SY=q!*U@K2_amJQw_*@7yn}hD4RXdVWXK9NGm9z|57&ZK%hf<0OJ6B z;qkSHh^Q0qFcgRI#u%_cR zD#@#XZE|A2_uTmv1za*YIPu$OZ*(IfvZ*SzK%h-~t|5jZx$=ux+hwZh=er3 zi?Vbyy7wHQa>g=wB5==&){z2K{E_7^KtkmcxqeIjLr@kir(MaLIM>XXeGOonFWR8# z`qE~WfvcMX@FgeuVv>vEgc|!6FHp~i+b&ah<;_RsI`O>L2^{YD$^OKr%;NP57-PJ( zfEHnK83J2T^RO*%Qj9312#l#!udUu0XODSu8AZ%?tx87f6Gx}Mo(ZtBJT1V;rT~jY z60pP0eg-bLNOw0z`&WzNVhmxw>5T~ilK$cC=Fyra>T1XPj2Arm1>^5C=jmK}WY-m3 z7o>H9esflS(c!h8G2;CJbdp6`N2a_v6h&3NJd!4Ci576{C~i)1n`D2`!0eEkvRfi? zc3QsK@(9D2-e<#0SShjK4*cC0)o&GdSC+>%KB9&wMxO!S#EN%@RUOv5cVw9WM0fJb z#ylRu^PcmO(>FIi4|T?ks4pLj0bUHesmGXg%sL^o@8IB&_9N%EX6k-!Ul@}k_oR%IMW(=S z0M*{YeH4IiN8`(1!HOi?c13qFL~Uy|y_X zfwp3$30oU+=MAAe+azwoU)g;TUY^FvUGKYxSStxCmE0dg0fvH;bkO^Ge5TpY#{T0P zgher{)KFfrAic?l79C^50%*E}-+*K8z!fyg@k8cAo!{*ud!v$V1MLS=1=_cOWsH7k zThw!900ZoPL83$TnRy$DdC-C(BXIBWtIhn?)Hm{ zoZ%E|Q5=+?{2)3!`h1M*2u=GB-$&f<@n0h!ca!P{%92 zd9lcaPw)fRC|&Em|7Q=ig`NLA=!{3osdMdvDyK3YeH>(sPjF5jS3+M`m3Hta|EMJ_ zT+*jfi9(4ESU&)KJ^RI+;N$bRtv&WZ`Hz&s3dvX@Byx56yW35rrB&lPr@vx7?m^R^ zKf9R%iz~8Kq`%MZTwul4gdwPIjEuFb03f(3k3=tf*7%@(30as2$KT&01#sQ(gxo?Ei-HN zw+B)s@c*A&l~-iG7HBBSPXhFW5ff=Gp2)GT?2jVSf{(WGh z0+n-5#L@%EoB5h^RPQj%1XNUjt@CDLpY-_1NMcHt6;J>5tMd(;Bwr=abwB1Zoaquq zp85Mre!Hix4b0+>xcvzDPBCUHXSGus_KRsmxE-{0kid69ks*#y;nzn{lDo%OCjp8^ zQ=lv*O3JYqt2?#`RFY9fv!l7P+3GU=xdSRE3J~@`o2CCkwJTL%jWV(wkITu1dPEqK z>M4s$pS$me9;fiy3Zt*&1HZ$p9@m~Ee?zxQw#j9;QL?MRA-V^|yI9Slk!$$q)xW>S zpdpRXwzapnDvCO<#HtzmL9rE z+F1;mii4rc^kD)<$Qv+2fp}FQ@sOHc{9OP{`ow;wDIMyfXC1~_l(9-fgeu>5O%+;X z`!a7w%Dbd;{&ZKKN49-Cq#*OInw~sxnSu4!gxZpe6zwj&00{*$JmzgDG(GNic)tA2 zlKu+J@8}uJK7>Qq+$SgCf_p*hN;aTed)=O1vD$AWjFm{Xg92J+ zOVi*mHa3=f-DktBw1q1Q^Ic8GhcOb(f1LR#bu2ggi_WM1c;?&$OC^biz3N1acouF< z@b!jNUO2NN%kdFQRv5Q;k%g3*8dTF#}henWq}K_Jc1 zD$>>&dy~zjdu=`tazcWG(g|Y>{jf(E_T-J3HZsl;AwyrEJD)e4s)i+MsZdU zoT$4lB$FP#NB(z6Px)DsB;RutU@N}XPKgFt(%GL=T?LizG=lh)?|mlZ@w0;$wy)#1 z;A(ULOOlYS&3B;aCbVtyS(9+eT{zwx02h5h4#LDV-{)(9EB2NT>zm09NX^RG0iBOf0+Yp5B1PembFTIO?Uu zWy{EkO@WnheKQEQNZ^tc4|1Of@8(nIrsfS*)zd>Pr4x0rm{(c7?BUeyr7L;~dZ$N_ zdlQhlyKn?vgHV|ylBS` z>pcrLsScdZD>sjQR?v<2loy2qjP&`#ZQ6uw1_n?m^cqa7Y}gns)683Vt-+CkRxKXu zM<$Wr>9Sv>SpjqB6QB-?e@JubK$g!NhEh@>XZ-Fr`|aYWM2M(0Hs+Js=N^$Po>G~r zmS+r}Pvnn~-y~GpRpV&&aE$CF)``Gru?tUg-;W2fHs){n9G#p_)IcFM5*^@ARKJ5W z#=)j7(^rt7c*N@%Gpkt_6YvCdY|McU7RNBlOoLEo|9Vo?G0X(Gn_qCAJ|G3~sN*ML z@j*xm#tl_kjFhb{Qhx}C4d^9ULyUM>#=M8$pn0_ifsBQ+vip^`e4l7ch)Pn^3jH!L zynVGmcr9gUt>%D+dK=?=-+dDy{P9lI(v8!Zx8eR36i|u`U;&*V3BeQp#vkpex2FN! z3}1VVhYalIkIn!d3ZRYDa`XBp1T8h(d_Yd%Q1^M-mdeEZJ@i zgvAS9a)j^rk`piimso$w%}@0gg;3D&QUft&IFAgG2mTR+v?1RP;X%3o)5Ot-Go8o3 zW)g~ix!7stTB5Vd%*c_Kx`*!S!gU!#p=7R`nWZ%^+0cuLxa-v}Hcu|)r(;qSVb05R zR+Pluu(*4+qTx=W7u$VV_dNG|zW@Bb&*%GnK5yUO@B4W@%H>)pFi)Uok4bPL$__-N zD|O2wRzR)3<=!ZS;jng>Q>8ug%eOTWt{=CV`62qo7h~ARq@FdI(GA#}M%yJ+OuWT1 z&Bx-vH~(Zb=kB(1RM<(+5DxFpZ?A`G*$%zs_?CdQINQ;6Ag76~Ny1Z2q+_$7>S&98$>$qSF94UuEq8G4g zh_zr0QT~83UxDn7b%3$@PGTpUhN3(e`xaBOf5+Z@k%T^*OSF10uoFBPky>Q@Kc zJQRPf_TX8H@@FF!jd9g1msl;p<=&=|PL^ERW<3#pWc@_2La{n%~a9k6hAv9+gF{|o!d+0#P)Nl z7(dY+BU4Q$9#@WUb_+;C$LnigTBj|dL(yiIqL!HqT^j1EGS9wUK|xV&9v&&2e$F1G z+M7ZuN0|uw5IuzMV~yua8@!GTsS5a@g^fXO`64gQ3{S>!`PxnHm4w zCqUXYeR?Z+T8!4mx`8x6qo97~75!M)T=1?0EsFANr^ddr@;cKNJu zrik(8+5EF>LB>juAu1G!ct zI+htHN>O<{(r=jq%WM&(Kc75S0g+xnN0m>eYwt-UQ^fy3h8nm*v8rEU@`hI<8@2%e z+dX)b-P|}4e(_N%n!C-Jg&Ld%RCsTm;3&RP;o+Y4o*%$)%Z4us`!{i&C~KB31tnNe z5`=zVbPWWGz0+Rr|3FtxW7_OQp1@)gw-V(gj<)4k+v#B)>>V6@t&lxk#a;5rofcrL z3E0Zt<<~YsR*7uPvzx7L^(^>K@w!&<-K8y;gvyJ0RMu*|g(+&Bk+hr=_t5JTA3Zvc zYBie>_5)n0a^boc_JPyca(P2XVH(@jCVkH_JrfQ%;|6UI7Cz=xOM83!pDVG8(lcZ# z5~|Rue9C2E6cqX*_dDez|F#M8cH3gHm>-^{U;a#*X~+@?f(jNOFGBSg8?eIwr?OgR z-I0CV$cWeQir!zjfC04hPca#1t}+!em+5_^=~5@c+HQIF_R>dtpdsVs z;_l8Jxv)1$*_0j%{}lJ-?|1pVq}`vV_Pw*RvRYgH?IM3dqIKH&hM&qjJ|=maOzDbp z1-=JWwhfi<7<3Ve^8MO;FPHs}ATWRTH^H5CgyDNT(|#*B)jE<5NtL#bj03RxD?0Dl zDl6DxDuOWj7OnT8=(gA*2$e%ln4G(tq&AP(05WoAEg4{S6_!@u_B(;6j%FlGaek`qmrg{|Xm@0$`pFH%V@#$0kHXo> z-)q#wI?`ji<-cCWQ_7s&9Mvt6#9O|&7&P$HGD7(E-i_rRwVzBN+d z-`1^CtHl=__1gUssi~<64ek?Fj#qZdzjm4CU)@SA&JpxlT3X5;ihcRe8S&y)C#3Ly cPe@<0fm>LTQR`6+eqnltf1qFG7tvY&0&zqr0{{R3 literal 0 HcmV?d00001 diff --git a/doc/source/tune/_tutorials/overview.rst b/doc/source/tune/_tutorials/overview.rst index c03593114..450b8d6d8 100644 --- a/doc/source/tune/_tutorials/overview.rst +++ b/doc/source/tune/_tutorials/overview.rst @@ -70,6 +70,11 @@ These pages will demonstrate the various features and configurations of Tune. :figure: /images/pytorch_lightning_small.png :description: :doc:`Tuning PyTorch Lightning modules ` +.. customgalleryitem:: + :tooltip: How to use Tune with PyTorch + :figure: /images/pytorch_logo.png + :description: :doc:`How to use Tune with PyTorch ` + .. customgalleryitem:: :tooltip: Tuning XGBoost parameters. :figure: /images/xgboost_logo.png @@ -87,6 +92,7 @@ These pages will demonstrate the various features and configurations of Tune. tune-advanced-tutorial.rst tune-distributed.rst tune-sklearn.rst + tune-pytorch-cifar.rst tune-pytorch-lightning.rst tune-xgboost.rst diff --git a/doc/source/tune/_tutorials/tune-pytorch-cifar.rst b/doc/source/tune/_tutorials/tune-pytorch-cifar.rst new file mode 100644 index 000000000..6d1b0363c --- /dev/null +++ b/doc/source/tune/_tutorials/tune-pytorch-cifar.rst @@ -0,0 +1,278 @@ +.. _tune-pytorch-cifar: + +How to use Tune with PyTorch +============================ + +In this walkthrough, we will show you how to integrate Tune into your PyTorch +training workflow. We will follow `this tutorial from the PyTorch documentation +`_ for training +a CIFAR10 image classifier. + +.. image:: /images/pytorch_logo.png + :align: center + +Hyperparameter tuning can make the difference between an average model and a highly +accurate one. Often simple things like choosing a different learning rate or changing +a network layer size can have a dramatic impact on your model performance. Fortunately, +Tune makes exploring these optimal parameter combinations easy - and works nicely +together with PyTorch. + +As you will see, we only need to add some slight modifications. In particular, we +need to + +1. wrap data loading and training in functions, +2. make some network parameters configurable, +3. add checkpointing (optional), +4. and define the search space for the model tuning + +.. note:: + + To run this example, you will need to install the following: + + .. code-block:: bash + + $ pip install ray torch torchvision + +.. contents:: + :local: + :backlinks: none + +Setup / Imports +--------------- +Let's start with the imports: + +.. literalinclude:: /../../python/ray/tune/examples/cifar10_pytorch.py + :language: python + :start-after: __import_begin__ + :end-before: __import_end__ + +Most of the imports are needed for building the PyTorch model. Only the last three +imports are for Ray Tune. + +Data loaders +------------ +We wrap the data loaders in their own function and pass a global data directory. +This way we can share a data directory between different trials. + +.. literalinclude:: /../../python/ray/tune/examples/cifar10_pytorch.py + :language: python + :start-after: __load_data_begin__ + :end-before: __load_data_end__ + +Configurable neural network +--------------------------- +We can only tune those parameters that are configurable. In this example, we can specify +the layer sizes of the fully connected layers: + +.. literalinclude:: /../../python/ray/tune/examples/cifar10_pytorch.py + :language: python + :start-after: __net_begin__ + :end-before: __net_end__ + +The train function +------------------ +Now it gets interesting, because we introduce some changes to the example `from the PyTorch +documentation `_. + +We wrap the training script in a function ``train_cifar(config, checkpoint=None)``. As you +can guess, the ``config`` parameter will receive the hyperparameters we would like to +train with. The ``checkpoint`` parameter is used to restore checkpoints. + +.. code-block:: python + + net = Net(config["l1"], config["l2"]) + + if checkpoint: + net.load_state_dict(torch.load(checkpoint)) + +The learning rate of the optimizer is made configurable, too: + +.. code-block:: python + + optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9) + +We also split the training data into a training and validation subset. We thus train on +80% of the data and calculate the validation loss on the remaining 20%. The batch sizes +with which we iterate through the training and test sets are configurable as well. + +Adding (multi) GPU support with DataParallel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Image classification benefits largely from GPUs. Luckily, we can continue to use +PyTorch's abstractions in Ray Tune. Thus, we can wrap our model in ``nn.DataParallel`` +to support data parallel training on multiple GPUs: + +.. code-block:: python + + device = "cpu" + if torch.cuda.is_available(): + device = "cuda:0" + if torch.cuda.device_count() > 1: + net = nn.DataParallel(net) + net.to(device) + +By using a ``device`` variable we make sure that training also works when we have +no GPUs available. PyTorch requires us to send our data to the GPU memory explicitly, +like this: + +.. code-block:: python + + for i, data in enumerate(trainloader, 0): + inputs, labels = data + inputs, labels = inputs.to(device), labels.to(device) + +The code now supports training on CPUs, on a single GPU, and on multiple GPUs. Notably, Ray +also supports :doc:`fractional GPUs ` +so we can share GPUs among trials, as long as the model still fits on the GPU memory. We'll come back +to that later. + +Communicating with Ray Tune +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The most interesting part is the communication with Tune: + +.. code-block:: python + + checkpoint_dir = tune.make_checkpoint_dir(epoch) + path = os.path.join(checkpoint_dir, "checkpoint") + torch.save((net.state_dict(), optimizer.state_dict()), path) + tune.save_checkpoint(path) + + tune.report(loss=(val_loss / val_steps), accuracy=correct / total) + +Here we first save a checkpoint and then report some metrics back to Tune. Specifically, +we send the validation loss and accuracy back to Tune. Tune can then use these metrics +to decide which hyperparameter configuration lead to the best results. These metrics +can also be used to stop bad performing trials early in order to avoid wasting +resources on those trials. + +The checkpoint saving is optional, however, it is necessary if we wanted to use advanced +schedulers like `Population Based Training `_. +Also, by saving the checkpoint we can later load the trained models and validate them +on a test set. + +Full training function +~~~~~~~~~~~~~~~~~~~~~~ + +The full code example looks like this: + +.. literalinclude:: /../../python/ray/tune/examples/cifar10_pytorch.py + :language: python + :start-after: __train_begin__ + :end-before: __train_end__ + :emphasize-lines: 2,4-9,12,14-18,28,33,43,70,81-84,86 + +As you can see, most of the code is adapted directly from the example. + +Test set accuracy +----------------- +Commonly the performance of a machine learning model is tested on a hold-out test +set with data that has not been used for training the model. We also wrap this in a +function: + +.. literalinclude:: /../../python/ray/tune/examples/cifar10_pytorch.py + :language: python + :start-after: __test_acc_begin__ + :end-before: __test_acc_end__ + +As you can see, the function also expects a ``device`` parameter, so we can do the +test set validation on a GPU. + +Configuring the search space +---------------------------- +Lastly, we need to define Tune's search space. Here is an example: + +.. code-block:: python + + config = { + "l1": tune.sample_from(lambda _: 2**np.random.randint(2, 9)), + "l2": tune.sample_from(lambda _: 2**np.random.randint(2, 9)), + "lr": tune.loguniform(1e-4, 1e-1), + "batch_size": tune.choice([2, 4, 8, 16]), + "data_dir": data_dir + } + +The ``tune.sample_from()`` function makes it possible to define your own sample +methods to obtain hyperparameters. In this example, the ``l1`` and ``l2`` parameters +should be powers of 2 between 4 and 256, so either 4, 8, 16, 32, 64, 128, or 256. +The ``lr`` (learning rate) should be uniformly sampled between 0.0001 and 0.1. Lastly, +the batch size is a choice between 2, 4, 8, and 16. + +At each trial, Tune will now randomly sample a combination of parameters from these +search spaces. It will then train a number of models in parallel and find the best +performing one among these. We also use the ``ASHAScheduler`` which will terminate bad +performing trials early. + +We wrap the ``train_cifar`` function with ``functools.partial`` to set the constant +``data_dir`` parameter. We can also tell Ray Tune what resources should be +available for each trial: + +.. code-block:: python + + gpus_per_trial = 2 + # ... + result = tune.run( + partial(train_cifar, data_dir=data_dir), + resources_per_trial={"cpu": 8, "gpu": gpus_per_trial}, + config=config, + num_samples=num_samples, + scheduler=scheduler, + progress_reporter=reporter, + checkpoint_at_end=True) + +You can specify the number of CPUs, which are then available e.g. +to increase the ``num_workers`` of the PyTorch ``DataLoader`` instances. The selected +number of GPUs are made visible to PyTorch in each trial. Trials do not have access to +GPUs that haven't been requested for them - so you don't have to care about two trials +using the same set of resources. + +Here we can also specify fractional GPUs, so something like ``gpus_per_trial=0.5`` is +completely valid. The trials will then share GPUs among each other. +You just have to make sure that the models still fit in the GPU memory. + +After training the models, we will find the best performing one and load the trained +network from the checkpoint file. We then obtain the test set accuracy and report +everything by printing. + +The full main function looks like this: + +.. literalinclude:: /../../python/ray/tune/examples/cifar10_pytorch.py + :language: python + :start-after: __main_begin__ + :end-before: __main_end__ + +If you run the code, an example output could look like this: + +.. code-block:: + :emphasize-lines: 7 + + Number of trials: 10 (10 TERMINATED) + +-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------+ + | Trial name | status | loc | l1 | l2 | lr | batch_size | loss | accuracy | training_iteration | + |-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------| + | train_cifar_87d1f_00000 | TERMINATED | | 64 | 4 | 0.00011629 | 2 | 1.87273 | 0.244 | 2 | + | train_cifar_87d1f_00001 | TERMINATED | | 32 | 64 | 0.000339763 | 8 | 1.23603 | 0.567 | 8 | + | train_cifar_87d1f_00002 | TERMINATED | | 8 | 16 | 0.00276249 | 16 | 1.1815 | 0.5836 | 10 | + | train_cifar_87d1f_00003 | TERMINATED | | 4 | 64 | 0.000648721 | 4 | 1.31131 | 0.5224 | 8 | + | train_cifar_87d1f_00004 | TERMINATED | | 32 | 16 | 0.000340753 | 8 | 1.26454 | 0.5444 | 8 | + | train_cifar_87d1f_00005 | TERMINATED | | 8 | 4 | 0.000699775 | 8 | 1.99594 | 0.1983 | 2 | + | train_cifar_87d1f_00006 | TERMINATED | | 256 | 8 | 0.0839654 | 16 | 2.3119 | 0.0993 | 1 | + | train_cifar_87d1f_00007 | TERMINATED | | 16 | 128 | 0.0758154 | 16 | 2.33575 | 0.1327 | 1 | + | train_cifar_87d1f_00008 | TERMINATED | | 16 | 8 | 0.0763312 | 16 | 2.31129 | 0.1042 | 4 | + | train_cifar_87d1f_00009 | TERMINATED | | 128 | 16 | 0.000124903 | 4 | 2.26917 | 0.1945 | 1 | + +-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------+ + + + Best trial config: {'l1': 8, 'l2': 16, 'lr': 0.0027624906698231976, 'batch_size': 16, 'data_dir': '...'} + Best trial final validation loss: 1.1815014744281769 + Best trial final validation accuracy: 0.5836 + Best trial test set accuracy: 0.5806 + +As you can see, most trials have been stopped early in order to avoid wasting resources. +The best performing trial achieved a validation accuracy of about 58%, which could +be confirmed on the test set. + +So that's it! You can now tune the parameters of your PyTorch models. + +If you consider switching to PyTorch Lightning to get rid of some of your boilerplate +training code, please know that we also have a walkthrough on :doc:`how to use Tune with +PyTorch Lightning models `. \ No newline at end of file diff --git a/python/ray/tune/examples/cifar10_pytorch.py b/python/ray/tune/examples/cifar10_pytorch.py new file mode 100644 index 000000000..90b39f57f --- /dev/null +++ b/python/ray/tune/examples/cifar10_pytorch.py @@ -0,0 +1,236 @@ +# flake8: noqa +# yapf: disable + +# __import_begin__ +from functools import partial +import numpy as np +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import random_split +import torchvision +import torchvision.transforms as transforms +from ray import tune +from ray.tune import CLIReporter +from ray.tune.schedulers import ASHAScheduler +# __import_end__ + + +# __load_data_begin__ +def load_data(data_dir="./data"): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + trainset = torchvision.datasets.CIFAR10( + root=data_dir, train=True, download=True, transform=transform) + + testset = torchvision.datasets.CIFAR10( + root=data_dir, train=False, download=True, transform=transform) + + return trainset, testset +# __load_data_end__ + + +# __net_begin__ +class Net(nn.Module): + def __init__(self, l1=120, l2=84): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, l1) + self.fc2 = nn.Linear(l1, l2) + self.fc3 = nn.Linear(l2, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x +# __net_end__ + + +# __train_begin__ +def train_cifar(config, checkpoint=None, data_dir=None): + net = Net(config["l1"], config["l2"]) + + device = "cpu" + if torch.cuda.is_available(): + device = "cuda:0" + if torch.cuda.device_count() > 1: + net = nn.DataParallel(net) + net.to(device) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9) + + if checkpoint: + print("loading checkpoint {}".format(checkpoint)) + model_state, optimizer_state = torch.load(checkpoint) + net.load_state_dict(model_state) + optimizer.load_state_dict(optimizer_state) + + trainset, testset = load_data(data_dir) + + test_abs = int(len(trainset) * 0.8) + train_subset, val_subset = random_split( + trainset, [test_abs, len(trainset) - test_abs]) + + trainloader = torch.utils.data.DataLoader( + train_subset, + batch_size=int(config["batch_size"]), + shuffle=True, + num_workers=8) + valloader = torch.utils.data.DataLoader( + val_subset, + batch_size=int(config["batch_size"]), + shuffle=True, + num_workers=8) + + for epoch in range(10): # loop over the dataset multiple times + running_loss = 0.0 + epoch_steps = 0 + for i, data in enumerate(trainloader, 0): + # get the inputs; data is a list of [inputs, labels] + inputs, labels = data + inputs, labels = inputs.to(device), labels.to(device) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # print statistics + running_loss += loss.item() + epoch_steps += 1 + if i % 2000 == 1999: # print every 2000 mini-batches + print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, + running_loss / epoch_steps)) + running_loss = 0.0 + + # Validation loss + val_loss = 0.0 + val_steps = 0 + total = 0 + correct = 0 + for i, data in enumerate(valloader, 0): + with torch.no_grad(): + inputs, labels = data + inputs, labels = inputs.to(device), labels.to(device) + + outputs = net(inputs) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + loss = criterion(outputs, labels) + val_loss += loss.cpu().numpy() + val_steps += 1 + + checkpoint_dir = tune.make_checkpoint_dir(epoch) + path = os.path.join(checkpoint_dir, "checkpoint") + torch.save((net.state_dict(), optimizer.state_dict()), path) + tune.save_checkpoint(path) + + tune.report(loss=(val_loss / val_steps), accuracy=correct / total) + print("Finished Training") +# __train_end__ + + +# __test_acc_begin__ +def test_accuracy(net, device="cpu"): + trainset, testset = load_data() + + testloader = torch.utils.data.DataLoader( + testset, batch_size=4, shuffle=False, num_workers=2) + + correct = 0 + total = 0 + with torch.no_grad(): + for data in testloader: + images, labels = data + images, labels = images.to(device), labels.to(device) + outputs = net(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + return correct / total +# __test_acc_end__ + + +# __main_begin__ +def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2): + data_dir = os.path.abspath("./data") + load_data(data_dir) + config = { + "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), + "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), + "lr": tune.loguniform(1e-4, 1e-1), + "batch_size": tune.choice([2, 4, 8, 16]) + } + scheduler = ASHAScheduler( + metric="loss", + mode="min", + max_t=max_num_epochs, + grace_period=1, + reduction_factor=2) + reporter = CLIReporter( + # parameter_columns=["l1", "l2", "lr", "batch_size"], + metric_columns=["loss", "accuracy", "training_iteration"]) + result = tune.run( + partial(train_cifar, data_dir=data_dir), + resources_per_trial={"cpu": 2, "gpu": gpus_per_trial}, + config=config, + num_samples=num_samples, + scheduler=scheduler, + progress_reporter=reporter, + checkpoint_at_end=True) + + best_trial = result.get_best_trial("loss", "min", "last") + print("Best trial config: {}".format(best_trial.config)) + print("Best trial final validation loss: {}".format( + best_trial.last_result["loss"])) + print("Best trial final validation accuracy: {}".format( + best_trial.last_result["accuracy"])) + + best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"]) + device = "cpu" + if torch.cuda.is_available(): + device = "cuda:0" + if gpus_per_trial > 1: + best_trained_model = nn.DataParallel(best_trained_model) + best_trained_model.to(device) + + model_state, optimizer_state = torch.load(best_trial.checkpoint.value) + best_trained_model.load_state_dict(model_state) + + test_acc = test_accuracy(best_trained_model, device) + print("Best trial test set accuracy: {}".format(test_acc)) +# __main_end__ + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + + if args.smoke_test: + main(num_samples=1, max_num_epochs=1, gpus_per_trial=0) + else: + # Change this to activate training on GPUs + main(num_samples=10, max_num_epochs=10, gpus_per_trial=0) diff --git a/python/ray/tune/examples/mnist_pytorch_lightning.py b/python/ray/tune/examples/mnist_pytorch_lightning.py index 2ad9cea5e..4f02c05e3 100644 --- a/python/ray/tune/examples/mnist_pytorch_lightning.py +++ b/python/ray/tune/examples/mnist_pytorch_lightning.py @@ -185,7 +185,7 @@ def train_mnist_tune_checkpoint(config, checkpoint=None): # __tune_asha_begin__ -def tune_mnist_asha(): +def tune_mnist_asha(num_samples=10, max_num_epochs=10): data_dir = mkdtemp(prefix="mnist_data_") LightningMNISTClassifier.download_data(data_dir) config = { @@ -198,7 +198,7 @@ def tune_mnist_asha(): scheduler = ASHAScheduler( metric="loss", mode="min", - max_t=10, + max_t=max_num_epochs, grace_period=1, reduction_factor=2) reporter = CLIReporter( @@ -208,7 +208,7 @@ def tune_mnist_asha(): train_mnist_tune, resources_per_trial={"cpu": 1}, config=config, - num_samples=10, + num_samples=num_samples, scheduler=scheduler, progress_reporter=reporter) shutil.rmtree(data_dir) @@ -250,5 +250,15 @@ def tune_mnist_pbt(): if __name__ == "__main__": - # tune_mnist_asha() # ASHA scheduler - tune_mnist_pbt() # population based training + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + + if args.smoke_test: + tune_mnist_asha(1, 1) + else: + tune_mnist_asha() # ASHA scheduler + tune_mnist_pbt() # population based training