From 48cdca843f2d06aeb356db27d48e01c1153eba22 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sun, 1 Mar 2020 21:22:48 -0800 Subject: [PATCH] [raysgd] Custom training operator (#7211) --- doc/source/raysgd/raysgd-custom.jpg | Bin 0 -> 30790 bytes doc/source/raysgd/raysgd_pytorch.rst | 339 +++++++++-------- doc/source/raysgd/raysgd_ref.rst | 8 + python/ray/util/sgd/pytorch/__init__.py | 5 +- python/ray/util/sgd/pytorch/constants.py | 7 + .../sgd/pytorch/distributed_pytorch_runner.py | 15 +- .../pytorch/examples/cifar_pytorch_example.py | 20 +- python/ray/util/sgd/pytorch/examples/dcgan.py | 289 ++++++++------- python/ray/util/sgd/pytorch/pytorch_runner.py | 81 +++-- .../ray/util/sgd/pytorch/pytorch_trainer.py | 145 +++++--- .../ray/util/sgd/pytorch/training_operator.py | 343 ++++++++++++++++++ python/ray/util/sgd/pytorch/utils.py | 229 ------------ python/ray/util/sgd/tests/test_pytorch.py | 163 ++++++--- .../ray/util/sgd/tests/test_pytorch_runner.py | 63 ++-- python/ray/util/sgd/utils.py | 8 + 15 files changed, 1013 insertions(+), 702 deletions(-) create mode 100644 doc/source/raysgd/raysgd-custom.jpg create mode 100644 python/ray/util/sgd/pytorch/constants.py create mode 100644 python/ray/util/sgd/pytorch/training_operator.py delete mode 100644 python/ray/util/sgd/pytorch/utils.py diff --git a/doc/source/raysgd/raysgd-custom.jpg b/doc/source/raysgd/raysgd-custom.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0145b099db7204a18bd747f0571ef398b78d4928 GIT binary patch literal 30790 zcmeFZ1yG#Lwl+FQAOs8U5F|JuxO?yb0fGleaG2n386>zvfH1f_1a}DTE`!_P?hJnC z+xwiozx|#1&pG$jU$^dm>kQSy)H~hp+iUgewR*`jPqR;}fR_p~@-hGfL;wH*{s(wk z07wCzqoAOoJbR9cii(Ez`~?O+CI&h>1_>@6Ha-O@6(t2JIr%HPHw>?6SZT@08Tpu4 z-*Rwsb5k)0i12d?zv1HM{NpAFXlQ5{=orM9n8cjaB~A^hzH__r4VA`-8_J)5C zK*o87OU)sHf~WcgmBt>Q(>EsbIjv+xJAvBx2_2V_gCE)pLLy=k(%1A1j7-ejJiL7T z0)kTSq-A8~*yMrn3|beSXwzcIlH*JxqJA34+snj4hfBo`w^dz_%jKd zm7SBDmtRm=R9RJBQ(ITx(Ad%0)!ozE*FP{ZIW;{4nVp+oTZe9JZf)=E?wy{UUtC^Y z!)|W>pbG(j^mnr0?|&!kFLdF+=|V(CMnXpYgDwO_S9n9hL4HQffr2Zciu%PKkA~Ct zIlg2}W<@(1EtlE}fsw=b3qm^Xwb!SAkoGsq{(FS^{Xe4YAB6pru6Y0k5(0ejkZ=HE zfUDyDyPxIKEsceASkL?Ev+tD4p8&$B59dz+NEYaZ*yt8}18C;M(K=cD1Qbw$vzf}E#U+97_)p{{Hbx|KLKI_*)twxlAzeQ zQfurlR`%83JOQe@-dp-f^|?v&@1TD+uMz#Xon_g|SqDD{8H4t4ibQWw`__ zx7f-Xvd{_pm zlNGD-z0@B@%Gu`+);-;Mo-2&ny?p|35R<3=hLR()m^$4h&OBlhRsGCui8^+~=`xY| zxzvkF+m;LP00j!$X1cYM#mP6gg2#M?IlIOi&H4F1^RGgr4b0b)<+EVNP@UoC5afiJ zCOSsRiZkd4w$uo4k`XCtrj-GvB0jx2$zDT5^b&eV`4O9@JuP`o0;yDQrWv$Wj$Pw3 zRr5N6eciG&=cBnij<~=xiM{8&Wil+?>}>2_LZxn#%F{MAcc4R|LQ_97^w&u|#2akB z&4ulq#^od;-G-2L!xun9^aOp;$4%DEq~Vfa$~D=6+4*L|501VM^MK1^OxS+sBTM9A zb<<)?{=Jv1AaNSUo|JK<;EYCj_X_f|RbhU2@%@DSrYu*-K6+b*0|)G)42uQ*m=f01 z1KtWUb+j~g&a6$xsKX&v8rontKFGnFH5LKV->G`ruF0gHT#a_>f?Pc6|m%W)F7|d&IEadZtF)M&k2{GgSF-!*73=)unna?l}CM}4YNhF z5PSjz@>sO?x);eJPZGC3EYG~j;U9>J5(m;-ZoxcH&CYcptjA>+cMosK+}PPNI1gh8 zkeGVk|6n+!gpHhn(#!6}PBu>6r0xiryd=oN7_&CB=R1d-U+``(dN>XPJc5EAsbRKV z^VRbPNgjvpr0#wKpWVXpuM_P^$V1yF?m+~tZDlH@hRq(Iz#Y4!RyP9g;_*=j5$@kR zCh)I9bIoC?8y^5p)@L z)Y5d7UAOIWm#*#+aC%wM+8Jc~+6}d#%m7pD!=y0w$!HgzWuhygF3)tQ!SCpV*te6W zhPDR0PXKmZz)`hWaOpaG7}Z>(^W(#FOC_OBI!Zm+w_X$rK3{wn@;zp1VNI)pr3x&J zHB4gOd@%%_dUDjRBX7yWosTE#n=qB3Ie|LcUGU?As=Q70&c!IrNWawE5dswL4dYn%uVbji^Y@x{^{_g}HQ5&l zs!VkuUCIWEr323yGE~noE<_wBYF5_U{EcVGl&rP3b(#ay`$4bdA13K}9|d|_92^Ge zV5tR}y9P$uIR2(}x)B{^HK2?&!F^MwTLHqYKl=Z2dJ zh}yjiV_{e3N?mGUMA|95SIS**%Y$?}jC86u`YcKlCCD-InH=qi?iPAUmrAUyU>ojT zzt9zud!LM|CDxI>n1bC?37uhcdII3>ijpNNH-j$^q)KlB6kdJ!T`cNAd?<_%X;t3M zZ3eS166`T~S&i2ibtEfn^8|qHfdLJ#uUfKUYW-S=2gjv8dr-FSttDW{3ALidD(hg$iUlLZ< zW|Hv7n;lTonfD~9+M+MG5vt1ElHZ@))G0!ej^6w%ZK0of)_@1oxz%FpQ3JRnuHvZ#%4WK4T3p-j>&w-a@mNq z6JbP};!Y>ox$VhHc*!%x$Jt}%YjbGge~1%VN_j~ZM4N9_$lhg(+38GY?dW*rwsl6= zN*J5`F^+%UY~hzcOS$g z1=PfsB;r0lkJ$V`{>Br$%>xS7Ot@VMcvEsb)JUI?2^H+LF>y~t{zCAP$-6x%aExnz zqF`5@-B(!4lOgzB2l^6c4?nvaYCL*78C^_lMzp+;&vCl0cU6BOY_eT9cuY^xr>u~s zOZ}O~cTfG_sud|FtLLNY;uK1GEw<}RWU5h;E#QeqjE-gRhDbe2`}h^I$W$Q*E=@0% zvM3yw)u=j$JC|H8L)fGHKu&S_>Hz3P@5uqK>?uo6Y_)R5K<%?ng$R8izU{~JNhjTk zFq_*;5On!Yv$>(}2~cLuW!+oR16Os2j9H<8|El#v+NPu)Paic3scv(g08W|)Est4L z*HLGsa!-J1hPx~u3TgZg2!CoQ+)seBx7`<1eKW(xDo+4f3`2?gh_7S+-sc&nxl*FW zBzS^C;atV#^S5z}0Xkw4gwL{K3R`V!NK;}vo~L3Dk3mm>2#BRLfCoQ|Ie}Fi$$rm! zA4&TpAhKrNo5&!|^LxXD;!icsCMr9?QQxRdmzI|az(zBWM>MavZ{oeDv?2}RjMnG2R-x^|GzS(?aqO(aNrydE^;(n*11%ddr;|c|) zXwSfk(yzLT>JTINvCSxyXQwC|pz-%n|33PE^&g4B;0d6Leb;-VnFHF3dl=X8xV{Yr zEsG7*&^`e^rXPg@obln4{8J;Id;;`!5{-dq^>`x0?!!5}QO|g){@%qD_>FGDO8}qK z3IZ~h578b?IF1#37z1iHmqTsu{hgw24;jt~YF=37N~2B7(O^>Y-m$#{jm*=A4Pk?~*)#FR?ZQeCluPH2*Kyxh+jKH`Ua(zd>bKI0rK1+8O5v?ruDZ4 z`OB{T>w=`>gUr?7gp^0O{`Xe!p98$V<(ocujb{I=i2g?g3Nrqi9(wiv2{rT|wvgn? zy6pF|GNZa6%_Y_{9HN7>@CtsU__Cr_P59=%4+k4Qe3O0xWQN}6)!Vzdja$kVE(DY_ z7dAM-BL)pv+5dY@fxBVP$_>gc)#bd(2Fo17ijYJGk&XxDc4697dv zd|aB~IE`e`~e!=P>*k@!oZJQs7s#G23yGHEwxy=a{Xu%j_Ml zu|u?-M6bajd@q^6JiBi51K;GNzfWNRXPU_B53D-%U0oK`41-ipDYI;AnC|;s0zRM+ z5RG9?k7Q~sI2tY;*la4!IV>!5!4D@pZ@H^6H7-Ov2Gzt>Vi?U-@8#US!2D#P`budr zSiVpt^$(_6zb1_NYBwI{_@I3x*Q{;V{bjg~aNSFwTAPrks+tw*RAy~5p>++0IT zLF9fPV62AOBDez=$F^gheC~u69+{)~?1dFQolk>n3t`iZLS^O0ZytyE2!)#4ej=h4 zjwrhkP%NiXCT76Wl|z!U>ZVsi6rGuce{&s&j@lAvk8m%9wU7^k^4rOA#OCD!{JbZ| z4@Xytkv{BEGV}Ep7gXUy$Xvb|J$BrTUpl+k7HXWaA|QjgBv~wUraqGd^#yGoii9^sMz~ z6-}2ptBDT(*IES3{tZ_VA+?Rd+?nFh!Ta7(6@%q8(1wMTC|T_j;JfX|=45eFrS9m6 z_3mTv1mc>U6J1REzOPpKMH4%&TaZ-s9kn0vg?5?g*_IHPgHrRYp6{%UXamm>s6#ey zepfdkHQ!A3qR-M+jg$HjyUNRC2}l@(1YIIi)g+pW^s^>|flIj1XsM1Q^{ zF7P7AVqQc=eWGYbcW|e!QgP#T7hcEA^-F^3UD4F27WP35_lFpZgJ4{`l`8y) zYN>_j_b9sI&qu?+`Yy(ixOA>b&p+)X3Oed`Q>A}nSCkIW5Yll$Mgn-P zseY>?i($4TPCPji+~8q;@dLg51v|6 zHo5~`6ov84Kj>odgiX?5WlhRWP4JDZL1_xm5NG+aN>~Pa4N zj^m8l_?$Y6BjH4)*~OgIO?klW-P9R>Q!adm1+;NCOlsQZw&DGnZgVy--|L%gO_r`O zw}){4(7|2@jLtNK$k)S+kp{ZK3z{5~dN9?j#owN$68Ba-r%shXEl8q^{jjCy^SXp# z?X2KH-e)4Aho7iyJ&KEgcKW+B^Jb|gUoT)3xwB_CQFP$|)?wdS0D$0a{%aG6ZJ%Qy zkxM#gXzHY7#?DQzYr;_c`sK6Qhl$qV?IC{TE0X;~sa-+kOknn2?ToAD41~|x($Vs3 zkNIN8OXH>fJ`pTCZy=WxuL{q>uPs)({77&+9}w_8BQYqUzeCy)t8~L*zn#R-*#lhRT=GYxz=Z1+698JWLpR5H= z{DcMyZ)L~A{f(LBl2#88q&$nl+&#mPLrtv=EB@bWz$#YZ~o~RQv_h`bc*$FiId; zW)VkuyxP2tZt5j0x^rDgMf;AC9`*@|DCt^wZ0U4_&%;r3jsB?oIyRpmyKd1*^BaL5 zL}rq~nuJH*_6reiKYd2+`fIv1RaZo16bA=^oSKueB$SF$9hp~CN2T-Z6}rPN)VSH0 zT}Q|^rhUZ7t|fBK()piuuQYFGq9K*q`H2`6pMEOJ#!89}I3O^@_$x(q@#2kw97ahz`-C-@hhKr%mgJf}%QGVg$T-!3G7GT|L*o7Y7l@TC9V_;{xgmEw^YelTSVI#=7iXOb$XvxkyiY@^%Ef4CgusCpYtf$Xb0*I84Y{H zh$6`fWmyq^o7Ez=dKI5zMdT4)H`ZzI`do!x;Ky4HfkA*r%Bp!CRA+_IX^BPYlnyNuhMuAF9^8sq{SWlvq^L zZkJu$y1OTCB*oo~2psI(dw^*<`NI}hXT15uXUqEcBZAs#XlW!Lb7Cqq)u(XU@dqw| z=zU+CsjZ8i0KnE{&w^!T2;Yafo3bvP?;|qf{rt!2Ry73OC4z^Kxt0i)Dna12G*2-m z(;siUGO>yIQ6tfYDVO#$9zu`Yy>)79LOneT=^7&m1CZ#xVwu7KYp+<-VwwFDtfa*PUj@dpM!NwB2tX)D2NzZK8X_4HC8t!%R#5gEr?;Z`;&)jFN=K|D~OaAc_Z1+N3i&&wmd zErBwJ`6KsMsyS^%hki@?TI)~tU3xE}=X4WShb=_5kiKu-uZxdw$zK!fCla}!ttY4A z>PxNzpPBpL5mK--<0D1nkRK5*wUHbf80P6Ix|Znb>2<{ovhIYCr?A9!BD6$#G&;M- z&sX5(4jiNoIB6A2xn6BUs`pDR*}Gi$SiiCzgf1*pGPTQmHMabXjO7hPxz=hsSx#Tu z_a-j+u&DETXu%lsDoJOSv!Qng1s8SgUzWq|0rt%7T^utPbq+tluXs1JThJHbtW~QFC{< z@Q`U!LKDko_`Y`WSKuqwT)m@&cGS#%{(Z%T6%#RIoBZUNpsV-kCy}Yf!-lN-VTSDF z+hzw3{H3ELue;zUlZ@MdvRqrvesZnx4x`+9&kb5DGcp(PcjdHpZZPx+eY z+s~e5Q)+35g!FJ-UKTl%g#sIx!D^OQ{EaLVxMX2>EqFRO+Pwwxc}kA%JX1RFKANct zNbk+Ilar^|D@~PuTUhmMIzCKt)=OsDAtfth^QyuT=u=ac+p!bxB@kWtSo&Z+`>j`T zam`tu6Dm7pQqefK-Gl#2Dw>b+%jnqTg%aCki`Ar3q;$j))n^wH#1=e_Mz9JoEc9?- z)IQI$m<=rF`~9n?a6R!hS;$tE7hAE_{7g&6qSf25??E5VW*!o+zyBCbP_bW0ix-_& zPCcBwXfPOHb~%4g5i^}XXbr-TUz8|W1};;cvXc7Cvht4Q%!~(5zr)-rxOpq8>L=~Q zr^PGCTmB_yVQ)g(tU_F@Dtl6RstF?_wpUExYnpwpO}uiIuAbrk;E+P7$BN4S(3us> zIIC?*b+O)^_8Z%aXVwaaG_biu+*iMG#7+@gevcYRJ5sXOHq@SyZ|xrhY<)~J7*^=n z8+iU1`~9WpO>uyMrEZr(I;*YYi%`bW_rp}!aPPf87hZ(3VR%biUIg5<{p*$fp;pGU zJj0VdxKn+uBL2TQu5X}qTK6c?DEr_Nck|*>;a_aFGB)f~;1tvz1Lang_S_DX$2eOU zcF*yO|UzFS=EtZCN*8Iy!g2M&FQ3GKWj}$|J=Ie{hhc zRCM>SoNco%RrJE4P?O3x*V)sVBDH+XFO)Hl#*<|sLhy=gP+qtxSXNnaz(k>6efci(?CVth_>mVIGo=eIW2R371#qQM#j|l+!)GxYK_rUP1eUKW`Wd zjJ`Kc*(P9`^w@O(7gSH_b9D(O#2_bJ#WBond_Ke&gP&;an}B`@*LkDmdx@L|{XEbn zgN#TA@tdia;=8|zDL5wL)W@9JD2zN8tYZeFW9hC6{+2$ORS_O9udJC^cd zJ~n;!A&^z;T17Zg5>kg%UlMGe>{*nD^Gwy~Ik!-;ep4IV$OjjOTjLtr5nmV4>Q14A zIz zsaMjHCF#!}f(7QR2&fDhYAQ2{M-y8Uoexm2#>L*6sW&C7jeg5&F-x+WL-nX08|03Y z(Sqt1OKR70O?OJ)w@4H?@v7gN(jvcoZ^sAyYy;s%*4G9OtsNz{0 z!1B}*hYD?FMUgzr)7G1yUOmP~qw@LXK&Z`!X*3znq&m`wrRn(6l!>FdYOGgE8DAF{ zk~5CHt;`@i48N)>fwe%O07k-&xTv1FFuD23t_mf^bYubeb7k~$CF+?3S-jkcUV37r zdk>SXut2f4iZy0;9_enact1^}+*}V56H7|Wa#*J|;^c~2&V0bsZQLH``<}y1E)v-R z*KGl4KKG%Ddh-kqn(grxm%d`e837VM01JSWL6B{~`tWUH1Xy^nEtf%-OE%WXejSQb zkL_g?bDy;~RZ#Q@()tsQTWy1YK z?3eRbkuqTn491+X&9|iGw&i;XGj+A&OLzpQ^lK&>-rMQPtZ1lleeH$p&0HmxwIQy~ z=C!EN@5(A^{Wi>P2ULXBkG+^mcdu-n9gX_&Ka+6KaLHgR?V*|8i|5{zv&r6#il8I) zJW}J*6@MdiniszG<6uD;<2qdIjQpN*T>`sYQeflVS61}BQY;kiH{vLK)*>T2;6}lr zPKCbDl9Q@hqVx5;ks(j2IJ3t~uDA;d7OD0$$|$|Y7Vy?m8C%I=njw?9c6ii1OWYWh z!EDYFV|R6ezd}wB_#k2Fi=UqIxR6fb}e)X(Xfq#-Dfg z>^}NOOI!D@MQCXxct!}lKP&axisbS9j`ZrG+lL=d$_G2aOhb%mBtO(up~PS4*bwzOweq(b%aa-Q*H9|*vgo^S1+-2 z2>}9~vKVKpCFRtl@F;W95kEd~d`y1dtfY6xmPPgT+$){kBR!qJ$1$??qQX*Za`}_4 ztWA%nsY?1e9+*MK`Uo<)H#HnVA(t^F6k3NPwlCaGLksnP_6*lA)GZL_HlOmAdYiKSfH1SsF~7&@ zi0OGQfXwWJuwAgf!hI;#(Ku6Rk*JvkDG+&)0{Um+T6DbSzzUQo`ZT-UMm(gDsPyOEy!#*0$+JSk*Ihr zO-~JK*fw)V=cg&h!}Ll%WYK-~hz>>ExLq&Y1M_BPEz$`FGz2UmWiFEFdrL9BipnhF zo_aMsA?W}fuhsGY&{!L|M{LSMF^`<`nPTtxS5F62g~^)Jz$Q~mf@Z!9R+dU43)P$u zZ=62X>>alfLFN>i>vx6yUx)RDy1-?L7{BxuS)^EQBFIna1`O|J?Zibd6>cB*)th6_ zu3W(DQVA}5p4b_^R^`$!IFxs-u#}2<;g0`@i2tKb6CSc1AUsP_{P}>^eC7q7kxud> zQS~|jHBaYZRao{q1y>j8Lz87By`q~}T?z)O00Yt#wxQ!EHUhuSZjGxJrBru8gTV|{ zX@Z83@-^U&(+Yv5Ya(VXahYn4?NOP#Sb$Ts8-sRGcg(UpJ}@_THf80ef8$2t+6FUS zdAdVw14_cbA39=Zy~HIq{Msx^vZ^6nr~VY97HnNslfP(1?=m^asf@r@`>_Y>t)0i^ zt}fv(%PZ2uqjb`y#bt{Q?q5w!{UO(g*3G|VL_fUczbI!#_K&u+C|N>37}Pi%J955A zRzC668P3w%MHiWG)NARA;v!H(5BR!fXP?H^t$JdUQg=1qJLv38XP2qZy9{Do^1I(| zK!*T4v>4?X;=l;6Hteh(;x_FYD90TQdOg<*p0)X0%)&FmYm-atn@eBgNJxbdW2I*2Q&UOLzAj7 z>V!!lVtLr>f|oATbzr|27IR8!XI?6DTgR{zYB*qjlc2oZYM=>?GTyGcS!@GtKInGF z%nJ~Vhc-rkjLd*!0IxNF?GWicG-WiNlH2#Fdi2&MM-5Md1kdN0w^fqzXM;!Mxjjk~ zq7xc>LhKryz~L`yM<1%(J(kt5uwEhp7RHzuVkQcd&U8mgePqKLsC>%>{veb9oN-*t zeYKw@LZw@iu3K|Tn3Tt!?oHUjHk7dQ&7I~|!7C6UmWR`lCa7|g?~(XdOLKzI&*}zW zMW9bPEsB>oRR*0`{*JV}8y$Hl;}zD7%59DR0?ShQ-gfrVT$mUF@U%9#QL$>_ZNL$! zdV>ctZv%fQXSuxc<`Q>nI(j7rMjlRN`MoHS+o^<(w< z8TQvEknW9p>PTVso&jy`*qaOO;qGQZFW#f^Ux>>eM;`9)oA2jD9JQLI3-p#|-9~^+ z(`&Iu{pgwNl11&om(PBH&%51KkKJPP zHh>vRjVU#ZnG};2;a*neznA5Zic-PLyyFz%`NqDs3l77Nkb&jDUWrTcQWbgMwiXT7 zL|bXTapj|fZuP$*or8B_K{ae?s;jRHoL5WkM!K*L>0D}GRCxFHOCNY!IcgL@2Sj9;H3BRE;zNT!@F++3pkCD4MLqN<3 zhctlqAE}3m{C@{#)LH<_Y`etH@G9@cE|_;&l25VkzaO+=KfLhaWh@dL-k5iZaCf); zdd($;Cn!b9CxDbO>~8*Dlwpsz4a4SMs(f}?KiU2Fow9~>&kc#4|1V_(e1rpTEi|ew z#k#T0sVBhN%NdvE({9t}QS1)nxa6m2Z#eod@fWZuBh2qSWu`7X52@5EDm+7|&ozGb zXdTl=a`hc9`v}dm3|trn>$(O$Qo0!(v=^^P6f7ztTiY{&S8-UPeump@`=EY(vEwBK zg_gMGK|8^bA?`U=9PL-Z6Rujgve8`eJMO!qQlt0?bK}x(_sa~^)#VY(w6)b0Q_fKr zjLpfka_ei74<;?gI}A(`$Dnm$ zp8G4ILCUZwN5^9Q9mpe5YsK_>C2*Dvqma5@^`2Q)Qbi>*1M4dEObp;9-)#HL>UKMn zw$Xyk%u*B*us}`67=Uk#OpdMmEw51oeVy($t;#fiPn2Y;I-9(oaZXN_FEx}-T5YI~@%7&K z0f>m2@l)Pp9rq^oo9FNZeZLgiA6A&vfBy2j+=nxyzHc9958+A4K2&RIx9jkG7cbcp zkml6Ibn_$}4nkX#CdTzOhTpjlBR>zj&SP0$d0Ld!vbJ(@&uDcgOD)|35n?->VAA7y z;ywyTu22v_MOk5)qBX1DR#5qIXKkm^%a_3C?J_5wDTpK=FGdRi>y5#}h6R=eWGCIF zX}*O(*K4o|7nh7n#>4zZb7>EBlmf&NjdcrWFy6gda}&(C!0MJFektOf$f#$vUXwQ> zCf4(%V>P)6-k~1tI(IE8p+5$g;$ftvP&SzRWN?dLbtrrZF=hnxdki&m{L}XLdqdmx zpT|abV7Q+xQjdie#&U&4cD5l&73RRfaMa;m*X-uQJUlkrP^8y0kJNOd3?M3ntI5r6IhF(1OPOw*t5Pg`@(@-0eMYR41aNH7DlM^7m;zpPqH_$2e zQf+ED;R}9KeMo;as57xac>S`q_l`s0Vaww!ncVJ7sbw3^%6pi1#`0dH>HU1}Ubi*S zHhN;YwfkX5BtB%sQwn7dgF5@b*wa$Hfr>& z@1(bD*C`zxW?|L%Chv-fMo4pd(@yHW12*yzt$+W%fA|<``MCBIV8Ibya-#DP_7LZ} z;hjACI5z7QzD;RxVC?E;Rt%feNtZh5?(R;*yeLVS@wqrR`5X#DoV#a4%E)SN>P~=+ zyZLvB5=E}4Hcq~GbdR9S+-O;#!P4Lt*Cs&OK7%5bn6EFUXw^6HoD z3x{xVe)5ZMsRG#i;{==J)k4QLgw!Byq-wFcAMdS7>v+#%Wu!Zlf89rjndN8eX)uus zui%?x-ePZ83wBMO7$fafVN@5JBnALL1*M_Iv%Aifh?20mZqk|2hDZZE16AtNaz|F? z-KN%YkXTAFPMbp5b+o0 ztni;7GhT=i!cSW8ulDmVvZByIy(w^n?c}YhIPT~_YVBvZ|5<_$AQ*-PJxaP520aGo zzjpja4bPsZCa%E?4L_jdc9{V!jgS#{KH?p*bR0G_t>1jZv=YTm?5vb^wYG~uPPYgT z^D$kVh1UeNx(|DH<0HV&b}ui=ah71Va{f}0>)IJxt$Db{Mr+~5ZV*-&V9m3 z;!nG}#h(m%xAKp&&j)Qr4k)qQhMh;(9zNXLL>UqW-LYC4n)8uCD9zSgpaKspEk)q8 zWXlrE$=&Odo0+PzIod@nyAcn+=ey#B`vYYc$AlJ!ggh(wmdlinlaoi*Mm{u)IX19& z{hw|*Kl=%t>Z9H>t02K^8sN3i^l-A)_sgoP{Z%UJD&c{gw0nXGxD}xTO0%o93 ztH7&*j>e9BX1j3|oxs45PqK;?f580#HE`DeE5EM-X;^^I&C7E*l0w55p_tGvpLVVLDvV4ToB~Ahx z%ufb;JU-sKD?Ugb9Vu!zchO5biRZssHSwPcZCYc;+&kHSB(+xVU#lB;LU>Wd8T}^3 zXZ3ShaEtf;1bh)FPF*`A3_rqxlAoWh4I61QO#j-;;C=^BPzVJHBlVYxQCc)U9EY>M z)~7$$J{haws7<%@y!uSvf0MpNLfl(-d2!2#)Sd+cty{Db!P4K6GKQY*e^d+uy5&m- ziVR+B*NNj_wwHBmcYt$Oy7ffl549~pvoZsEKKo~Vu6`UT*W^%wT{vF?@`LVOO;4a7 zPB09}BAsPG?5ad^xN?KXztu<99Z>AU&%+(FOJI)Ol%PQBh6#3OeXJF^z-6dXO)KG@ zjpS+~b>P~VIsfXTXs3 zR@uI-bS{fk2-0X%X!6f@EZGoBL*37=%?tWS_>reW?#0#zkI`X#>R}RZ_gV?#jYz#9 zXvqcfw9h)7)D)aVAKXSO$Hgh_Szyp^Qle12593CYB}d%5uYUcqB-nOWz=>V;L2+#F zRoxvis9Sw-*mf z0LI?QLw6uB*VgvY*Zs+q8QHF%6wQovr`cjx`iHFS{HQ|cJR*7bo@qzgAq~DL@Qe4VBRs6fYk(?|Y2e37Z zh@9Kj>!ELeIgseehDh~?NED6*5R~n^Jum|zT9dv!i^w8A%#a|L)2^MSK#^xaEct+n zxwi6R#lnDTB8(x)+G^se{EiZ&Qs!o;A1|8hyO#Z}~s;OxRiAPFVOJ+!I z*gw?Y#77vKw{)%WlAGp{^$BCD-l2}v@7!{BW#vdYA$B?b*sNz=(hJ4L*cUYWaU0R{ z-4*B*(-*vEg5z6fPa?*|f7K%ctb*&9i6?;ksb(>DNXs!niP>YBCW!1Bmu#8OYT+w0UinZam`NEbi%4U-sOUxRWu ziHS-6^fd_ zrs@1nFsfS#kueZdSyE<`Z>z$GlX05@H=aHaV-4%-eMQBlrh20Vkp-4fY9qz?57~vQ z0Dh#%);u3JRcU;{pW8?HX|T6B{fKH)?VlgT@x9pK!4CLW5S3kN%0F3VpV$8g!j$|e zGz0jP|JbOobuR_4ZW3aLD=Nth%ux6b12FBu18@;05b*bPGG}VVcn0K)RYV~ABO%}S zC&12|w(BjGNBSZ!b+~s*3Vm$sg`06LKS*Xkhh()t&Rly)xAKo0qo_>Knz5V8qu>4$ z;F}5DvRT0ESR>ZNZqdshZsMK*A?0ZWW$Ay1>6?T#T#dl?Uo0Bd{#tR#uSz+(S8JK%QFpP?Ilw?zN?^r*iKzuEgU8t8xK{^w== zf06J0aG8l)qq>CyEjF6l?v?gL|X z$I9up*n`vK4=uEL#ebkbT^;PZ2NnX85#PEHOjwGts5e|dw3Q}D5T8m1_}NuZMarv#2@h3?33 z*;bBpk266qk4O%FsZp`|m7{%2#~%e?Bcobl!c19vb)?QSAr4AcbZ+yI(8u>m!>kOEu&96bphM8EgcW~9g z_0ya4QHaCsijT!=iGkxS+=L?bFtJ_SE!3NU=SS8pvbJO+=|88dp%U!Ws_mP zbMx9xKz%l`3V&Gl^S92!C`x=?5K@u+WS7%;v|8b+O(`cT-Jyo+hRUCv&bF+$kqQa{ z+TT!GbWeJ~kcds^#yhLh+T)71c+Z0=%_bzELXMHsgKh!y=N-4g_BvL|x?9pUX;;$j zCTb|E;y1xd;r8LNOzI^|acB2sf92yUgl1wqKZEb~PeDO__{3=L(-)bgGNF2Dh5-k%7|n_Oh|6Lqnnqly!1BaL0wOxBv*&=ab!;TbPe3a^W$RT1)8 zR(<6d&4&;Z=c~hfv;flNaU!@Do;g-&nH)f05DGJbGSi2|*R~G zYGRGyEcQkt=9zWVsnJUg!z<}Cl+{;bTlp$LZ=CXsZS-y>(8`+9+RECs3uiydX$E9k ze`htK-2IB;=P2N4p^hp*y=Zy{{IZb2mMe1dB5YU!>UL;DI@MQ5w~J%?K3YHoRX0&} z+`R6T;403P9AaxL)Ro0Q$n}iQ9Ek_AG+f+S1XJSHY#=n7Sn1Q|sz?`W=_AK$5~_lo z>JqWp64ZsiM!Xw-oMW;3O3a!?l}rL>fs5mDXriuG+hv1QxV7!R@eo5Dc*ZPnorF5x z%_K(Q*R8)QL=Q^R{J<`0R?`ykofA3j>Yb%ol-)wK0Dxk)Z-Y`$>$P6Tr}3t@Dek79 zvqG8s{|0$*>6vS@oJ+CkCCP(KuouF#9$)%-?W-Jss;qHrC^)X|A+G!2?Q;Ezn)oAI z$S zsILj>?K3(bGcUV2+vg|LU<8;fv`GCJ6yg9)o!@aw+KcUY4|A^r;`MgqwxSMNqz(+N|F8DGJF2N| z-4_G}6lH^ebWji!1O%jmh)QqLYd~oMLTHg*f~cTK4I5BE5RgvjN$8~-C?>osX_pDh+gst8qn1ySEgEz* z9Cp}XE?MImX?9y~9e#a^#E)gKH{_zLxScHF;>LQiF_iu?(|TWspj|7X1hF^U??dZv zE)}d#bzldMF<`4Dk3`S-7%vw>cV0HB(S9frV&q|`U7zqQsbiL^ruzLS&8dO}85usi z{KI@wtM5U7`FRo?1$IwZ`Fy-Lg{y-U*?m2-vBnJK#eb7)iu8Y>B#XK1V;{RPlOAy0 z0KZNA2Cy-ZXSiNwL|qDX;^yRqBmF*V0dBvvrO9NpPxW~cTJ{Q*VV&s1=6&AK{;Bpe z@AvFB3-0}CM*N7|n(tu+RWE3WdU+s7Ok5<(4q!xt ziCrgLW1Ud9N^Dy2l7n?AjLOd)?T*0L-MD)+-&z34Zcxul1Di3sHwl^}w&=k3_e2ns zw$Vb5n}KR@$|bV3NND5dU~kHVsflSFS{ou?48Mb30Y&5<-^f-8iO(;*FR;KKZe9} zswhFK@7yU(BmN~cSaVdmurI#@h>M)C{kmWK1xdA5CLsQKEVUib;NY-nS30IIQ6d(@ zPni6K1IPY)Xb#_Q1jojbL;((ViE{vGvl7l3li@Zzt#PRD@$hRx6W3~;r3NU?S!zA< zrAcGMLbOe*p&D0#(~DhYvFhhN1ftf?)6?gL3e)X+Iv9+RKB`;jjsB+Y zMKLz`cH&Vi>Geq|>5>LXQWL66TApUzDshm$sC~NPxT|?SG@-DB4lva(vc_|KxJ)Pupkl~8B+x_hklr|J7oVnr= zEX2|yzcazsBR+ZPC=arF@T_gP$1XI1;f739LPC{K_g>MB+P=I3H4fj&=$Rq4LuThs zE$8CS2a~rqW>Y!$E|(x?g%LbT@|vwYc0;AsN4kVQ9KpMBcvC0y!m5z}2L)MqFJ%fI zu4@uTD_01bQS0DtBw^U>3XPL(OxvMoDchLcjIeD#oQWkGVs{nkdSl!Y7 z>OR>nh#55xNLlF*uD|@mQ=PKeBg<)E5Mo=G@>rcrz$5sggUylSYOh}4k^4RO4O@={ z4e(v~-B_KbP>(>a^u263Xh3vJR>)h*QAz5p#_8lSA^$(f6x)!iFFR-Z0zCB+OdD%P zUNqL%)zqZ(o0uXT@7;%mPIR3Kc{g$sS{61g_N9-tD2d&}Ege>fNF=g!>bAXZ2fJ9a z)eKj)k6fHdoWZ6pci}4CcQRPXs}B^SV&n$$bQ;K8vMfGf#t;Oyx8K*tcinC~+c+)3 z7<84c8%{w2P@nTApeXq+=x$w4jE`}@)jV5*%0Q_k-p#Jc$!Mwt9QU2Sotm0*`DX`g z?^4xNv$Qj#v~Z#Gtrpie_0meUB6QREI^P9dgWLw(ZC0#53wjcr1NOZ=gFFYcml*7m zWX;WYyu6gV*KpY>q@BgxjRcV1ZQK+8{+pul1DiyPq6J00Q;S2r)UMF+=rM`!JZ9O3 zQ4F)kEEEapi7uTjxQy`EK0l?L z!Jp(+z+yY)rM?lx9ps%aqM+pefK4(d2U%PR(opc!aldKjVWQ0dYog%b*44Q!7XUXQ zNYfC{{aH;W`wGBL2T8Tj)6AP*gK83j8BH99&qq=0khO*|5<#ejMpRC-c|2X$QdSlz zA4_B1>Aq$GGItT=SLb6kbQi1RxFbtZFi#)9rE>UYa~ZvZm z$%vtlp(>@=DP$VP`vim}FJzYzq#!_1+@;yRNmU^(BZtr1$~^XD;)47`7Db7aQ-kB} z$>yUXqPOZ(4eAkuoA1{A$7|^v1$Yb2@R-Tm%z~bZx!GxudiLdEO{@Hae$BC((U(w5 z&05Q7!71GlUUFnkxe1McLUsY)%VlmC*I7C&W*EA~S5sjKgS!<63_I z(hnD-ER2bRO!xXGc!pTpa9%iN?c$sKo5-KJpK_b5;OG&lhv4fe_sdoLZ zV}G|J1BYs`W~Y39os{$K@~|E~(R1QFS~74S-7$B^N@_a5PEoN4-e#d7l{_nGDbMZ+ z0V@a=gI+&oiMw#^cq=7K;d%jQ*?Cr*%T?}=wf?E+JW*hF�x(HT@`2XBwfdkUmcz z)!Rwgfw}!_<1_sbk^9!Hml+hr1=aUI>k2EEE{(Y1ejY(yNl>{Lf%&WM)eHIh=Cl@Q z_`h%lOVrz17NB1wbJ$VtrceBUMGg&}alorkA2>|{og6AFoU{4uQ14mI$tiv(TRvB7 zkTuFJ8PC2PgOkO3BMp4#q;mXu{j*%F*7ZKhN!!%Rb7f^gAMYn7mWoh2*B@LmS+&_p z4>=S_uOFHi*9b10!f#hpCG?eL)75IYQtYkPLbZ$rUkaN*1&3P?OwqUs7RU7Unj}ZV zaGr2-*}nUC(Fql%3_3+Nx8+*Pn)Yw0m~VXYQ@*p^1Xx0b)!36dvX z_8ITAy3KRhNqO2+T9aTmD?jTU?s5fbsbrtSgb#k3!prxtR+asZpE(>Zi&BnyV_AvAjrEm1xMx)8gUC!Q%-AlN)m%)0}G{MA1y{k>c*IiwQxrM3+)pA~_ zbz6OqF*ej&E46Fggs2NrZgQ-EL}{_`D&6mp;fSN*9jdYLOmnDY%)GN1))=)Nd-g5> zeC{8+Vk0{?m`)FXwRGA5`A_Vu=(X67QjG0eRy^(p><;i1q~91H!q1qGrf!<9|Lg*$W_cy9z<@pCQU1y-)=m_XQh< zVDzRyp#KXejxW7Dw{_>`ce1x7d-XvHVQE%2hDi(_2*i&7>9$f%=SBfkRoO6uYRp}Q zv6QD%ebi*v8v5XucT?^ZZKnk5JarT8KH+8kR#IcOF_CNHTzVzva+gpHoTE+ zV{Iw*wSCD&_;|}%OzJ&TUoFO@&5-Tq0a0u5X(K5RC#e%$Uas?De%v%(*b(DPhn95{ zFSZoeED(B?vS=TvZZFmYER;pG^hbt5-1ZxgACL{EEvW8vJLKz;L0=2VWmNQ1C1+c- zGtbAV+!pC5?vG&-&3mlhK{)H7HSLjQ+?;atM^&A~&5TE0W{T;5*4vRqE=nQ$KV)Ro(QhlS>>xl+kq zC~EI?IvA3Us(lqikv~~A_#p!O7I&-F^H~fH#c;gUxsJOt%%>fST=^8ir~3BYS-ufM z3xqtP$e)_RE<%qS30EN8zZDOE)sgx**uh!``jFS7s7tZ3U{S!sZLZ?j@MxW7<1aV_F`eZ!E^lY*QP+==Vmm8Q(!-M~-6)Tv&LUxW1L7vnfZ zu%pFJM9jJW^ae@*Q#qty3^_QJ{aVu)<`mK|R~;{{?$>E$4<0IZIsqz|APXwhwNH&A z4`?QIE$AMUneceYt>LY%GU-OTeE);Yq+GC47Y{eB87_Gb{i?DyOT-OogWp-6hdHDz z`B*$=+-CG5Z|2AB`wNS_oH5Yx(z&#%e4Sm0c*D3R-6}dKp6;>4=O)|w)imC?jv-3h z$Vb94;GirUJML8dp@WeYR|k2&s~)nd+A$bRE^$KoyWUs+qg_{v0sE<&D(Vg+)xkh>8IBo zjEHQ2VWwkCyus#uP`>b{y4Iv3R_qsBgrell$#M zh)|vPjAHSiXiVTLaOLQz9d&*uIUzv zALFzC8M#&}ym6*xmj!ArG@JY&^q{I%V~xIEh2?BP?;f@6u@}tBfl>AC=iw1fNRedb z54k?%^Wmc*^%CUD^%;wLGIhE%rf&IQwGtr zn@km>B)1a@Q@hnsA%0)2562tcX~kw^3i^YrNOXrA6O1jw?#k(apM?{fPqei>+51FA zz9Q~Fa-cP&EOp(r#JZN(|GfH5E8jv@MByR2;9kPTD6JKN9?LJXlrU?56~inf+%Z35Srhed?22glg*2PQ1K<^?+%R5 zs=kL;d){Jppi1}9RRqvRBntGzG=}{LgVgr280&|#cdb2?Vw|$`C#EUyzB+yTz|Kf{ z_qhz^9&3?L%G+542KOWL>Ina=L%jcGb){E5k$pz)m189KvuIw^+F=s+SszB%>?K54 zS{<>2TRk9}VXE&!?9v2YCb07t#$XgNiBU_LU@unhxnVj zhv8Se)H|7)EUdc0{uf8gc5lYpYez-Bjz24K7x~)QSY4mW<+}5rwCcH=T|#VMgbJqd z^u4hL6y52aM@SNzcDgrKCYq2n-F{P?yu2&zo#DhttUDxWUbiWVE8v0}D;eABj~p76S!=D$%au`y z$*HF;|8U_<_;cYL{qJ8mEgR{#+ekpT{h7A+Lb0Q{1F!tE-()!Z)VZTym|(-})%+KN zLEh{L( zS)RdSlsm^!iskA1%$pblCMe`_io#~n!;n!ONLX>kMYtMp+!6=i6bJdAMNGZV?HtE^2_UTPt4XR(^#}rT#ZEk@qlKS5bf)Q8O;MI0TeN+|HA!y= z^@kvPVcu-&hf+zWUof-cKRvr|^@=l2bi_kT01RB-E?ldmhHJ z3Y@1WK0V6T7}5DWnYJ3=DUqV3_}rBHh_&hA2ux=#u@2sJ%}fXeM7P_z zdWL1mTB9Px0?kR2Chy<kKF!!^0T%9x|llg|bvUi1b)teTS%BxVu zkw$BJBhW)+E3ew@425{yES)q-GjG4+Nq~HXf*+4Tg;lXkd(ybUdlv}bgx<%8Z0_4W zd@j?afE(6R9sTwNS0H}_1>$b2o%)fKg8B-9@pX(!Ui!r6OukGdsOuAzw*R$e+qkOC zg~kDbY*kDXy`O$C=PFJ9aGT@8o2YmMY{NMAgKJf0&Xki>GPh${-RTR&&(2R0ag1R% z-mhj>yn)JDq^@S-q47a|Sb> z-!m;;dGZT+AKfz{X@fCnsrv((?3ib@-f7G`;S~3}*4s)IWw#yzqUZj~G22XlQavnI{rlxdl%ONBh%gLxN*Cg>VvP@hQ8P{?wz~UuIN9= zUXdIUw{fo?)(AuhZ@I@577rkM-<-;StLfP|ZL+7Y6Tlk-P*bsjQVt=YzV?IS&rjOS zb>yJfse2YDJL5>*a^7* zea(t+R<7z7=EkLq!6uyUHYubTKK138+gEs4dphz1uFvS$n`~nRSEoCCR)gw*0MAI> zG#}HP+MLOrvB4zwE*3sfk4={8`#wS3iv3wCQmX9!p?1khwW*nEQai-Fe$7WLYAJUM zK2$?2?3?Xfbypr5n+%G52*^o&X%fnXAM^hPA9vI{XGr`urjtuZVfQm~Iw>15PC710 z++d>S$ttKTOv_;(0oy!kxlq{aw3x(R{A9=bthqVGFakp}EcZL)bpKI{+C0}a zslCI(bk?z~ArTSbt*pRk!2@2V=_^t3XZLEBJQQA=>89mQVcZR!bX0D`xN!8IzOtF#oYT9kb#I{X*kr21NIbr#j9P#)(6MB?GfQnggO1b0vRmJw zPCcpDC7J7auR|wA9?P0yxzRH2Ouk^+oDhI&*tM~22VkYPeA}8kWCmtw`8>2y+%CPN zsE?9Y>9i^$DHxLwR`Q`_+Ug6T*09h^vibVIr!<_A*2;qhBYiM?gXnasC`d~2Lx-8~ z;TeoDx{B6X|2LtG&l+MC(#I7H^yVFFp(0w;?i;)fb*H@T`JL}?xw=hLQ+RE31H%4{ zOG`p`Q94@RVs_-OKE}=7@KZO>@=JW=utsUmB`$u|*!}thKH0*)R==?>X=rJK!OBKm zIRV}D>YaYnZue?VpnJv{-7foq%~rn+hJg+v|ou3r@+39>R34X|_kvk&H_lm*c5&E>u zGj767&01z1_B7z&HH<6TH~$IzU_+cSqWg_N;(gWI8n#{y{`04{DXTgtIG$p~S->5i zwE1#xSV5`Od7x`Wc)#m~sjFKP16fFw3m5k=u;+STih@2qF&-}EjjFgPu7~wl<0V_Z z7A(r1~mY_+-*n%3FmFJu1Jx3U;G-xs+!S zamw2|cEXH==j!Emie#|e4=29ZrH}zl=yw`+Rxgj1(h(6l(h(1c>{8X$D6Rb5w=?=+ z4KK|iaLM%kRt(B_J;mqKJeIxzm6YecMr~^@$z*i9ejuodqJ)eu<3&?ukH$pfT*mQL zRbh}QVQ$rbpHUwKzVj!i&-AaL%$nSZM2phSO?<}ws#2o=L~K)o?iqF|B>ewK9El6~ zTKPA+$X_JZe9KnAWk(UE$DXJSk^+-=4u_r`s(p&2*HoC+B;kFUUPT~xF!O#ypv`Yf|w4eXiP7wG0y~XP_Bkra844XTP-g>pS z6F-U$3q&ISL04|aW}&MUuA0^?7?}cek(S);{+zXQ;0{FnzjR;k^l|LRVAaz2|KfLQ z2XDFW|99Ex@4vtixRd@aUH<5ff?(vK&0nnit4&D**{j_8MazE1Oq)~-HzH^N3^kw5 zT@s+vDJ?uZ@+Aclln(cSqlZ|vZ#Mdce!)ze+}rs<@<0Y_pZoL#;OVJjX@VXUl1mo{ zYXjB4C+GiRjUWG&*Vz4kVKnS7gJuN&t2g~QfF5R>c5N35{UbmAo13|Oi(=KFy{1eR z=hBYL3=GC>_$J?Am$8w1d>ZmeawFsK-N@j7m^{($ACJ?Ff7VU^)11cp>!_`*v!s66 zz_rhv%}IynQkry;*kmw!!#`%nMEk~`c6UDa^od50!z zCVm%~^oAeLD{W=p*Jbp_!D*|=;5^em$kHiD$G|upJos!w@}oNi z5bn}Pj7n%g27axLgk4ZOn3+2}${0Hn{DbU&HOB2c3S0*g&5iJrJ@O}~Zd;@vJtBz_ z4xItMDER&y-cox0ug3pnmVbA>e|gKl2GRfR+>$oA-xEw!2!@duu~WYA-9;cTStNb8 zz*?}l%;TEsB_aPyBhLa}MUt(;|MxjdC_gJy7QRu3J^rawdZd4j0dR(g0)R~UWx|~= z*KGbaKF4rwvr;x?Wz*DKB&JdQM=a^Ak?O^{;=CewnQ5`~JILC`__ for an end to end example. It constructs two models and two optimizers and uses a custom training operator to provide a non-standard training loop. + + +Initialization Functions +------------------------ + +Use the ``initialization_hook`` parameter to initialize state on each worker process when they are started. This is useful when setting an environment variable: + +.. code-block:: python + + def initialization_hook(): print("NCCL DEBUG SET") # Need this for avoiding a connection restart issue os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" @@ -193,8 +299,8 @@ and ``trainer.load``, which wraps the relevant ``torch.save`` and ``torch.load`` trainer_2.restore(checkpoint_path) -Exporting a model for inference -------------------------------- +Retrieving the model +-------------------- The trained torch model can be extracted for use within the same Python program with ``trainer.get_model()``. This will load the state dictionary of the model(s). @@ -242,22 +348,23 @@ To specify particular parameters for ``amp.initialize``, you can use the ``apex_ } ) -Note that if using a custom training function, you will need to manage loss scaling manually. +Note that if using a custom training operator (:ref:`raysgd-custom-training`), you will need to manage loss scaling manually. Distributed Multi-node Training ------------------------------- -You can scale out your training onto multiple nodes without making any modifications to your training code. To train across a cluster, simply make sure that the Ray cluster is started. +You can scale your training to multiple nodes without making any modifications to your training code. -You can start a Ray cluster `via the Ray cluster launcher `_ or `manually `_. +To train across a cluster, first make sure that the Ray cluster is started. You can start a Ray cluster `via the Ray cluster launcher `_ or `manually `_. -.. code-block:: bash +Then, in your program, you'll need to connect to this cluster via ``ray.init``: - ray up CLUSTER.yaml - ray submit train.py --args="--address='auto'" +.. code-block:: python -Then, within ``train.py`` you can scale up the number of workers seamlessly across multiple nodes: + ray.init(address="auto") # or a specific redis address of the form "ip-address:port" + +After connecting, you can scale up the number of workers seamlessly across multiple nodes: .. code-block:: python @@ -266,7 +373,10 @@ Then, within ``train.py`` you can scale up the number of workers seamlessly acro data_creator, optimizer_creator, loss_creator=nn.MSELoss, - num_replicas=100) + num_replicas=100 + ) + trainer.train() + model = trainer.get_model() Advanced: Fault Tolerance @@ -310,22 +420,37 @@ Advanced: Hyperparameter Tuning Simultaneous Multi-model Training --------------------------------- -In certain scenarios such as training GANs, you may want to use multiple models in the training loop. You can do this in the ``PyTorchTrainer`` by allowing the ``model_creator``, ``optimizer_creator``, and ``scheduler_creator`` to return multiple values. - -If multiple models, optimizers, or schedulers are returned, you will need to provide a custom training function (and custom validation function if you plan to call ``validate``). +In certain scenarios, such as training GANs, you may want to use multiple models in the training loop. You can do this in the ``PyTorchTrainer`` by allowing the ``model_creator``, ``optimizer_creator``, and ``scheduler_creator`` to return multiple values. Provide a custom TrainingOperator (:ref:`raysgd-custom-training`) to train across multiple models. You can see the `DCGAN script `_ for an end-to-end example. .. code-block:: python + from ray.util.sgd.pytorch import PyTorchTrainer, TrainingOperator + + def train(*, model=None, criterion=None, optimizer=None, dataloader=None): + model.train() + train_loss = 0 + correct = 0 + total = 0 + for batch_idx, (inputs, targets) in enumerate(dataloader): + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + train_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + return { + "accuracy": correct / total, + "train_loss": train_loss / (batch_idx + 1) + } + def model_creator(config): - netD = Discriminator() - netD.apply(weights_init) - - netG = Generator() - netG.apply(weights_init) - return netD, netG - + return Discriminator(), Generator() def optimizer_creator(models, config): net_d, net_g = models @@ -335,125 +460,27 @@ You can see the `DCGAN script `__: - Training a ResNet18 model on CIFAR10. It uses a custom training - function, a custom validation function, and custom initialization code for each worker. + Training a ResNet18 model on CIFAR10. - `DCGAN example `__: - Training a Deep Convolutional GAN on MNIST. It constructs - two models and two optimizers and uses a custom training and validation function. + Training a Deep Convolutional GAN on MNIST. It constructs two models and two optimizers and uses a custom training operator. diff --git a/doc/source/raysgd/raysgd_ref.rst b/doc/source/raysgd/raysgd_ref.rst index d7ac5020b..15e99d62c 100644 --- a/doc/source/raysgd/raysgd_ref.rst +++ b/doc/source/raysgd/raysgd_ref.rst @@ -11,6 +11,14 @@ PyTorchTrainer .. automethod:: __init__ +.. _ref-pytorch-operator: + +PyTorch TrainingOperator +------------------------ + +.. autoclass:: ray.util.sgd.pytorch.TrainingOperator + :members: + PyTorchTrainable ---------------- diff --git a/python/ray/util/sgd/pytorch/__init__.py b/python/ray/util/sgd/pytorch/__init__.py index dd284cb9e..bdcc100d3 100644 --- a/python/ray/util/sgd/pytorch/__init__.py +++ b/python/ray/util/sgd/pytorch/__init__.py @@ -3,6 +3,7 @@ logger = logging.getLogger(__name__) PyTorchTrainer = None PyTorchTrainable = None +TrainingOperator = None try: import torch # noqa: F401 @@ -10,6 +11,8 @@ try: from ray.util.sgd.pytorch.pytorch_trainer import (PyTorchTrainer, PyTorchTrainable) - __all__ = ["PyTorchTrainer", "PyTorchTrainable"] + from ray.util.sgd.pytorch.training_operator import TrainingOperator + + __all__ = ["PyTorchTrainer", "PyTorchTrainable", "TrainingOperator"] except ImportError: logger.warning("PyTorch not found. PyTorchTrainer will not be available") diff --git a/python/ray/util/sgd/pytorch/constants.py b/python/ray/util/sgd/pytorch/constants.py new file mode 100644 index 000000000..0d7421a42 --- /dev/null +++ b/python/ray/util/sgd/pytorch/constants.py @@ -0,0 +1,7 @@ +USE_FP16 = "__use_fp16__" +BATCH_COUNT = "batch_count" +SCHEDULER_STEP = "scheduler_step" +SCHEDULER_STEP_BATCH = "batch" +SCHEDULER_STEP_EPOCH = "epoch" + +VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH} diff --git a/python/ray/util/sgd/pytorch/distributed_pytorch_runner.py b/python/ray/util/sgd/pytorch/distributed_pytorch_runner.py index 8db45dfe4..f3c8a8180 100644 --- a/python/ray/util/sgd/pytorch/distributed_pytorch_runner.py +++ b/python/ray/util/sgd/pytorch/distributed_pytorch_runner.py @@ -95,15 +95,22 @@ class DistributedPyTorchRunner(PyTorchRunner): self.validation_loader = torch.utils.data.DataLoader( val_set, batch_size=self.batch_size, **self.dataloader_config) - def step(self): + self.training_operator = self.training_operator_cls( + self.config, + models=self.models, + optimizers=self.optimizers, + criterion=self.criterion, + schedulers=self.schedulers, + use_fp16=self.use_fp16) + + def train_epoch(self, **kwargs): """Runs a training epoch and updates the model parameters. Automatically sets epoch of sampler if possible. """ - logger.debug("Starting step") if hasattr(self.train_loader.sampler, "set_epoch"): - self.train_loader.sampler.set_epoch(self.epoch) - return super(DistributedPyTorchRunner, self).step() + self.train_loader.sampler.set_epoch(self.epochs) + return super(DistributedPyTorchRunner, self).train_epoch(**kwargs) def _get_model_state_dicts(self): """Fetch state from ``model.module`` instead of ``model``. diff --git a/python/ray/util/sgd/pytorch/examples/cifar_pytorch_example.py b/python/ray/util/sgd/pytorch/examples/cifar_pytorch_example.py index e039ec839..2007dfc80 100644 --- a/python/ray/util/sgd/pytorch/examples/cifar_pytorch_example.py +++ b/python/ray/util/sgd/pytorch/examples/cifar_pytorch_example.py @@ -10,10 +10,9 @@ import torchvision.transforms as transforms import ray from ray.util.sgd.pytorch import (PyTorchTrainer, PyTorchTrainable) from ray.util.sgd.pytorch.resnet import ResNet18 -from ray.util.sgd.pytorch.utils import TEST_MODE -def initialization_hook(runner): +def initialization_hook(): print("NCCL DEBUG SET") # Need this for avoiding a connection restart issue os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" @@ -40,6 +39,11 @@ def cifar_creator(config): validation_dataset = torchvision.datasets.CIFAR10( root="~/data", train=False, download=False, transform=transform_test) + if config.get("test_mode"): + train_dataset = torch.utils.data.Subset(train_dataset, list(range(64))) + validation_dataset = torch.utils.data.Subset(validation_dataset, + list(range(64))) + return train_dataset, validation_dataset @@ -58,7 +62,6 @@ def train_example(num_replicas=1, use_gpu=False, use_fp16=False, test_mode=False): - config = {TEST_MODE: test_mode} trainer1 = PyTorchTrainer( ResNet18, cifar_creator, @@ -67,7 +70,10 @@ def train_example(num_replicas=1, scheduler_creator=scheduler_creator, initialization_hook=initialization_hook, num_replicas=num_replicas, - config=config, + config={ + "lr": 0.01, + "test_mode": test_mode + }, use_gpu=use_gpu, batch_size=16 if test_mode else 512, backend="nccl" if use_gpu else "gloo", @@ -88,14 +94,14 @@ def tune_example(num_replicas=1, use_gpu=False, test_mode=False): "model_creator": ResNet18, "data_creator": cifar_creator, "optimizer_creator": optimizer_creator, - "loss_creator": lambda config: nn.CrossEntropyLoss(), + "loss_creator": nn.CrossEntropyLoss, "num_replicas": num_replicas, "initialization_hook": initialization_hook, "use_gpu": use_gpu, "batch_size": 16 if test_mode else 512, "config": { - "lr": tune.choice([1e-4, 1e-3, 5e-3, 1e-2]), - TEST_MODE: test_mode + "lr": tune.choice([1e-4, 1e-3]), + "test_mode": test_mode }, "backend": "nccl" if use_gpu else "gloo" } diff --git a/python/ray/util/sgd/pytorch/examples/dcgan.py b/python/ray/util/sgd/pytorch/examples/dcgan.py index 110697d10..ad627f533 100644 --- a/python/ray/util/sgd/pytorch/examples/dcgan.py +++ b/python/ray/util/sgd/pytorch/examples/dcgan.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn import torch.optim as optim import torch.utils.data -import torchvision.datasets as dset +import torchvision.datasets as datasets import torchvision.transforms as transforms import numpy as np @@ -16,25 +16,12 @@ from scipy.stats import entropy import ray from ray.util.sgd import PyTorchTrainer -from ray.util.sgd.pytorch.utils import TEST_MODE - -# Training parameters -TRAIN_BATCHES = 5 -# Number of channels in the training images. For color images this is 3 -num_channels = 1 - -# Size of z latent vector (i.e. size of generator input) -latent_vector_size = 100 - -# Size of feature maps in generator -features_g = 32 - -# Size of feature maps in discriminator -features_d = 32 +from ray.util.sgd.utils import override +from ray.util.sgd.pytorch import TrainingOperator def data_creator(config): - return dset.MNIST( + dataset = datasets.MNIST( root="~/mnist/", download=True, transform=transforms.Compose([ @@ -42,62 +29,56 @@ def data_creator(config): transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )), ])) - - -def weights_init(m): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("BatchNorm") != -1: - nn.init.normal_(m.weight.data, 1.0, 0.02) - nn.init.constant_(m.bias.data, 0) + if config.get("test_mode"): + dataset = torch.utils.data.Subset(dataset, list(range(64))) + return dataset class Generator(nn.Module): - def __init__(self): + def __init__(self, latent_vector_size, features=32, num_channels=1): super(Generator, self).__init__() + self.latent_vector_size = latent_vector_size self.main = nn.Sequential( # input is Z, going into a convolution nn.ConvTranspose2d( - latent_vector_size, features_g * 4, 4, 1, 0, bias=False), - nn.BatchNorm2d(features_g * 4), + latent_vector_size, features * 4, 4, 1, 0, bias=False), + nn.BatchNorm2d(features * 4), nn.ReLU(True), nn.ConvTranspose2d( - features_g * 4, features_g * 2, 4, 2, 1, bias=False), - nn.BatchNorm2d(features_g * 2), + features * 4, features * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(features * 2), nn.ReLU(True), - nn.ConvTranspose2d( - features_g * 2, features_g, 4, 2, 1, bias=False), - nn.BatchNorm2d(features_g), + nn.ConvTranspose2d(features * 2, features, 4, 2, 1, bias=False), + nn.BatchNorm2d(features), nn.ReLU(True), - nn.ConvTranspose2d(features_g, num_channels, 4, 2, 1, bias=False), + nn.ConvTranspose2d(features, num_channels, 4, 2, 1, bias=False), nn.Tanh()) - def forward(self, input): - return self.main(input) + def forward(self, x): + return self.main(x) class Discriminator(nn.Module): - def __init__(self): + def __init__(self, features=32, num_channels=1): super(Discriminator, self).__init__() self.main = nn.Sequential( - nn.Conv2d(num_channels, features_d, 4, 2, 1, bias=False), + nn.Conv2d(num_channels, features, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False), - nn.BatchNorm2d(features_d * 2), nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False), - nn.BatchNorm2d(features_d * 4), nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(features_d * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid()) + nn.Conv2d(features, features * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(features * 2), nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(features * 2, features * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(features * 4), nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(features * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid()) - def forward(self, input): - return self.main(input) + def forward(self, x): + return self.main(x) -class Net(nn.Module): +class LeNet(nn.Module): """LeNet for MNist classification, used for inception_score.""" def __init__(self): - super(Net, self).__init__() + super(LeNet, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() @@ -114,92 +95,22 @@ class Net(nn.Module): return F.log_softmax(x, dim=1) -def inception_score(imgs, batch_size=32, splits=1): - N = len(imgs) - dtype = torch.FloatTensor - dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) - cm = Net() - cm.load_state_dict(torch.load(model_path)) - cm.eval() - up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype) - - def get_pred(x): - x = up(x) - x = cm(x) - return F.softmax(x).data.cpu().numpy() - - preds = np.zeros((N, 10)) - for i, batch in enumerate(dataloader, 0): - batch = batch.type(dtype) - batchv = Variable(batch) - batch_size_i = batch.size()[0] - preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv) - - # Now compute the mean kl-div - split_scores = [] - for k in range(splits): - part = preds[k * (N // splits):(k + 1) * (N // splits), :] - py = np.mean(part, axis=0) - scores = [] - for i in range(part.shape[0]): - pyx = part[i, :] - scores.append(entropy(pyx, py)) - split_scores.append(np.exp(np.mean(scores))) - - return np.mean(split_scores), np.std(split_scores) - - def model_creator(config): - netD = Discriminator() - netD.apply(weights_init) + def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) - netG = Generator() - netG.apply(weights_init) - return netD, netG + discriminator = Discriminator() + discriminator.apply(weights_init) - -def train(config, models, dataloader, criterion, optimizers, **kwargs): - netD, netG = models - optimD, optimG = optimizers - real_label = 1 - fake_label = 0 - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - for i, data in enumerate(dataloader, 0): - if i >= TRAIN_BATCHES and config.get(TEST_MODE): - break - - netD.zero_grad() - real_cpu = data[0].to(device) - b_size = real_cpu.size(0) - label = torch.full((b_size, ), real_label, device=device) - output = netD(real_cpu).view(-1) - errD_real = criterion(output, label) - errD_real.backward() - - noise = torch.randn(b_size, latent_vector_size, 1, 1, device=device) - fake = netG(noise) - label.fill_(fake_label) - output = netD(fake.detach()).view(-1) - errD_fake = criterion(output, label) - errD_fake.backward() - errD = errD_real + errD_fake - optimD.step() - - netG.zero_grad() - label.fill_(real_label) - output = netD(fake).view(-1) - errG = criterion(output, label) - errG.backward() - optimG.step() - - is_score, is_std = inception_score(fake) - - return { - "loss_g": errG.item(), - "loss_d": errD.item(), - "inception": is_score - } + generator = Generator( + latent_vector_size=config.get("latent_vector_size", 100)) + generator.apply(weights_init) + return discriminator, generator def optimizer_creator(models, config): @@ -211,22 +122,122 @@ def optimizer_creator(models, config): return discriminator_opt, generator_opt +class GANOperator(TrainingOperator): + def setup(self, config): + self.device = torch.device("cuda" + if torch.cuda.is_available() else "cpu") + + self.classifier = LeNet() + self.classifier.load_state_dict( + torch.load(config["classification_model_path"])) + self.classifier.eval() + + def inception_score(self, imgs, batch_size=32, splits=1): + """Calculate the inception score of the generated images.""" + N = len(imgs) + dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) + up = nn.Upsample( + size=(28, 28), mode="bilinear").type(torch.FloatTensor) + + def get_pred(x): + x = up(x) + x = self.classifier(x) + return F.softmax(x).data.cpu().numpy() + + # Obtain predictions for the fake provided images + preds = np.zeros((N, 10)) + for i, batch in enumerate(dataloader, 0): + batch = batch.type(torch.FloatTensor) + batchv = Variable(batch) + batch_size_i = batch.size()[0] + preds[i * batch_size:i * batch_size + + batch_size_i] = get_pred(batchv) + + # Now compute the mean kl-div + split_scores = [] + for k in range(splits): + part = preds[k * (N // splits):(k + 1) * (N // splits), :] + py = np.mean(part, axis=0) + scores = [] + for i in range(part.shape[0]): + pyx = part[i, :] + scores.append(entropy(pyx, py)) + split_scores.append(np.exp(np.mean(scores))) + + return np.mean(split_scores), np.std(split_scores) + + @override(TrainingOperator) + def train_batch(self, batch, batch_info): + """Trains on one batch of data from the data creator.""" + real_label = 1 + fake_label = 0 + discriminator, generator = self.models + optimD, optimG = self.optimizers + + # Compute a discriminator update for real images + discriminator.zero_grad() + real_cpu = batch[0].to(self.device) + batch_size = real_cpu.size(0) + label = torch.full((batch_size, ), real_label, device=self.device) + output = discriminator(real_cpu).view(-1) + errD_real = self.criterion(output, label) + errD_real.backward() + + # Compute a discriminator update for fake images + noise = torch.randn( + batch_size, + self.config.get("latent_vector_size", 100), + 1, + 1, + device=self.device) + fake = generator(noise) + label.fill_(fake_label) + output = discriminator(fake.detach()).view(-1) + errD_fake = self.criterion(output, label) + errD_fake.backward() + errD = errD_real + errD_fake + + # Update the discriminator + optimD.step() + + # Update the generator + generator.zero_grad() + label.fill_(real_label) + output = discriminator(fake).view(-1) + errG = self.criterion(output, label) + errG.backward() + optimG.step() + + is_score, is_std = self.inception_score(fake) + + return { + "loss_g": errG.item(), + "loss_d": errD.item(), + "inception": is_score, + "num_samples": batch_size + } + + def train_example(num_replicas=1, use_gpu=False, test_mode=False): - config = {TEST_MODE: test_mode} + config = { + "test_mode": test_mode, + "classification_model_path": os.path.join( + os.path.dirname(ray.__file__), + "util/sgd/pytorch/examples/mnist_cnn.pt") + } trainer = PyTorchTrainer( model_creator, data_creator, optimizer_creator, nn.BCELoss, - train_function=train, - validation_function=False, + training_operator_cls=GANOperator, num_replicas=num_replicas, config=config, use_gpu=use_gpu, batch_size=16 if test_mode else 512, backend="nccl" if use_gpu else "gloo") - for i in range(10): - stats = trainer.train(max_retries=3) + for i in range(5): + stats = trainer.train() print(stats) return trainer @@ -240,7 +251,7 @@ if __name__ == "__main__": "--address", required=False, type=str, - help="the address to use for Redis") + help="the address to use to connect to a cluster.") parser.add_argument( "--num-replicas", "-n", @@ -255,10 +266,6 @@ if __name__ == "__main__": args, _ = parser.parse_known_args() ray.init(address=args.address) - path = os.path.dirname(ray.__file__) - model_path = os.path.join(path, "util/sgd/pytorch/examples/mnist_cnn.pt") - # load the pretrained mnist classification model for inception_score - trainer = train_example( num_replicas=args.num_replicas, use_gpu=args.use_gpu, diff --git a/python/ray/util/sgd/pytorch/pytorch_runner.py b/python/ray/util/sgd/pytorch/pytorch_runner.py index 7c8803896..0be7ba9db 100644 --- a/python/ray/util/sgd/pytorch/pytorch_runner.py +++ b/python/ray/util/sgd/pytorch/pytorch_runner.py @@ -2,13 +2,15 @@ import collections from filelock import FileLock import logging import inspect +import itertools import os import torch import torch.utils.data from torch.utils.data import Dataset import ray -from ray.util.sgd.pytorch import utils as pytorch_utils +from ray.util.sgd.pytorch.constants import USE_FP16, SCHEDULER_STEP +from ray.util.sgd.pytorch.training_operator import TrainingOperator from ray.util.sgd import utils logger = logging.getLogger(__name__) @@ -31,8 +33,7 @@ class PyTorchRunner: loss_creator (dict -> loss | Loss class): see pytorch_trainer.py. scheduler_creator (optimizers, dict -> schedulers): see pytorch_trainer.py. - train_function: see pytorch_trainer.py - validation_function: see pytorch_trainer.py + training_operator_cls: see pytorch_trainer.py config (dict): see pytorch_trainer.py. dataloader_config (dict): See pytorch_trainer.py. batch_size (int): see pytorch_trainer.py. @@ -47,8 +48,7 @@ class PyTorchRunner: optimizer_creator, loss_creator, scheduler_creator=None, - train_function=None, - validation_function=None, + training_operator_cls=None, config=None, dataloader_config=None, batch_size=16, @@ -60,17 +60,15 @@ class PyTorchRunner: self.optimizer_creator = optimizer_creator self.loss_creator = loss_creator self.scheduler_creator = scheduler_creator + self.training_operator_cls = training_operator_cls or TrainingOperator self.config = {} if config is None else config self.dataloader_config = { "num_workers": 2 } if dataloader_config is None else dataloader_config - self.train_function = train_function or pytorch_utils.train - self.validation_function = (validation_function - or pytorch_utils.validate) self.batch_size = batch_size self.verbose = True - self.epoch = 0 + self.epochs = 0 self._timers = { k: utils.TimerStat(window_size=1) for k in [ @@ -160,6 +158,14 @@ class PyTorchRunner: self.validation_loader = torch.utils.data.DataLoader( val_set, batch_size=self.batch_size, **self.dataloader_config) + self.training_operator = self.training_operator_cls( + self.config, + models=self.models, + optimizers=self.optimizers, + criterion=self.criterion, + schedulers=self.schedulers, + use_fp16=self.use_fp16) + def get_node_ip(self): """Returns the IP address of the current node.""" return ray.services.get_node_ip_address() @@ -168,47 +174,42 @@ class PyTorchRunner: """Finds a free port on the current node.""" return utils.find_free_port() - def step(self): + def train_epoch(self, num_steps=None, info=None): """Runs a training epoch and updates the model parameters.""" - logger.debug("Begin Training Epoch {}".format(self.epoch + 1)) - training_config = self.config.copy() - training_config.update({ - pytorch_utils.USE_FP16: self.use_fp16, - pytorch_utils.SCHEDULER_STEP: self.scheduler_step_freq + logger.debug("Begin Training Step {}".format(self.epochs + 1)) + info = info or {} + info.update({ + USE_FP16: self.use_fp16, + SCHEDULER_STEP: self.scheduler_step_freq }) with self._timers["training"]: - train_stats = self.train_function( - training_config, - self.given_models, - self.train_loader, - self.criterion, - self.given_optimizers, - scheduler=self.given_schedulers) - train_stats["epoch"] = self.epoch - - self.epoch += 1 + iterator = self.train_loader + if num_steps: + iterator = itertools.islice(iter(self.train_loader), num_steps) + train_stats = self.training_operator.train_epoch(iterator, info) + self.epochs += 1 train_stats.update(self.stats()) return train_stats - def validate(self): + def validate(self, num_steps=None, info=None): """Evaluates the model on the validation data set.""" if self.validation_loader is None: raise ValueError("No validation dataloader provided.") + info = info or {} with self._timers["validation"]: - validation_stats = self.validation_function( - self.config, - self.given_models, - self.validation_loader, - self.criterion, - scheduler=self.given_schedulers) + iterator = self.validation_loader + if num_steps: + iterator = itertools.islice( + iter(self.validation_loader), num_steps) + validation_stats = self.training_operator.validate(iterator, info) validation_stats.update(self.stats()) return validation_stats def stats(self): """Returns a dictionary of statistics collected.""" - stats = {"epoch": self.epoch} + stats = {"epoch": self.epochs} for k, t in self._timers.items(): stats[k + "_time_mean"] = t.mean stats[k + "_time_total"] = t.sum @@ -233,7 +234,8 @@ class PyTorchRunner: """Returns the state of the runner.""" state = { - "epoch": self.epoch, + "epoch": self.epochs, + "operator": self.training_operator.state_dict(), "models": self._get_model_state_dicts(), "optimizers": [opt.state_dict() for opt in self.optimizers], "stats": self.stats() @@ -262,13 +264,18 @@ class PyTorchRunner: if self.use_fp16 and "amp" in state and amp: amp.load_state_dict(state["amp"]) - self.epoch = state["stats"]["epoch"] + self.epochs = state["stats"]["epoch"] + self.training_operator.load_state_dict(state_dict) - def apply_fn(self, fn): - return fn(self) + def apply(self, fn): + return fn() + + def apply_operator(self, fn): + return fn(self.training_operator) def shutdown(self): """Attempts to shut down the worker.""" + del self.training_operator del self.validation_loader del self.train_loader del self.criterion diff --git a/python/ray/util/sgd/pytorch/pytorch_trainer.py b/python/ray/util/sgd/pytorch/pytorch_trainer.py index 6eec29ba3..0316cd11a 100644 --- a/python/ray/util/sgd/pytorch/pytorch_trainer.py +++ b/python/ray/util/sgd/pytorch/pytorch_trainer.py @@ -15,12 +15,20 @@ from ray.util.sgd.pytorch.distributed_pytorch_runner import ( DistributedPyTorchRunner) from ray.util.sgd import utils from ray.util.sgd.pytorch.pytorch_runner import PyTorchRunner -from ray.util.sgd.pytorch import utils as pytorch_utils +from ray.util.sgd.pytorch.constants import VALID_SCHEDULER_STEP logger = logging.getLogger(__name__) RESIZE_COOLDOWN_S = 10 +def _validate_scheduler_step_freq(scheduler_step_freq): + if scheduler_step_freq: + if scheduler_step_freq not in VALID_SCHEDULER_STEP: + raise ValueError( + "Scheduler step freq must be in {}. Got {}".format( + VALID_SCHEDULER_STEP, scheduler_step_freq)) + + class PyTorchTrainer: """Train a PyTorch model using distributed PyTorch. @@ -48,14 +56,15 @@ class PyTorchTrainer: loss_creator=nn.MSELoss, use_gpu=True ) - trainer.train() + for i in range(4): + trainer.train() Args: model_creator (dict -> Model(s)): Constructor function that takes in config and returns the model(s) to be optimized. These must be ``torch.nn.Module`` objects. If multiple models are returned, - a ``train_function`` must be specified. You do not need to + a ``training_operator_cls`` must be specified. You do not need to handle GPU/devices in this function; RaySGD will do that under the hood. data_creator (dict -> Dataset(s)): Constructor function @@ -75,22 +84,18 @@ class PyTorchTrainer: of ``torch.nn.modules.loss._Loss``, which is most Pytorch loss classes. For example, ``loss_creator=torch.nn.BCELoss``. scheduler_creator (optimizers, dict -> loss): - A constructor function for the scheduler loss. This is + A constructor function for the torch scheduler. This is a function that takes in the generated optimizers (from ``optimizer_creator``) provided config for customization. Be sure to set ``scheduler_step_freq`` to increment the scheduler correctly. - train_function: Custom function for training. This function - will be executed in parallel across all workers at once. The - function needs to take in (models, train_dataloader, criterion, - optimizers, config), and return a dict of training stats. - validation_function: Custom function for validation. This function - will be executed in parallel across all workers at once. - This takes in (model, val_dataloader, criterion, config) - and returns a dict of validation stats. + training_operator_cls (type): Custom training operator class + that subclasses the TrainingOperator class. This class + will be copied onto all remote workers and used to specify + custom training and validation operations. Defaults to + TrainingOperator. config (dict): Custom configuration value to be passed to - "model_creator", "data_creator", "optimizer_creator", and - "loss_creator". + all creator and operator constructors. dataloader_config (dict): Configuration values to be passed into the ``torch.utils.data.DataLoader`` object that wraps the dataset on each parallel worker for both training @@ -130,8 +135,7 @@ class PyTorchTrainer: optimizer_creator, loss_creator, scheduler_creator=None, - train_function=None, - validation_function=None, + training_operator_cls=None, initialization_hook=None, config=None, dataloader_config=None, @@ -151,11 +155,10 @@ class PyTorchTrainer: self.model_creator = model_creator self.data_creator = data_creator - self.train_function = train_function self.optimizer_creator = optimizer_creator self.loss_creator = loss_creator self.scheduler_creator = scheduler_creator - self.validation_function = validation_function + self.training_operator_cls = training_operator_cls self.initialization_hook = initialization_hook self.config = {} if config is None else config self.dataloader_config = dataloader_config @@ -166,6 +169,8 @@ class PyTorchTrainer: logger.info("Using {} as backend.".format(backend)) self.backend = backend + + # TODO: Have an auto "use_gpu" option to detect and use GPUs. self.use_gpu = use_gpu self.batch_size = batch_size self.max_replicas = num_replicas @@ -180,12 +185,7 @@ class PyTorchTrainer: self._num_failures = 0 self._last_resize = float("-inf") - if scheduler_step_freq and ( - scheduler_step_freq not in pytorch_utils.VALID_SCHEDULER_STEP): - raise ValueError( - "Scheduler step freq must be in {}. Got {}".format( - pytorch_utils.VALID_SCHEDULER_STEP, scheduler_step_freq)) - + _validate_scheduler_step_freq(scheduler_step_freq) self.scheduler_step_freq = scheduler_step_freq self._start_workers(self.max_replicas) @@ -204,8 +204,7 @@ class PyTorchTrainer: self.optimizer_creator, self.loss_creator, self.scheduler_creator, - train_function=self.train_function, - validation_function=self.validation_function, + training_operator_cls=self.training_operator_cls, config=self.config, dataloader_config=self.dataloader_config, batch_size=self.batch_size, @@ -243,8 +242,7 @@ class PyTorchTrainer: self.loss_creator, self.scheduler_creator, backend=self.backend, - train_function=self.train_function, - validation_function=self.validation_function, + training_operator_cls=self.training_operator_cls, config=self.config, dataloader_config=self.dataloader_config, batch_size=batch_size_per_replica, @@ -266,21 +264,35 @@ class PyTorchTrainer: for i, worker in enumerate(self.workers) ]) - def train(self, max_retries=0, checkpoint="auto"): + def train(self, + num_steps=None, + max_retries=0, + checkpoint="auto", + info=None): """Runs a training epoch. Runs an average over all values returned from workers. Set `max_retries` to enable fault handling in case of instance preemption. Args: + num_steps (int): Number of batches to compute update steps on. + This corresponds also to the number of times + ``TrainingOperator.train_batch`` is called. max_retries (int): Must be non-negative. If set to N, will kill all current workers, query the Ray global state for total available resources, and re-launch up to the available resources. Behavior is not well-defined in case of shared cluster usage. checkpoint (str): Path to checkpoint to restore from if retrying. - If max_retries is set and checkpoint == "auto", PyTorchTrainer - will save a checkpoint before starting to train. + If max_retries is set and ``checkpoint == "auto"``, + PyTorchTrainer will save a checkpoint before starting to train. + info (dict): Optional dictionary passed to the training + operator for ``train_epoch`` and ``train_batch``. + + Returns: + A dictionary of metrics for training. + You can provide custom metrics by passing in a custom + ``training_operator_cls``. """ assert max_retries >= 0, "`max_retries` must be non-negative." if max_retries: @@ -296,7 +308,8 @@ class PyTorchTrainer: self._resize_workers(checkpoint=checkpoint) with self.optimizer_timer: - success, worker_stats = self._train_step() + success, worker_stats = self._train_epoch( + num_steps=num_steps, info=info) # Fault handling for i in range(max_retries): if success: @@ -306,7 +319,8 @@ class PyTorchTrainer: self._resize_workers(checkpoint=checkpoint) logger.info("Retrying training step with %d workers." % len( self.workers)) - success, worker_stats = self._train_step() + success, worker_stats = self._train_epoch( + num_steps=num_steps, info=info) if not success: raise RuntimeError("Training run failed.") @@ -321,19 +335,58 @@ class PyTorchTrainer: train_stats[stat_key] = worker_stats[0][stat_key] return train_stats - def _train_step(self): - worker_stats = [w.step.remote() for w in self.workers] + def _train_epoch(self, num_steps=None, info=None): + worker_stats = [ + w.train_epoch.remote(num_steps=num_steps, info=info) + for w in self.workers + ] success = utils.check_for_failure(worker_stats) return success, worker_stats def apply_all_workers(self, fn): - return ray.get([w.apply_fn.remote(fn) for w in self.workers]) + """Run a function on all operators on the workers. - def validate(self): - """Evaluates the model on the validation data set.""" - if self.validation_function is False: - return {} - worker_stats = ray.get([w.validate.remote() for w in self.workers]) + Args: + fn (Callable): A function that takes in no arguments. + + Returns: + A list of objects returned by ``fn`` on each worker. + + """ + return ray.get([w.apply.remote(fn) for w in self.workers]) + + def apply_all_operators(self, fn): + """Run a function on all operators on the workers. + + Args: + fn (Callable[TrainingOperator]): A function that takes in a + TrainingOperator. + + Returns: + A list of objects returned by ``fn`` on each operator. + + """ + return ray.get([w.apply_operator.remote(fn) for w in self.workers]) + + def validate(self, num_steps=None, info=None): + """Evaluates the model on the validation data set. + + Args: + num_steps (int): Number of batches to compute update steps on. + This corresponds also to the number of times + ``TrainingOperator.validate_batch`` is called. + info (dict): Optional dictionary passed to the training + operator for `validate` and `validate_batch`. + + Returns: + A dictionary of metrics for validation. + You can provide custom metrics by passing in a custom + ``training_operator_cls``. + """ + worker_stats = ray.get([ + w.validate.remote(num_steps=num_steps, info=info) + for w in self.workers + ]) validation_stats = {} for stat_key in worker_stats[0]: @@ -346,8 +399,8 @@ class PyTorchTrainer: This is useful for lr_schedulers such as ``ReduceLROnPlateau``. """ - self.apply_all_workers( - lambda runner: [sched.step(metric) for sched in runner.schedulers]) + self.apply_all_operators( + lambda op: [sched.step(metric) for sched in op.schedulers]) def get_model(self): """Returns the learned model(s).""" @@ -366,17 +419,18 @@ class PyTorchTrainer: Args: checkpoint (str): Path to target checkpoint file. + Returns: + checkpoint (str): Path to target checkpoint file. """ state = ray.get(self.workers[0].get_state.remote()) torch.save(state, checkpoint) return checkpoint def restore(self, checkpoint): - """Restores the model from the provided checkpoint. + """Restores the Trainer and all workers from the provided checkpoint. Args: checkpoint (str): Path to target checkpoint file. - """ state = torch.load(checkpoint) state_id = ray.put(state) @@ -450,7 +504,6 @@ class PyTorchTrainable(Trainable): validation_stats = self._trainer.validate() train_stats.update(validation_stats) - # output {"mean_loss": test_loss, "mean_accuracy": accuracy} return train_stats diff --git a/python/ray/util/sgd/pytorch/training_operator.py b/python/ray/util/sgd/pytorch/training_operator.py new file mode 100644 index 000000000..9b4c5f090 --- /dev/null +++ b/python/ray/util/sgd/pytorch/training_operator.py @@ -0,0 +1,343 @@ +import collections +import torch + +from ray.util.sgd.utils import TimerStat, AverageMeter +from ray.util.sgd.pytorch.constants import ( + SCHEDULER_STEP_EPOCH, SCHEDULER_STEP_BATCH, SCHEDULER_STEP, BATCH_COUNT) + +amp = None + +try: + from apex import amp +except ImportError: + # Apex library is not installed, so we cannot enable mixed precision. + # We don't log here because logging happens in the pytorch_runner, + # where amp is initialized. + pass + + +def _is_multiple(component): + """Checks if a component (optimizer, model, etc) is not singular.""" + return isinstance(component, collections.Iterable) and len(component) > 1 + + +class TrainingOperator: + """Abstract class for custom training or validation loops. + + The scheduler will only be called at a batch or epoch frequency, depending + on the user parameter. Be sure to set ``scheduler_step_freq`` in + ``PyTorchTrainer`` to either "batch" or "epoch" to increment the scheduler + correctly during training. If using a learning rate scheduler + that depends on validation loss, you can use ``trainer.update_scheduler``. + + For both training and validation, there are two granularities that + you can provide customization: per epoch or per batch. + You do not need to override both. + + .. image:: raysgd-custom.jpg + :scale: 80% + :align: center + + Raises: + ValueError if multiple models/optimizers/schedulers are provided. + You are expected to subclass this class if you wish + to train over multiple models/optimizers/schedulers. + """ + + def __init__(self, + config, + models, + optimizers, + criterion, + schedulers=None, + use_fp16=False): + # You are not expected to override this method. + self.timers = { + k: TimerStat() + for k in ["fwd", "grad", "apply", "epoch_time"] + } + self._validated_customization = False + self._models = models # List of models + assert isinstance(models, collections.Iterable), ( + "Components need to be iterable. Got: {}".format(type(models))) + self._optimizers = optimizers # List of optimizers + assert isinstance(optimizers, collections.Iterable), ( + "Components need to be iterable. Got: {}".format(type(optimizers))) + self._criterion = criterion + self._schedulers = schedulers + if schedulers: + assert isinstance(schedulers, collections.Iterable), ( + "Components need to be iterable. Got: {}".format( + type(schedulers))) + self._config = config + self._use_fp16 = use_fp16 + self.global_step = 0 + + if type(self) is TrainingOperator: + for component in (models, schedulers, optimizers): + if _is_multiple(component): + raise ValueError( + "Need to provide a custom operator subclassing " + "TrainingOperator if using multi-scheduler, " + "multi-model or multi-optimizer training/validation.") + + self.setup(config) + + def setup(self, config): + """Override this method to implement custom operator setup. + + Args: + config (dict): Custom configuration value to be passed to + all creator and operator constructors. Same as ``self.config``. + """ + pass + + def train_epoch(self, iterator, info): + """Runs one standard training pass over the train_iterator. + + By default, this method will iterate over the given iterator and + call ``self.train_batch`` over each batch. + + If ``scheduler_step_freq`` is set, this class will also step the + scheduler accordingly. + + You do not need to call ``train_batch`` in this method if you plan + to implement a custom optimization/training routine here. + + Args: + iterator (iter): Iterator over the training data for the entire + epoch. This iterator is expected to be entirely consumed. + info (dict): Dictionary for information to be used for custom + training operations. + + Returns: + A dict of metrics from training. + """ + self._losses = AverageMeter() + + self.model.train() + with self.timers["epoch_time"]: + for batch_idx, batch in enumerate(iterator): + batch_info = { + "batch_idx": batch_idx, + "global_step": self.global_step + } + batch_info.update(info) + metrics = self.train_batch(batch, batch_info=batch_info) + + if self.scheduler and batch_info.get( + SCHEDULER_STEP) == SCHEDULER_STEP_BATCH: + self.scheduler.step() + + if "loss" in metrics: + self._losses.update( + metrics["loss"], n=metrics.get("num_samples", 1)) + self.global_step += 1 + + if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH: + self.scheduler.step() + + stats = { + BATCH_COUNT: batch_idx + 1, + "mean_train_loss": self._losses.avg, + "last_train_loss": self._losses.val, + "epoch_time": self.timers["epoch_time"].last + } + stats.update({ + timer_tag: timer.mean + for timer_tag, timer in self.timers.items() + }) + return stats + + def train_batch(self, batch, batch_info): + """Computes loss and updates the model over one batch. + + This method is responsible for computing the loss and gradient and + updating the model. + + By default, this method implementation assumes that batches + are in (features, labels) format. If using amp/fp16 + training, it will also scale the loss automatically. + + You can provide custom loss metrics and training operations if you + override this method. If overriding this method, you can access model, + optimizer, criterion via ``self.model``, ``self.optimizer``, + and ``self.criterion``. + + You do not need to override this method if you plan to + override ``train_epoch``. + + Args: + batch: One item of the validation iterator. + batch_info (dict): Information dict passed in from ``train_epoch``. + + Returns: + A dictionary of metrics. + By default, this dictionary contains "loss" and "num_samples". + "num_samples" corresponds to number of datapoints in the batch. + However, you can provide any number of other values. + + """ + features, target = batch + # Create non_blocking tensors for distributed training + if torch.cuda.is_available(): + features = features.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # Compute output. + with self.timers["fwd"]: + output = self.model(features) + loss = self.criterion(output, target) + + # Compute gradients in a backward pass. + with self.timers["grad"]: + self.optimizer.zero_grad() + if self.use_fp16: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + # Call step of optimizer to update model params. + with self.timers["apply"]: + self.optimizer.step() + return {"loss": loss.item(), "num_samples": features.size(0)} + + def validate(self, val_iterator, info): + """Runs one standard validation pass over the val_iterator. + + This will call ``model.eval()`` and ``torch.no_grad`` when iterating + over the validation dataset. + + If overriding this method, you can access model, criterion via + ``self.model`` and ``self.criterion``. You also do not need to call + ``validate_batch`` if overriding this method. + + Args: + val_iterator (iter): Iterable constructed over the + validation dataset. + info: (dict): Dictionary for information to be used for custom + validation operations. + + Returns: + A dict of metrics from the evaluation. + By default, returns "mean_accuracy" and "mean_validation_loss" + which is computed by aggregating "loss" and "correct" values + from ``validate_batch`` and dividing it by the sum of + ``num_samples`` from all calls to ``self.validate_batch``. + """ + losses = AverageMeter() + total_correct = 0 + + # switch to evaluate mode + self.model.eval() + with torch.no_grad(): + for batch_idx, batch in enumerate(val_iterator): + batch_info = {"batch_idx": batch_idx} + batch_info.update(info) + metrics = self.validate_batch(batch, batch_info) + if "loss" in metrics: + losses.update( + metrics["loss"], n=metrics.get("num_samples", 1)) + + if "num_correct" in metrics: + total_correct += metrics["num_correct"] + + stats = { + "batch_count": batch_idx + 1, + "mean_validation_loss": losses.avg, + "mean_accuracy": total_correct / losses.count + } + return stats + + def validate_batch(self, batch, batch_info): + """Calcuates the loss and accuracy over a given batch. + + You can override this method to provide arbitrary metrics. + + Args: + batch: One item of the validation iterator. + batch_info (dict): Contains information per batch from + ``validate()``. + + Returns: + A dict of metrics. + By default, returns "loss", "num_correct", and "num_samples". + """ + features, target = batch + if torch.cuda.is_available(): + features = features.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # compute output + output = self.model(features) + loss = self.criterion(output, target) + _, predicted = torch.max(output.data, 1) + + return { + "loss": loss.item(), + "num_correct": (predicted == target).sum().item(), + "num_samples": target.size(0) + } + + def state_dict(self): + """Returns a serializable representation of the operator state.""" + pass + + def load_state_dict(self, state_dict): + """Loads a serializable representation of the operator state.""" + pass + + @property + def config(self): + """Dictionary as provided into PyTorchTrainer.""" + return self._config + + @property + def model(self): + """First or only model created by the provided ``model_creator``.""" + return self._models[0] + + @property + def models(self): + """List of models created by the provided ``model_creator``.""" + return self._models + + @property + def optimizer(self): + """First or only optimizer(s) created by the ``optimizer_creator``.""" + return self._optimizers[0] + + @property + def optimizers(self): + """List of optimizers created by the ``optimizer_creator``.""" + return self._optimizers + + @property + def criterion(self): + """Criterion created by the provided ``loss_creator``.""" + return self._criterion + + @property + def scheduler(self): + """First or only scheduler(s) created by the ``scheduler_creator``.""" + if self._schedulers: + return self._schedulers[0] + + @property + def schedulers(self): + """List of schedulers created by the ``scheduler_creator``.""" + return self._schedulers + + @property + def use_fp16(self): + """Whether the model and optimizer have been FP16 enabled.""" + return self._use_fp16 + + +class _TestingOperator(TrainingOperator): + def train_epoch(self, iterator, info): + func = self.config.get("custom_func") + if callable(func): + return func(self, iterator, info) + return {"done": 1} diff --git a/python/ray/util/sgd/pytorch/utils.py b/python/ray/util/sgd/pytorch/utils.py deleted file mode 100644 index a28407983..000000000 --- a/python/ray/util/sgd/pytorch/utils.py +++ /dev/null @@ -1,229 +0,0 @@ -import collections -import time -import torch - -from ray.util.sgd.utils import TimerStat - -amp = None - -try: - from apex import amp -except ImportError: - # Apex library is not installed, so we cannot enable mixed precision. - # We don't log here because logging happens in the pytorch_runner, - # where amp is initialized. - pass - -USE_FP16 = "__use_fp16__" -TEST_MODE = "__test_mode__" -BATCH_COUNT = "batch_processed" -SCHEDULER_STEP = "scheduler_step" -SCHEDULER_STEP_BATCH = "batch" -SCHEDULER_STEP_EPOCH = "epoch" - -VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH} - - -def train(config, model, train_iterator, criterion, optimizer, scheduler=None): - """Runs one standard training pass over the train_iterator. - - This function automatically measures timing for various operations such - as host to device transfer, gradient calculation, and gradient application. - - It also automatically detects and places the data on the given GPU device - if available. - - The scheduler will only be called at a batch or epoch frequency, depending - on the user parameter. Be sure to set ``scheduler_step_freq`` in - ``PyTorchTrainer`` to either "batch" or "epoch" to increment the scheduler - correctly during training. If using a learning rate scheduler - that depends on validation loss, you can use ``trainer.update_scheduler``. - - Raises: - ValueError if multiple models/optimizers/schedulers are provided. You - are expected to have a custom training function if you wish - to use multiple models/optimizers/schedulers. - - Args: - config: (dict): A user configuration provided into the Trainer - constructor. - model: The model as created by the model_creator. - train_iterator: An iterator created from the DataLoader which - wraps the provided Dataset. - criterion: The loss object created by the loss_creator. - optimizer: The torch.optim.Optimizer object as created by the - optimizer_creator. - scheduler (optional): The torch.optim.lr_scheduler object - as created by the scheduler_creator. Be sure to set - ``scheduler_step_freq`` in ``PyTorchTrainer`` - to increment the scheduler correctly. - - Returns: - A dict of metrics from training. - """ - if isinstance(model, collections.Iterable) or isinstance( - optimizer, collections.Iterable) or isinstance( - scheduler, collections.Iterable): - raise ValueError( - "Need to provide custom training function if using multi-model " - "or multi-scheduler or multi-optimizer training.") - - batch_time = AverageMeter() - data_time = AverageMeter() - losses = AverageMeter() - - timers = {k: TimerStat() for k in ["h2d", "fwd", "grad", "apply"]} - - # switch to train mode - model.train() - - end = time.time() - - for batch_idx, (features, target) in enumerate(train_iterator): - # measure data loading time - data_time.update(time.time() - end) - - # Create non_blocking tensors for distributed training - with timers["h2d"]: - if torch.cuda.is_available(): - features = features.cuda(non_blocking=True) - target = target.cuda(non_blocking=True) - - # compute output - with timers["fwd"]: - output = model(features) - loss = criterion(output, target) - - # measure accuracy and record loss - losses.update(loss.item(), features.size(0)) - - with timers["grad"]: - # compute gradients in a backward pass - optimizer.zero_grad() - - if config.get(USE_FP16): - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() - - with timers["apply"]: - # Call step of optimizer to update model params - optimizer.step() - - if scheduler and config.get(SCHEDULER_STEP) == SCHEDULER_STEP_BATCH: - scheduler.step() - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - if config.get(TEST_MODE) and batch_idx == 0: - break - - if scheduler and config.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH: - scheduler.step() - - stats = { - "batch_time": batch_time.avg, - BATCH_COUNT: batch_idx + 1, - "train_loss": losses.avg, - "data_time": data_time.avg, - } - stats.update({k: t.mean for k, t in timers.items()}) - return stats - - -def validate(config, model, val_iterator, criterion, scheduler=None): - """Runs one standard validation pass over the val_iterator. - - This function automatically measures timing for various operations such - as host to device transfer and processing time for the batch. - - It also automatically detects and places the data on the given GPU device - if available. - - Raises: - ValueError if multiple models/schedulers are provided. You - are expected to have a custom validation function if you wish - to use multiple models/schedulers. - - Args: - config: (dict): A user configuration provided into the Trainer - constructor. - model: The model as created by the model_creator. - train_iterator: An iterator created from the DataLoader which - wraps the provided Dataset. - criterion: The loss object created by the loss_creator. - scheduler (optional): The torch.optim.lr_scheduler object - as created by the scheduler_creator. By default, - this is not used in this function. - - Returns: - A dict of metrics from the evaluation. - """ - - if isinstance(model, collections.Iterable) or isinstance( - scheduler, collections.Iterable): - raise ValueError( - "Need to provide custom validation function if using multi-model " - "or multi-scheduler training.") - batch_time = AverageMeter() - losses = AverageMeter() - - # switch to evaluate mode - model.eval() - correct = 0 - total = 0 - batch_idx = 0 - with torch.no_grad(): - end = time.time() - for batch_idx, (features, target) in enumerate(val_iterator): - if torch.cuda.is_available(): - features = features.cuda(non_blocking=True) - target = target.cuda(non_blocking=True) - - # compute output - output = model(features) - loss = criterion(output, target) - _, predicted = torch.max(output.data, 1) - total += target.size(0) - correct += (predicted == target).sum().item() - - # measure accuracy and record loss - losses.update(loss.item(), features.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - if config.get(TEST_MODE) and batch_idx == 0: - break - - stats = { - BATCH_COUNT: batch_idx + 1, - "batch_time": batch_time.avg, - "validation_loss": losses.avg, - "mean_accuracy": correct / total, - "mean_loss": losses.sum / total, - } - return stats - - -class AverageMeter: - """Computes and stores the average and current value.""" - - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count diff --git a/python/ray/util/sgd/tests/test_pytorch.py b/python/ray/util/sgd/tests/test_pytorch.py index 08eb28172..d70f2e8c8 100644 --- a/python/ray/util/sgd/tests/test_pytorch.py +++ b/python/ray/util/sgd/tests/test_pytorch.py @@ -10,28 +10,34 @@ import torch.distributed as dist import ray from ray import tune -from ray.tests.conftest import ray_start_2_cpus # noqa: F401 from ray.util.sgd.pytorch import PyTorchTrainer, PyTorchTrainable -from ray.util.sgd.pytorch.utils import (train, BATCH_COUNT, TEST_MODE, - SCHEDULER_STEP) +from ray.util.sgd.pytorch.training_operator import _TestingOperator +from ray.util.sgd.pytorch.constants import BATCH_COUNT, SCHEDULER_STEP from ray.util.sgd.utils import check_for_failure from ray.util.sgd.pytorch.examples.train_example import ( model_creator, optimizer_creator, data_creator, LinearDataset) -def test_test_mode(ray_start_2_cpus): # noqa: F811 +@pytest.fixture +def ray_start_2_cpus(): + address_info = ray.init(num_cpus=2) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +def test_single_step(ray_start_2_cpus): # noqa: F811 trainer = PyTorchTrainer( model_creator, data_creator, optimizer_creator, loss_creator=lambda config: nn.MSELoss(), - config={TEST_MODE: True}, num_replicas=1) - metrics = trainer.train() + metrics = trainer.train(num_steps=1) assert metrics[BATCH_COUNT] == 1 - val_metrics = trainer.validate() + val_metrics = trainer.validate(num_steps=1) assert val_metrics[BATCH_COUNT] == 1 @@ -45,29 +51,51 @@ def test_train(ray_start_2_cpus, num_replicas): # noqa: F811 loss_creator=lambda config: nn.MSELoss(), num_replicas=num_replicas) for i in range(3): - train_loss1 = trainer.train()["train_loss"] - validation_loss1 = trainer.validate()["validation_loss"] + train_loss1 = trainer.train()["mean_train_loss"] + validation_loss1 = trainer.validate()["mean_validation_loss"] for i in range(3): - train_loss2 = trainer.train()["train_loss"] - validation_loss2 = trainer.validate()["validation_loss"] + train_loss2 = trainer.train()["mean_train_loss"] + validation_loss2 = trainer.validate()["mean_validation_loss"] - print(train_loss1, train_loss2) - print(validation_loss1, validation_loss2) - - assert train_loss2 <= train_loss1 - assert validation_loss2 <= validation_loss1 + assert train_loss2 <= train_loss1, (train_loss2, train_loss1) + assert validation_loss2 <= validation_loss1, (validation_loss2, + validation_loss1) @pytest.mark.parametrize("num_replicas", [1, 2] if dist.is_available() else [1]) -def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811 - def custom_train(config, models, dataloader, criterion, optimizers, - **kwargs): +def test_multi_model(ray_start_2_cpus, num_replicas): + def train(*, model=None, criterion=None, optimizer=None, dataloader=None): + model.train() + train_loss = 0 + correct = 0 + total = 0 + for batch_idx, (inputs, targets) in enumerate(dataloader): + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + train_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + return { + "accuracy": correct / total, + "train_loss": train_loss / (batch_idx + 1) + } + + def train_epoch(self, iterator, info): result = {} - for i, (model, optimizer) in enumerate(zip(models, optimizers)): - result["model_{}".format(i)] = train(config, model, dataloader, - criterion, optimizer) + for i, (model, optimizer) in enumerate( + zip(self.models, self.optimizers)): + result["model_{}".format(i)] = train( + model=model, + criterion=self.criterion, + optimizer=optimizer, + dataloader=iterator) return result def multi_model_creator(config): @@ -84,7 +112,8 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811 data_creator, multi_optimizer_creator, loss_creator=lambda config: nn.MSELoss(), - train_function=custom_train, + config={"custom_func": train_epoch}, + training_operator_cls=_TestingOperator, num_replicas=num_replicas) trainer1.train() @@ -100,6 +129,8 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811 data_creator, multi_optimizer_creator, loss_creator=lambda config: nn.MSELoss(), + config={"custom_func": train_epoch}, + training_operator_cls=_TestingOperator, num_replicas=num_replicas) trainer2.restore(filename) @@ -123,16 +154,17 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811 @pytest.mark.parametrize("num_replicas", [1, 2] if dist.is_available() else [1]) def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811 - def custom_train(config, model, dataloader, criterion, optimizer, - scheduler): - if config.get("models", 1) > 1: - assert len(model) == config["models"], config + def train_epoch(self, iterator, info): + if self.config.get("models", 1) > 1: + assert len(self.models) == self.config["models"], self.config - if config.get("optimizers", 1) > 1: - assert len(optimizer) == config["optimizers"], config + if self.config.get("optimizers", 1) > 1: + assert len( + self.optimizers) == self.config["optimizers"], self.config - if config.get("schedulers", 1) > 1: - assert len(scheduler) == config["schedulers"], config + if self.config.get("schedulers", 1) > 1: + assert len( + self.schedulers) == self.config["schedulers"], self.config return {"done": 1} def multi_model_creator(config): @@ -167,12 +199,13 @@ def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811 multi_optimizer_creator, loss_creator=nn.MSELoss, scheduler_creator=multi_scheduler_creator, - train_function=custom_train, + training_operator_cls=_TestingOperator, num_replicas=num_replicas, config={ "models": model_count, "optimizers": optimizer_count, - "schedulers": scheduler_count + "schedulers": scheduler_count, + "custom_func": train_epoch }) trainer.train() trainer.shutdown() @@ -180,9 +213,8 @@ def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811 @pytest.mark.parametrize("scheduler_freq", ["epoch", "batch"]) def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811 - def custom_train(config, model, dataloader, criterion, optimizer, - scheduler): - assert config[SCHEDULER_STEP] == scheduler_freq + def train_epoch(self, iterator, info): + assert info[SCHEDULER_STEP] == scheduler_freq return {"done": 1} def scheduler_creator(optimizer, config): @@ -194,18 +226,17 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811 data_creator, optimizer_creator, loss_creator=lambda config: nn.MSELoss(), - scheduler_creator=scheduler_creator) + config={"custom_func": train_epoch}, + training_operator_cls=_TestingOperator, + scheduler_creator=scheduler_creator, + scheduler_step_freq=scheduler_freq) for i in range(3): - trainer.train()["train_loss"] + trainer.train() trainer.shutdown() def test_scheduler_validate(ray_start_2_cpus): # noqa: F811 - def custom_train(config, model, dataloader, criterion, optimizer, - scheduler): - return {"done": 1} - from torch.optim.lr_scheduler import ReduceLROnPlateau trainer = PyTorchTrainer( @@ -213,11 +244,13 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811 data_creator, optimizer_creator, loss_creator=lambda config: nn.MSELoss(), - scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer)) + scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer), + training_operator_cls=_TestingOperator) trainer.update_scheduler(0.5) trainer.update_scheduler(0.5) assert all( - trainer.apply_all_workers(lambda r: r.schedulers[0].last_epoch == 2)) + trainer.apply_all_operators( + lambda op: op.schedulers[0].last_epoch == 2)) trainer.shutdown() @@ -248,13 +281,13 @@ def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811 # checks loss decreasing for every trials for path, df in analysis.trial_dataframes.items(): - train_loss1 = df.loc[0, "train_loss"] - train_loss2 = df.loc[1, "train_loss"] - validation_loss1 = df.loc[0, "validation_loss"] - validation_loss2 = df.loc[1, "validation_loss"] + mean_train_loss1 = df.loc[0, "mean_train_loss"] + mean_train_loss2 = df.loc[1, "mean_train_loss"] + mean_validation_loss1 = df.loc[0, "mean_validation_loss"] + mean_validation_loss2 = df.loc[1, "mean_validation_loss"] - assert train_loss2 <= train_loss1 - assert validation_loss2 <= validation_loss1 + assert mean_train_loss2 <= mean_train_loss1 + assert mean_validation_loss2 <= mean_validation_loss1 @pytest.mark.parametrize("num_replicas", [1, 2] @@ -303,15 +336,17 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811 def single_loader(config): return LinearDataset(2, 5, size=1000000) - def step_with_fail(self): - worker_stats = [w.step.remote() for w in self.workers] + def step_with_fail(self, *args, **kwargs): + worker_stats = [ + w.train_epoch.remote(*args, **kwargs) for w in self.workers + ] if self._num_failures < 3: time.sleep(1) # Make the batch will fail correctly. self.workers[0].__ray_kill__() success = check_for_failure(worker_stats) return success, worker_stats - with patch.object(PyTorchTrainer, "_train_step", step_with_fail): + with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail): trainer1 = PyTorchTrainer( model_creator, single_loader, @@ -331,15 +366,17 @@ def test_resize(ray_start_2_cpus): # noqa: F811 def single_loader(config): return LinearDataset(2, 5, size=1000000) - def step_with_fail(self): - worker_stats = [w.step.remote() for w in self.workers] + def step_with_fail(self, *args, **kwargs): + worker_stats = [ + w.train_epoch.remote(*args, **kwargs) for w in self.workers + ] if self._num_failures < 1: time.sleep(1) # Make the batch will fail correctly. self.workers[0].__ray_kill__() success = check_for_failure(worker_stats) return success, worker_stats - with patch.object(PyTorchTrainer, "_train_step", step_with_fail): + with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail): trainer1 = PyTorchTrainer( model_creator, single_loader, @@ -365,15 +402,17 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811 def single_loader(config): return LinearDataset(2, 5, size=1000000) - def step_with_fail(self): - worker_stats = [w.step.remote() for w in self.workers] + def step_with_fail(self, *args, **kwargs): + worker_stats = [ + w.train_epoch.remote(*args, **kwargs) for w in self.workers + ] if self._num_failures < 2: time.sleep(1) self.workers[0].__ray_kill__() success = check_for_failure(worker_stats) return success, worker_stats - with patch.object(PyTorchTrainer, "_train_step", step_with_fail): + with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail): trainer1 = PyTorchTrainer( model_creator, single_loader, @@ -383,3 +422,9 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811 num_replicas=2) trainer1.train(max_retries=2) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/sgd/tests/test_pytorch_runner.py b/python/ray/util/sgd/tests/test_pytorch_runner.py index e341194a6..cd7a262c1 100644 --- a/python/ray/util/sgd/tests/test_pytorch_runner.py +++ b/python/ray/util/sgd/tests/test_pytorch_runner.py @@ -4,6 +4,7 @@ import torch.nn as nn import unittest from unittest.mock import MagicMock +from ray.util.sgd.pytorch.training_operator import TrainingOperator from ray.util.sgd.pytorch.pytorch_runner import PyTorchRunner @@ -46,39 +47,55 @@ def create_dataloaders(config): class TestPyTorchRunner(unittest.TestCase): def testValidate(self): - mock_function = MagicMock(returns=dict(mean_accuracy=10)) + class MockOperator(TrainingOperator): + def setup(self, config): + self.train_epoch = MagicMock(returns=dict(mean_accuracy=10)) + self.validate = MagicMock(returns=dict(mean_accuracy=10)) + runner = PyTorchRunner( model_creator, create_dataloaders, optimizer_creator, loss_creator, - validation_function=mock_function) + training_operator_cls=MockOperator) runner.setup() - runner.step() - runner.step() - runner.step() - self.assertEqual(mock_function.call_count, 0) + runner.train_epoch() + runner.train_epoch() + runner.train_epoch() + self.assertEqual(runner.training_operator.validate.call_count, 0) runner.validate() - self.assertTrue(mock_function.called) + self.assertTrue(runner.training_operator.validate.called) self.assertEqual(runner.stats()["epoch"], 3) - def testStep(self): - mock_function = MagicMock(return_value=dict(mean_accuracy=10)) + def testtrain_epoch(self): + class MockOperator(TrainingOperator): + def setup(self, config): + self.count = 0 + + def train_epoch(self, *args, **kwargs): + self.count += 1 + return {"count": self.count} + runner = PyTorchRunner( model_creator, create_dataloaders, optimizer_creator, loss_creator, - train_function=mock_function) + training_operator_cls=MockOperator) runner.setup() - runner.step() - runner.step() - result = runner.step() - self.assertEqual(mock_function.call_count, 3) - self.assertEqual(result["epoch"], 3) + runner.train_epoch(num_steps=1) + runner.train_epoch(num_steps=1) + result = runner.train_epoch() + self.assertEqual(runner.training_operator.count, 3) + self.assertEqual(result["count"], 3) self.assertEqual(runner.stats()["epoch"], 3) def testGivens(self): + class MockOperator(TrainingOperator): + def setup(self, config): + self.train_epoch = MagicMock(returns=dict(mean_accuracy=10)) + self.validate = MagicMock(returns=dict(mean_accuracy=10)) + def three_model_creator(config): return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1) @@ -88,8 +105,12 @@ class TestPyTorchRunner(unittest.TestCase): ] return opts[0], opts[1], opts[2] - runner = PyTorchRunner(three_model_creator, single_loader, - three_optimizer_creator, loss_creator) + runner = PyTorchRunner( + three_model_creator, + single_loader, + three_optimizer_creator, + loss_creator, + training_operator_cls=MockOperator) runner.setup() self.assertEqual(len(runner.given_models), 3) @@ -121,7 +142,7 @@ class TestPyTorchRunner(unittest.TestCase): runner = PyTorchRunner(model_creator, single_loader, optimizer_creator, loss_creator) runner.setup() - runner.step() + runner.train_epoch() with self.assertRaises(ValueError): runner.validate() @@ -132,7 +153,7 @@ class TestPyTorchRunner(unittest.TestCase): optimizer_creator, loss_creator=nn.MSELoss) runner.setup() - runner.step() + runner.train_epoch() def testMultiModel(self): def multi_model_creator(config): @@ -146,6 +167,6 @@ class TestPyTorchRunner(unittest.TestCase): runner = PyTorchRunner(multi_model_creator, single_loader, multi_optimizer_creator, loss_creator) - runner.setup() + with self.assertRaises(ValueError): - runner.step() + runner.setup() diff --git a/python/ray/util/sgd/utils.py b/python/ray/util/sgd/utils.py index fc5d3ce8f..5e2f14087 100644 --- a/python/ray/util/sgd/utils.py +++ b/python/ray/util/sgd/utils.py @@ -149,3 +149,11 @@ def check_for_failure(remote_values): except RayActorError as exc: logger.exception(str(exc)) return False + + +def override(interface_class): + def overrider(method): + assert (method.__name__ in dir(interface_class)) + return method + + return overrider