From e0b6984dce2220f0c447dbdfac696375f0c03d51 Mon Sep 17 00:00:00 2001 From: krfricke Date: Tue, 30 Jun 2020 01:52:07 +0200 Subject: [PATCH] [tune] pytorch lightning template and walkthrough (#9151) Co-authored-by: Kai Fricke --- ci/jenkins_tests/run_tune_tests.sh | 4 + doc/source/images/pytorch_lightning_full.png | Bin 0 -> 5365 bytes doc/source/images/pytorch_lightning_small.png | Bin 0 -> 5740 bytes doc/source/tune/_tutorials/overview.rst | 7 + .../_tutorials/tune-pytorch-lightning.rst | 297 ++++++++++++++++++ docker/tune_test/requirements.txt | 1 + .../tune/examples/mnist_pytorch_lightning.py | 254 +++++++++++++++ 7 files changed, 563 insertions(+) create mode 100644 doc/source/images/pytorch_lightning_full.png create mode 100644 doc/source/images/pytorch_lightning_small.png create mode 100644 doc/source/tune/_tutorials/tune-pytorch-lightning.rst create mode 100644 python/ray/tune/examples/mnist_pytorch_lightning.py diff --git a/ci/jenkins_tests/run_tune_tests.sh b/ci/jenkins_tests/run_tune_tests.sh index 00d792f87..8783a10eb 100755 --- a/ci/jenkins_tests/run_tune_tests.sh +++ b/ci/jenkins_tests/run_tune_tests.sh @@ -112,6 +112,10 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \ python /ray/python/ray/tune/examples/mnist_pytorch.py --smoke-test +$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \ + python /ray/python/ray/tune/examples/mnist_pytorch_lightning.py \ + --smoke-test + $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \ python /ray/python/ray/tune/examples/mnist_pytorch_trainable.py \ --smoke-test diff --git a/doc/source/images/pytorch_lightning_full.png b/doc/source/images/pytorch_lightning_full.png new file mode 100644 index 0000000000000000000000000000000000000000..86c781c97fb9e48d68a3c05c4da71114eab0625c GIT binary patch literal 5365 zcmVpF8FWQhbW?9;ba!ELWdL_~cP?peYja~^aAhuUa%Y?FJQ@H16oN@a zK~#90?VWpg6x9{Le`jWsAORIbKy523P%G9)TY3AewN!kd7*r_9EVNi9*#NdmtF5Wh z+ExUz8&IBEKv9SyNL#fkAXcjswZ0V<0R@7*oeiWDhY7D+c6r+m9P4O~*BNRe!kYBa9Dqc;z{nnSyj+9b2t6e&_PBRt7!Xz-nV z$Qs|54exE^>(YJ6YBt4Dz`{R9VW5F7FDC!~rk3YuVBx=0;c`Kr36ejlspV-&B*kdl zFuo5r+}RfdUR$~r)nv3!(H zbL&^hkSR%=DR8+$qsTOZz{wD5Wgz=!F=lIH^pOBPoYyStn_7^TMuUw;*01|={T+Qk zz_yjwpI(~`x#B_y&#^IoQd6Eu1(|rT@PDP~Y6D|o8tDsXa7#s^{e`htgV`~0i6gRw z9tUo2Xge*1FuG7M`z!JkaDBY>a_|oW9iT{H9=_F$C}OjufCA5Y+T9J(8O(wpqf47w zCTNIHLtw~V{TO^lAH-cG-T3}CHm}{8B+-&aM?tPqbbdKy{~1Z)*-b4_in__BsJp;? zL4In$(x#TDDOgW?FUUQ>ElRlAOra7RX$ahSS3j=%W$!AtUt4nYMj6$3#UPv=eItal zF_4pizF-UkJsmhJjI10<(f4PQyP>|Fm>4~4o1WO1F94@0JZ2KUJ65?0TBV#mbrL=V z5(3f`{x9_)%_*mX5zy#qX(MZttj z+7|&o2TllK-wzC}Z$EVz8Vi2~fWV#@@}!89dE{@j2wWlPr6%V5rj{ob7XDuWhKccb zA`hb4C?@Z5i$I1V9Yg5D$!xnJ)Z1v>^t=8H8r$oL`y1Zd#+FZaB}J%YQ$j&`0R^|M zr~So1cY$jxidzNo9%^dYQp6@m-b!F)6N^zR3V#&vgb1$zl4>+SGzh`uw%Ksv_GyT~ z&G%f$pmDvRhCuIsJ@s)ia1gZ6eRE6NM05N>8&~80zo?oWZ1|?q0JldOwBAN zcM;Og2D@z-Ep3D6-_sb~Y#eeCvMKzQ0cV48yrR36(0}GJZc`#{jbQTWVLWFGc56Yl zgeY4xpI_}yro3MKXDOVis7EosILP)lN+>9ANSmXWo_2_`^9}TLka81UI~(8oAaO-a zRo7t}ygduwmo=4*+i1MId;=HV)T45FH5$VZa%WzNh&0=2%1DY%4Pg%vK;Z=g&zBZ{ z1oDdU71dVP`Epl$38X3d$XxP=H$)Z-|JA^FsOXJ5+EWBYX-(<&917>Bp?|I(2xZ~F z5Lg8q2;$4MDC`C@IZSw<^e)?Br|9(78d>&fy+RDA0)q^YOab!(G8&Ef$zEZ5J2## zMPM~BK1kk5i@+tID*!u)ug2A{9Hv|(828)ISE4j!l5x<&|4>`_O#ok+MPMlCqt)vR z@`8mw`qo;YEpQUZtwHi0&Zh7>11V5A$E|~4lS%vC7XH6;!0kcumej0a6)%Co3gdzAN9!;I zY7tm%WB+;{`76BmFILyT6!-MxpCuns^2b-u7~!6tG?UyJSJHRznKkA&9V%V|KMteI z1>7l#t^~dYQ1I5(f=R_tAkD2iGf_!KF}<}H|3cttps&-tye>}A)3d-`$_cd<-{st@o-Up6VbQFX?^5VZZ zx{n+RM+x>S;3D8{&?j*|Ec<~Bf!s_y%d%M1;Kc<~oX9iMZHK)ztC0D23&<-BHH1snv90mkixCqe6bY{Omv!EWQlKhZ!IE1U(q zrtpNozkmwhIN%STJ1cy(Vjz2>u1{P~`{lr&6xu4hfb$-^BTYRWbg@O@P1W1UV$v}d zfrSbsR2qyUb!Z^VEdu^9TB*Z&t)47?;5Ag0A$RpdY8Y~uX$GSo{s|(&D^Gk_pQam1 z7YqNLk%Ty3@Z%M}0fF8Mr$&`sb{f0>yq*&}f#w0<8-f4aIr!d=)wn>m^X!oigXF&q zW9ltn7K_YsN}mJ<;24V{58*#Vk0!EMWM;DUYS4+m8G=4-QFLXH+~)zp80C?7&O`&i zNx*ayhoMHHn zO@kT0jld5xDIHFec=lDw2|Ow=$i4Bb!0pBOK8P-3G4*oAo&@+n&yUrbz=fb*U`?2? zHJ`DSAKal#YnM)F4o+VKbMR;FXV(*HbbKaKCXLWn0vSfaBDXHU~SQ_2K)Bg9OL$3Omo`RzogUZ_cOTVK%|_PDZLb9 zDbNX|XcW_*KAf-8hN4Fmz0NKE55*Wam)!TO>k#Y*LGuKu)M33NLZ6!WIt%Z6CFcGT zGtJ<(hX&9;>pS&rI$;F50)LA9Dcq-UJI)6@Vleg!{B}P(Zz!N3BvI?$3L|2*oz0YU z6e0?n8SfmiI+>LI631Ym$j^i1`KnJpN*EFjlDEXh=moqZ;8E-aqnLi2Q!gCxL$n44 z3VKHn-!CEtVgU-TP?0PGmjK0+k20$_SuL@iwaBo}W3W10Zr5ms@o%Bnh4FUF2^ExG8OgHZ2Inwqktu1n`wb-xrKAvFZizl_|2o zO@6!}pu82}W@?tr7oiC# z`<~{Iw;6a2z^Fs7J@_E{TlaYe8TZC2yPr`~aYU0t1NX@Xd5GM%S%F&yCSK zOnAfE(n+gP8*aX*f6Y@BkHM|e25@L!1#ABPX`f50qj;|RQ!`GsoyzuSYb#(-`KdiFi-Poqdj0D|8@BXPsYT&D zklSJu`J$NI6_6LR4W;r7$`17uIl`;68C~m_K#>mx9i;W^Y~f!EbO7xD zu#OJB4D3?@Y;;0A8#o;u`sV_pp@E0+JUwMKj}^Qe(>ZE-46eNGdo|`Y2iiJLzHM&u zh$+hFwEr>C*-04*J3tlsxR+aljEU(xIkKMq-Mu_(6cej_m`tR~kldoUmBM*$y%!Rc z^uxNoh&Fa)s7@;~(-P|^F%){c|3mw!crM<0wlngKMV-r%f<#_(S^~~cK;NyagdEmu zdNSJu-Qu(_X#ge~*)SGvMXR`VDCqN%>jt0P_1B^`?Adl8X8x|yq%LtMl>ojnL5l$t zuE=K6Dbe}j%;5n(4wCy);yr0+w~sIC;}a3g&m-66WSmqrH&kUiv4_1KsEpM$zDP() zq|VqC6C>u7r?C;}lMKyLjZc4fw7d1#Dyn&fL>9F)Me<4rN+9oGA-}&Pvn) z3ZU-XwyveW>PXZwm!P!r z-Rjc-O-P4|U8%zc8;#zBy4ScDWy`0#xcBNs?Ap@o0tl-x7vE>V^8ku?Bg4-OLN@`e zfHIJ|4e29lc-zHwM8%1Ai(Dw-&9#->%Z}~1t00PYslSvhfyz_1jvdL=uJP7Vq(|e( z(-NuFVOu9SFg|RkJcp)1M#J?*{a7x=M_cpEF$%Jc4*x5o12Egg;S{9l2D(W4< zlyBoY-Nc!`G>CnazX+5%?aOcDai&@LdxKo=%J4`${V4LJdwHqFlnWA-x4H5NW48cL z0thD3WfF(da#>dzJL2 zwL6(`ZObwk5ZN|w&vH_dWn+$Zvp#W_;~GW@;|>ZGxXA%t0(aF`ek4Vw0t;{^ejE2u z5!V4MCI=qEIUHCNV}Wy^Z_NxRa9R%jnUTP;mV!ZeJkAz?W59gIOJUzwWxY-i(@cdr zmu3@14hLm@qp|u(Heuw~z%HPb0e{aS&@Wb53%}(C3|9Ko5{}a_JqA6mIy|j(Mi`5l z^=e6Nc#v5YaF?@JLq`KB(jW(DG{xl2un1fT3{$wzi~qZV{ku7gsodc=is`M(D8D(3 z{YRi9aL_;o$GcK6hdcM?P&gFiS%I!LtoGvnRU23w3F69Pk=dH9g9I%CI+s&+uEN_# z^d)j!8W>7ptB+`TlgOnSOUg62)DA~d=N{~A1Le-^e&JG zYih5d2z;t=z9OTB^T1<~9kDqS^*9{c)x4Nb-sjmA4mOabKxaW$Srk6*?14J0LqT$Y z3xNHgcSLlUu+iw=>tx!W>`YC*{CEcwt|{T`tu51F0L;SorG@`*z_mCduyPwuFnK+Q zcA>!9K>HmEK7}EQF~q|E1+Yn>O*v&}h_gWog)bEb2gzGgpI+vY|Du<`Rf;|adWyn= zy|CXR@Ci_+l%64O0Fl#&{WAk2R^U^z=^2TxjDxK>o9m9Cn7rk~ncUOFoQ1RTZ4f2g(Ht_ zuPL*Df>1HOe4JpfWjI6IZ{qB(JR8Sge4^;xnV8?HzHzC7@({>6kabFW%t*ds@}5hl zyt}}IuCFE&^gPgB0v#M;VLPiy>p<2W!i?)tYQW}5o%sLQ z0BmutjU&#HoX^+~gZTQ{=w+aV3jYK?R#+pjNRZ(N2%R&B{C`GOi?$rwocBs=e6ONW zj6+d!RUt_3LyAlUDjZ|1z1*ovDemjG(K^q;RCKS~{>G~JSqHMtK%z#Q5n1dwj;oBk zw@1cw1j+j-h_64$r8VggGpL=@NdyU!&m9Z$F9lkR^CqxI13|)mZH}J=)C;$l92&Pve<~-%Or7&44T>9V3;$dX_u~Xo_aR0Oz0l zZ+&QwF0KxFGig5n-O_18+D(cc4^)5=Y-)K@)QrW{&VoJ7pyFq<8CU1e`8n|}Tn3=9 zxX}e{0Yr2dMt&a8&TZ#FANiq^lEofDUQ#sXdj8QwR7x)ad*xB?*IDW;`;F@{ezdv# zW@d`2xbM^Rz}aCuUA+`p#pK2uJ-ExBi*OwNt50LcUyquw&6F%A9RuO#6wVQ}TMmWR zT=L(DRo*=Wxdbp2)*PVIvZJv)-z>54Us#Uky>!Cg0DS=XF>p}fyRHu)MKi>rFb`xh zV7ftDOB7w;K|iq3tu~tJ!Sfv(9Skx8XRG;U&`W~kf6}zlHxF_s94$y8t_~e1u`DpK z-;dT|7_JxR$cPMJi;Z!~JaX6mUm6Yf?8oXuTLsMtlHY890x*R`HiefO$Q0aDz3bp5 zg^{TTK~-EGP6p1cyR@z{@r;ZdB=mo2G)6G_bQ|+!;AAIakhg#bg7}s-wQMQk!;Nem zXuv=l?JLkxp}oKkg>44%DzLbi+<&B;+BH)^fw7+U0|mJ%mJSvBsUW`RqILLRt_rap T@bkA@00000NkvXXu0mjf)7=*< literal 0 HcmV?d00001 diff --git a/doc/source/images/pytorch_lightning_small.png b/doc/source/images/pytorch_lightning_small.png new file mode 100644 index 0000000000000000000000000000000000000000..0688f162e7d9593550d82894762059702c9cae75 GIT binary patch literal 5740 zcmZ8lc{tQj_x}!tm>Ste*_FMqq$IKw24!DoLX0J3AN#(ftl6?>%~FzOB-t6+DrHP5 zJA;OieJEpz-}FAeKi=p4<9@%-bIv{I+)$xV zNc*Z;tURUtn0<8b`T@Xg!oL9mAHB||5zqT0Q2u7#&i+C6zW0HkpdgtCUhaO5_CEJz zynS5?HdJ^3fNTG@)-Chk!u6SuaQ)?DrY(Ga&4p8ms#;`rgvmWug{oQvbER9@gnPY0 zB3vV*GzaA-UFY^Xl+BI#tFD8mp7k`!s!7azo#B;}vQoGDui27S(~mz3Lx%m9b|O0u zSo8DK4ApkRmzAz8OYkS*YrH@9#v5A*qKIWxTgD*Sqxc1us(y6$)5k5+`MF#;euA*b zi)yL3HY>{LN8p(TsFd#4$#G$CW`>y$z1=hfVX~DEK53d{&(N}g11gA@LKhBEKyU)| zme!w!D!5$6TR1R{i2q0B#3B4oU3ilQj=rGGz%~&N7JGUQqAzk8mZ$5Gb9lm*@u-E4 z?VmtUqmu>M2mfqs&7P*1#-3GOLhFf zNOUMeax!Hx-g<6^x;sSDoAUVOooQXK8g$eSi3P557Ns1Qd?kbpsBZNI?gB`~+)eHx z@}}kJ;NoMZ_>S$FPoHZ+7ArDvzQ5e5BfEAsGxq|cd=Dbh;y*!v4H$ALD=&pV zC^~ml?ldg#xOe7Tz~I+!^bsnZGpi~aF4c;=n^0mBy(3}Hhl@Y@Nl#~j^}?Gg^mI0= zHf}2OI?40(t2`N?Ejjj{@~Ww^Aw2Z6Q3yui%IfuC6Zu|j^rlPUGbARhKBt7&e5QYf z(nq1gDGq%N(fxR{h3>}XP19mKxayVC;(kPkSczhTn6#?RoiSr`q3J$f)Be9nH8#G! z=!pI@6^rNs(Z86%N&RMHb2Q)cBI8#qvpclC&Y?_3-N3Cg8}+H1NKSKtBy1w);@eEg zd#%)teVxO1W}KD2AaA$7R~@*6O3uVx!g5G*lR4D)d^VxL(R#~%>KgCfbv548NVNDo zw{f`Pk&)^g*qBf}B{-GcstL9wT!(S7D>f9~oWQX+F7yAIVGGRRjt=lD@hm1Lcg|yR zVn{J;@X=aH=tL)FHn%W9y1mIVU-t$ao0xqDuI4^$)Omxx$~f;(scTO|e$d6yA?C~N z+O4%) zPkocMNGQbCF3JIw0pW!e)_f(EIJ%);*sfb8h9gt?u|{_`H{e5wpKGR<8gyaa>KO6G ze)k{?@;ANebp;}o)h#!k$nYiuK7@UYF1<}m3f{8wd6fU?iy%G|TVQ-(#w7%c8*r|j zd28zWWbsq-H+9?DI=XtLg@&pv6ngPf?Y9d;9&3|O6CdLc&~(Aq2c6>A)%0S=n68Pr z7r|oY=k1zO7>TR6!I{wIfzD=rYOuL3&B~zFO9UE4MS7MK$0-XxWAvs#9UC78E8g43 zq>y?j%5EOV&1(;euKDcUm4{<73A7k{w1%zPT7%N`9^>eEH8=iJv^DiZFiRuZPT7l% z{wVae@1{3O=}2{xzUEuI7aSUi)lr6HQLH)051+#9C@IN7v+Pu=lQ;}jIK5l|KKH!i zgG#99rD$}<`44Jv?73FFBIHyWisM2DEocq=KiVHIHf99}XbcY3&eZs#`ktlhCdkSH z6|-&Oxr>bz|2Us&e`}Ml4epaVx8^o?maGRdE?I-FvK7Uy^|DQ3eBC=ul;FNam`@NP zm(VeT4I&jT4I97w1Vu~B)@LE@mE^Nd>M&H4wzTZ9Kvx^{bVLr1j(7w~ZSL(I*S^U> zf*dLYOzn+JS8C z8CmbU1bd1-zV`J*himg~Ebg&m*g++WaaSc**D@#VGs|lF8r^Smt_DxI#(Z z%3)|ny~G_+f<4PZ-*K{aPjX-Z3&cjwZQ^E09#t^xw+W}{P5(T)Pbr&S6CB~g29J%L zLDYjVqar^nQr5zP9^ul^J8O+RCT=u-#NzhN%btv7rF}b$tE#p#@P~!*&4oo4ur$Sx z`6{Gg*9F7xgc;cF>Ye-ns*{HuUKZX}eb@33-@LNL96?vlE&zcQlTTpLLWH|4!Yjim zVV+v?LhB8Ft<|&HQuTUvIl|@yyK#f|72S$|aIx74Fqz^}QW>V?F;I zjW@%q?K6<@VuQp>TaF!D&O|WX;W8b^7{TXbyF_(#>WLy#eR1~}j%|{)11#RPeY=oG7z<()OT1f`o`4dFW!&%xM35N`N6+AB*u~9G?!7M zLVy+-PGb%*^1dIsAoVr zJit7Pe3Getk3}_%N@1W${-|2@>t&L4+Xtf@Bxy{QxesD?U4DlLortnn^7`-LTe0yn zo~`U+D&fw7HZdo0)eDs%Bk4%7J=Rr5!rk3JlQ(W@(&SrD!yC*|L~&z_4uUuulj>&n zP0NLlPYGQvy-d%2)tV&-#RK@A2DO;`>#4xPB$s(A>OnPnp8k>Sc_ z58KUIT@q80!SBA!6zs+A*0){f!r}j^W!^AV4JSTaO73Cjg+gpj1$j;L%D)|WqlGXr zOdc+~QfZRA{GREvyBaIJ;W@3={nwjsl*f&to)H;D-(B4}o^FY3d=_z9z(M{r3`IXC zGkxB#yYJ;chg3e;6cy9;jN9dtpVNjkski*bAx^dX=#q=Xb^-k_?LUP)D%|jf-Ugg< z+tUn%z~C9K5~(C#ZCX(7&hi;6zR|JSvH_YA@fa~Uwku1|#mn%VZBW0jiv;HNt$>T6 zePBFDZ%q!4J^l}m?=D{Q&e~TunFzy45zERejcODN}M}6wgR1|xxaOBZY zU7cUnOa5H>1slp%RBl&rMq+|CiOcBHYH-ezi(?D#dBQJSTYPW)H%ZiKH<@j-|7sqn zuCI;A0*@md$&S6yQLs`F#DNL+{!3rIe(ISN4L0&})UIeZu9o zslU!zs>7r0;&EC=JgQAax^3cDzFbsd@^1d`fk)1rZvD8Q z$S8jXFNgE-L$3s74T{E{XE*_FL_8rMIY(6#W9UYzGTy8s%gHa=i#TF#bnPA0+!kJa z0knuHKN(JfT8&>=Xw^P;1db$?pVVB};7yxqC|Y(g3=jLFl@}!{qvX|lWBwRVjeD3` zdXL7=oZ43H?jMY+DL73GqB+45bFepa$a%@asfH{hzrf;Qv!N4MT%u6BGjJQIOMsR) zTIm^f&hy%G!1iu)fn?L0lE}-tAs^D&! zX0V8+uygN7vu|vW-dXxX^L@Cjpv6qrN{Uq*r*fF(f)bOvdH0hh)Geze>u{7Us_^#D z-PVX^?cToJYa#9<5d8b2_>rr(fzzL^yTqR%G-Z;(>I1Kq*=bE>v2deA86Oz9^SHVj& zN12u1BmT5l+zUUI8Bsj`rFTJ6ITnsW=SxboJ=jPyVZQQb5Wps-mS^8k^qbalNeAn= z?sHjl9D4r59}W*DLd!B%ClvjplA*$oI$AHxFbcn??|(}AC+XEkcb0=y-I<=P5nIUi`aENvCVMr5uktY4JCU3Z{Ky))LQuZ;T$nAX7SAZ?7s64$+ zKRoQ-QHiB>5aE{7qLebgb;pp}_GdOd>>_1pIx=xFZBt2*we^Hw1B+ufGZ)$4=XMuY z))AgQ&~g1l_8vP{=4MgRiqBvRdM;{t`Gewr?&r0vCMjq9Rz>AXV`*s>S?|)-9qZdDV8kK)&rM(+ z<(NM%ecBF;auk(E`##O7OtWRn`NIZ8qBZGNo`}3ukC*k9o|-Bty!Sfw%-_MhtaHG9 zt@z6j7)pZ$TW%UKheRw(L{F=_M+r3ed{7jmlaN(@l26-<)ZH%E3(+si{C;Mx${N4Lp$e4B%(44`+H8o`M`+=nLz#Nv}V2JM{MHz}j zCsHl9X6c9sf|5XQPO-1o=+E>=8Swf=EROOaarSG^nSmqf>vdhLfxH$duQs+LwVzH; zbHY*v>%>Hm6zB;Qk#^Mg_;A3AvK@4c5jByla*~fm)09!h;<$z`orb{U!{*(m4sj@a z?z*d$wuE9$pZr%@236ZKZ-n>7qLALetI`+!eLo!7!|PFSEJO9ER1wQ!ctAQy;Vi^F zdobpV!AF6{wpx=e<-9=(4+M$sHEX7=Ia0@;1qjgza(-m!)g37@%M z-+$%Y28|ex^u8VpbkQv-vHU#2v)1p|mkt4PJSLw6jK(C{UKxP$l z5ng8%3$HQ*S_g}mJpuk!61);XvX z&NSSG(d=@SztpZ^iKn`U8rS(TVSACeolhz4BLh!4x?a>O?nc-*t2ox<^b^Tej4 ziosQS-ozyDqod@N5Q10Em=kTGy8AEnHlibQ4fg#u+c@mERXKE`NeX9DM+UZ+$#|=< zTP55UQ+umMh{$PgcZq0U1_jJ5EZv8@Ymdexh((5~mgyeY9$Dplm^T)S^Qb!m~ z>a0F;x~i*9nOaup$O`0K<$>2rHsA)*6+Ko81V;YsW4-KruODS*NIK#u&tN*Pn#MuM zS~KmT*H9}CZhPW&ic=YW2N7IpdSsUr*P)n#9(h3#Gkdy8WO5!|4hu?IenU^O=1**^ld zl0RpD))oBuOKrl;As&(fq-2zzSYwoql zE&6RdlblO4q9%|6Ecu>;rI24<%8A#tH;2COs(SQb4iQTVu%hW{NI?-MXJ?iRjRrRg z-*|36Pc$~xT1AVn3TU%yf&N(5t&EMP7C6hEFJ{kelUL=5e)*$SSB9o=TXH#@k@%f% zc5UPR^V?>o$QNn!Lwdwh()#r)2Iib=;N2e`$5Re_wSVuk1FS_38BQqlWy9|2rB@Hb zcgRa`4SAgf?x=Aysk!U|sUvj3!R?|T*FwBtHyig;xm9CI>oz)MMR>Z8y%f{D>_z8; zA561N^`MZRUWu160_;3i*r-V|H#0XB!gsbExiwa!-ce}13k?!*7s=2Nevf?T+vaU9+Jvmxw4@Qlp|9Pue(( zTgW4vzy#}YYxRQGiJP8jon6OaOsT8L`O5OWC%izo786Ee?2^S^a>1L%*WJkl3nhMI zp@F!@5|tvqu}MRJxu;AJ__L&TV;_N=0Ch!5O32G0@kOZw?M53EQrcGi9~^REm7a!F zUix3^f5VpoOf+~C`Cr)O!2fgez&DG@XSJNTztc&?f@mEMGaFYzwZ{2BQPujUo^yW; RnQ2Hda2sK$Ri}Y|_` +.. customgalleryitem:: + :tooltip: Tuning PyTorch Lightning modules + :figure: /images/pytorch_lightning_small.png + :description: :doc:`Tuning PyTorch Lightning modules ` + + .. raw:: html diff --git a/doc/source/tune/_tutorials/tune-pytorch-lightning.rst b/doc/source/tune/_tutorials/tune-pytorch-lightning.rst new file mode 100644 index 000000000..0a3845c97 --- /dev/null +++ b/doc/source/tune/_tutorials/tune-pytorch-lightning.rst @@ -0,0 +1,297 @@ +.. _tune-pytorch-lightning: + +Using PyTorch Lightning with Tune +================================= + +PyTorch Lightning is a framework which brings structure into training PyTorch models. It +aims to avoid boilerplate code, so you don't have to write the same training +loops all over again when building a new model. + +.. image:: /images/pytorch_lightning_full.png + +The main abstraction of PyTorch Lightning is the ``LightningModule`` class, which +should be extended by your application. There is `a great post on how to transfer +your models from vanilla PyTorch to Lightning `_. + +The class structure of PyTorch Lightning makes it very easy to define and tune model +parameters. This tutorial will show you how to use Tune to find the best set of +parameters for your application on the example of training a MNIST classifier. Notably, +the ``LightningModule`` does not have to be altered at all for this - so you can +use it plug and play for your existing models, assuming their parameters are configurable! + +.. note:: + + To run this example, you will need to install the following: + + .. code-block:: bash + + $ pip install ray torch torchvision pytorch-lightning + +.. contents:: + :local: + :backlinks: none + +PyTorch Lightning classifier for MNIST +-------------------------------------- +Let's first start with the basic PyTorch Lightning implementation of an MNIST classifier. +This classifier does not include any tuning code at this point. + +Our example builds on the MNIST example from the `blog post we talked about +earlier `_. + +First, we run some imports: + +.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py + :language: python + :start-after: __import_lightning_begin__ + :end-before: __import_lightning_end__ + +And then there is the Lightning model adapted from the blog post. +Note that we left out the test set validation and made the model parameters +configurable through a ``config`` dict that is passed on initialization. +Also, we specify a ``data_dir`` where the MNIST data will be stored. +Lastly, we added a new metric, the validation accuracy, to the logs. + +.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py + :language: python + :start-after: __lightning_begin__ + :end-before: __lightning_end__ + +And that's it! You can now run ``train_mnist(config)`` to train the classifier, e.g. +like so: + +.. code-block:: python + + config = { + "layer_1_size": 128, + "layer_2_size": 256, + "lr": 1e-3, + "batch_size": 64 + } + train_mnist(config) + +Tuning the model parameters +--------------------------- +The parameters above should give you a good accuracy of over 90% already. However, +we might improve on this simply by changing some of the hyperparameters. For instance, +maybe we get an even higher accuracy if we used a larger batch size. + +Instead of guessing the parameter values, let's use Tune to systematically try out +parameter combinations and find the best performing set. + +First, we need some additional imports: + +.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py + :language: python + :start-after: __import_tune_begin__ + :end-before: __import_tune_end__ + +Talking to Tune with a PyTorch Lightning callback +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +PyTorch Lightning introduced `Callbacks `_ +that can be used to plug custom functions into the training loop. This way the original +``LightningModule`` does not have to be altered at all. Also, we could use the same +callback for multiple modules. + +The callback just reports some metrics back to Tune after each validation epoch: + +.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py + :language: python + :start-after: __tune_callback_begin__ + :end-before: __tune_callback_end__ + +Adding the Tune training function +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Then we specify our training function. Note that we added the ``data_dir`` as a config +parameter here, even though it should not be tuned. We just need to specify it to avoid +that each training run downloads the full MNIST dataset. Instead, we want to access +a shared data location. + +.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py + :language: python + :start-after: __tune_train_begin__ + :end-before: __tune_train_end__ + +Sharing the data +~~~~~~~~~~~~~~~~ + +All our trials are using the MNIST data. To avoid that each training instance downloads +their own MNIST dataset, we download it once and share the ``data_dir`` between runs. + +.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py + :language: python + :start-after: __tune_asha_begin__ + :end-before: __tune_asha_end__ + :lines: 2-3 + :dedent: 4 + +We also delete this data after training to avoid filling up our disk or memory space. + +.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py + :language: python + :start-after: __tune_asha_begin__ + :end-before: __tune_asha_end__ + :lines: 27 + :dedent: 4 + +Configuring the search space +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Now we configure the parameter search space. We would like to choose between three +different layer and batch sizes. The learning rate should be sampled uniformly between +``0.0001`` and ``0.1``. The ``tune.loguniform()`` function is syntactic sugar to make +sampling between these different orders of magnitude easier, specifically +we are able to also sample small values. + +.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py + :language: python + :start-after: __tune_asha_begin__ + :end-before: __tune_asha_end__ + :lines: 4-10 + :dedent: 4 + +Selecting a scheduler +~~~~~~~~~~~~~~~~~~~~~ + +In this example, we use an `Asynchronous Hyperband `_ +scheduler. This scheduler decides at each iteration which trials are likely to perform +badly, and stops these trials. This way we don't waste any resources on bad hyperparameter +configurations. + +.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py + :language: python + :start-after: __tune_asha_begin__ + :end-before: __tune_asha_end__ + :lines: 11-16 + :dedent: 4 + + +Changing the CLI output +~~~~~~~~~~~~~~~~~~~~~~~ + +We instantiate a ``CLIReporter`` to specify which metrics we would like to see in our +output tables in the command line. If we didn't specify this, Tune would print all +hyperparameters by default, but since ``data_dir`` is not a real hyperparameter, we +can avoid printing it by omitting it in the ``parameter_columns`` parameter. + +.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py + :language: python + :start-after: __tune_asha_begin__ + :end-before: __tune_asha_end__ + :lines: 17-19 + :dedent: 4 + +Putting it together +~~~~~~~~~~~~~~~~~~~ + +Lastly, we need to start Tune with ``tune.run()``. + +The full code looks like this: + +.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py + :language: python + :start-after: __tune_asha_begin__ + :end-before: __tune_asha_end__ + + +In the example above, Tune runs 10 trials with different hyperparameter configurations. +An example output could look like so: + +.. code-block:: + :emphasize-lines: 12 + + +------------------------------+------------+-------+----------------+----------------+-------------+--------------+----------+-----------------+----------------------+ + | Trial name | status | loc | layer_1_size | layer_2_size | lr | batch_size | loss | mean_accuracy | training_iteration | + |------------------------------+------------+-------+----------------+----------------+-------------+--------------+----------+-----------------+----------------------| + | train_mnist_tune_63ecc_00000 | TERMINATED | | 128 | 64 | 0.00121197 | 128 | 0.120173 | 0.972461 | 10 | + | train_mnist_tune_63ecc_00001 | TERMINATED | | 64 | 128 | 0.0301395 | 128 | 0.454836 | 0.868164 | 4 | + | train_mnist_tune_63ecc_00002 | TERMINATED | | 64 | 128 | 0.0432097 | 128 | 0.718396 | 0.718359 | 1 | + | train_mnist_tune_63ecc_00003 | TERMINATED | | 32 | 128 | 0.000294669 | 32 | 0.111475 | 0.965764 | 10 | + | train_mnist_tune_63ecc_00004 | TERMINATED | | 32 | 256 | 0.000386664 | 64 | 0.133538 | 0.960839 | 8 | + | train_mnist_tune_63ecc_00005 | TERMINATED | | 128 | 128 | 0.0837395 | 32 | 2.32628 | 0.0991242 | 1 | + | train_mnist_tune_63ecc_00006 | TERMINATED | | 64 | 128 | 0.000158761 | 128 | 0.134595 | 0.959766 | 10 | + | train_mnist_tune_63ecc_00007 | TERMINATED | | 64 | 64 | 0.000672126 | 64 | 0.118182 | 0.972903 | 10 | + | train_mnist_tune_63ecc_00008 | TERMINATED | | 128 | 64 | 0.000502428 | 32 | 0.11082 | 0.975518 | 10 | + | train_mnist_tune_63ecc_00009 | TERMINATED | | 64 | 256 | 0.00112894 | 32 | 0.13472 | 0.971935 | 8 | + +------------------------------+------------+-------+----------------+----------------+-------------+--------------+----------+-----------------+----------------------+ + +As you can see in the ``training_iteration`` column, trials with a high loss +(and low accuracy) have been terminated early. The best performing trial used +``layer_1_size=128``, ``layer_2_size=64``, ``lr=0.000502428`` and +``batch_size=32``. + +Using Population Based Training to find the best parameters +----------------------------------------------------------- +The ``ASHAScheduler`` terminates those trials early that show bad performance. +Sometimes, this stops trials that would get better after more training steps, +and which might eventually even show better performance than other configurations. + +Another popular method for hyperparameter tuning, called +`Population Based Training `_, +instead perturbs hyperparameters during the training run. Tune implements PBT, and +we only need to make some slight adjustments to our code. + +Adding checkpoints to the PyTorch Lightning module +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +First, we need to introduce +another callback to save model checkpoints: + +.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py + :language: python + :start-after: __tune_checkpoint_callback_begin__ + :end-before: __tune_checkpoint_callback_end__ + +We also include checkpoint loading in our training function: + +.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py + :language: python + :start-after: __tune_train_checkpoint_begin__ + :end-before: __tune_train_checkpoint_end__ + + +Configuring and running Population Based Training +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We need to call Tune slightly differently: + +.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py + :language: python + :start-after: __tune_pbt_begin__ + :end-before: __tune_pbt_end__ + +Instead of passing tune parameters to the ``config`` dict, we start +with fixed values, though we are also able to sample some of them, like the +layer sizes. Additionally, we have to tell PBT how to perturb the hyperparameters. +Note that the layer sizes are not tuned right here. This is because we cannot simply +change layer sizes during a training run - which is what would happen in PBT. + +An example output could look like this: + +.. code-block:: + + +-----------------------------------------+------------+-------+----------------+----------------+-----------+--------------+-----------+-----------------+----------------------+ + | Trial name | status | loc | layer_1_size | layer_2_size | lr | batch_size | loss | mean_accuracy | training_iteration | + |-----------------------------------------+------------+-------+----------------+----------------+-----------+--------------+-----------+-----------------+----------------------| + | train_mnist_tune_checkpoint_85489_00000 | TERMINATED | | 128 | 128 | 0.001 | 64 | 0.108734 | 0.973101 | 10 | + | train_mnist_tune_checkpoint_85489_00001 | TERMINATED | | 128 | 128 | 0.001 | 64 | 0.093577 | 0.978639 | 10 | + | train_mnist_tune_checkpoint_85489_00002 | TERMINATED | | 128 | 256 | 0.0008 | 32 | 0.0922348 | 0.979299 | 10 | + | train_mnist_tune_checkpoint_85489_00003 | TERMINATED | | 64 | 256 | 0.001 | 64 | 0.124648 | 0.973892 | 10 | + | train_mnist_tune_checkpoint_85489_00004 | TERMINATED | | 128 | 64 | 0.001 | 64 | 0.101717 | 0.975079 | 10 | + | train_mnist_tune_checkpoint_85489_00005 | TERMINATED | | 64 | 64 | 0.001 | 64 | 0.121467 | 0.969146 | 10 | + | train_mnist_tune_checkpoint_85489_00006 | TERMINATED | | 128 | 256 | 0.00064 | 32 | 0.053446 | 0.987062 | 10 | + | train_mnist_tune_checkpoint_85489_00007 | TERMINATED | | 128 | 256 | 0.001 | 64 | 0.129804 | 0.973497 | 10 | + | train_mnist_tune_checkpoint_85489_00008 | TERMINATED | | 64 | 256 | 0.0285125 | 128 | 0.363236 | 0.913867 | 10 | + | train_mnist_tune_checkpoint_85489_00009 | TERMINATED | | 32 | 256 | 0.001 | 64 | 0.150946 | 0.964201 | 10 | + +-----------------------------------------+------------+-------+----------------+----------------+-----------+--------------+-----------+-----------------+----------------------+ + +As you can see, each sample ran the full number of 10 iterations. +All trials ended with quite good parameter combinations and showed relatively good performances. +In some runs, the parameters have been perturbed. And the best configuration even reached a +mean validation accuracy of ``0.987062``! + +In summary, PyTorch Lightning Modules are easy to extend to use with Tune. It just took +us writing one or two callbacks and a small wrapper function to get great performing +parameter configurations. diff --git a/docker/tune_test/requirements.txt b/docker/tune_test/requirements.txt index 28479c4cb..a2603736a 100644 --- a/docker/tune_test/requirements.txt +++ b/docker/tune_test/requirements.txt @@ -18,6 +18,7 @@ opencv-python-headless pandas pytest-remotedata>=0.3.1 pytest-timeout +pytorch-lightning scikit-learn==0.22.2 scikit-optimize sigopt diff --git a/python/ray/tune/examples/mnist_pytorch_lightning.py b/python/ray/tune/examples/mnist_pytorch_lightning.py new file mode 100644 index 000000000..2ad9cea5e --- /dev/null +++ b/python/ray/tune/examples/mnist_pytorch_lightning.py @@ -0,0 +1,254 @@ +# flake8: noqa +# yapf: disable + +# __import_lightning_begin__ +import torch +import pytorch_lightning as pl +from torch.utils.data import DataLoader, random_split +from torch.nn import functional as F +from torchvision.datasets import MNIST +from torchvision import transforms +import os +# __import_lightning_end__ + +# __import_tune_begin__ +import shutil +from tempfile import mkdtemp +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities.cloud_io import load as pl_load +from ray import tune +from ray.tune import CLIReporter +from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining +# __import_tune_end__ + + +# __lightning_begin__ +class LightningMNISTClassifier(pl.LightningModule): + """ + This has been adapted from + https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09 + """ + + def __init__(self, config, data_dir=None): + super(LightningMNISTClassifier, self).__init__() + + self.data_dir = data_dir or os.getcwd() + + self.layer_1_size = config["layer_1_size"] + self.layer_2_size = config["layer_2_size"] + self.lr = config["lr"] + self.batch_size = config["batch_size"] + + # mnist images are (1, 28, 28) (channels, width, height) + self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_size) + self.layer_2 = torch.nn.Linear(self.layer_1_size, self.layer_2_size) + self.layer_3 = torch.nn.Linear(self.layer_2_size, 10) + + def forward(self, x): + batch_size, channels, width, height = x.size() + x = x.view(batch_size, -1) + + x = self.layer_1(x) + x = torch.relu(x) + + x = self.layer_2(x) + x = torch.relu(x) + + x = self.layer_3(x) + x = torch.log_softmax(x, dim=1) + + return x + + def cross_entropy_loss(self, logits, labels): + return F.nll_loss(logits, labels) + + def accuracy(self, logits, labels): + _, predicted = torch.max(logits.data, 1) + correct = (predicted == labels).sum().item() + accuracy = correct / len(labels) + return torch.tensor(accuracy) + + def training_step(self, train_batch, batch_idx): + x, y = train_batch + logits = self.forward(x) + loss = self.cross_entropy_loss(logits, y) + accuracy = self.accuracy(logits, y) + + logs = {"train_loss": loss, "train_accuracy": accuracy} + return {"loss": loss, "log": logs} + + def validation_step(self, val_batch, batch_idx): + x, y = val_batch + logits = self.forward(x) + loss = self.cross_entropy_loss(logits, y) + accuracy = self.accuracy(logits, y) + + return {"val_loss": loss, "val_accuracy": accuracy} + + def validation_epoch_end(self, outputs): + avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() + avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean() + tensorboard_logs = {"val_loss": avg_loss, "val_accuracy": avg_acc} + + return { + "avg_val_loss": avg_loss, + "avg_val_accuracy": avg_acc, + "log": tensorboard_logs + } + + @staticmethod + def download_data(data_dir): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307, ), (0.3081, )) + ]) + return MNIST(data_dir, train=True, download=True, transform=transform) + + def prepare_data(self): + mnist_train = self.download_data(self.data_dir) + + self.mnist_train, self.mnist_val = random_split( + mnist_train, [55000, 5000]) + + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=int(self.batch_size)) + + def val_dataloader(self): + return DataLoader(self.mnist_val, batch_size=int(self.batch_size)) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + return optimizer + + +def train_mnist(config): + model = LightningMNISTClassifier(config) + trainer = pl.Trainer(max_epochs=10, show_progress_bar=False) + + trainer.fit(model) +# __lightning_end__ + + +# __tune_callback_begin__ +class TuneReportCallback(Callback): + def on_validation_end(self, trainer, pl_module): + tune.report( + loss=trainer.callback_metrics["avg_val_loss"], + mean_accuracy=trainer.callback_metrics["avg_val_accuracy"]) +# __tune_callback_end__ + + +# __tune_train_begin__ +def train_mnist_tune(config): + model = LightningMNISTClassifier(config, config["data_dir"]) + trainer = pl.Trainer( + max_epochs=10, + progress_bar_refresh_rate=0, + callbacks=[TuneReportCallback()]) + + trainer.fit(model) +# __tune_train_end__ + + +# __tune_checkpoint_callback_begin__ +class CheckpointCallback(Callback): + def on_validation_end(self, trainer, pl_module): + path = tune.make_checkpoint_dir(trainer.global_step) + trainer.save_checkpoint(os.path.join(path, "checkpoint")) + tune.save_checkpoint(path) +# __tune_checkpoint_callback_end__ + + +# __tune_train_checkpoint_begin__ +def train_mnist_tune_checkpoint(config, checkpoint=None): + trainer = pl.Trainer( + max_epochs=10, + progress_bar_refresh_rate=0, + callbacks=[CheckpointCallback(), + TuneReportCallback()]) + if checkpoint: + # Currently, this leads to errors: + # model = LightningMNISTClassifier.load_from_checkpoint( + # os.path.join(checkpoint, "checkpoint")) + # Workaround: + ckpt = pl_load( + os.path.join(checkpoint, "checkpoint"), + map_location=lambda storage, loc: storage) + model = LightningMNISTClassifier._load_model_state(ckpt, config=config) + trainer.current_epoch = ckpt["epoch"] + else: + model = LightningMNISTClassifier( + config=config, data_dir=config["data_dir"]) + + trainer.fit(model) +# __tune_train_checkpoint_end__ + + +# __tune_asha_begin__ +def tune_mnist_asha(): + data_dir = mkdtemp(prefix="mnist_data_") + LightningMNISTClassifier.download_data(data_dir) + config = { + "layer_1_size": tune.choice([32, 64, 128]), + "layer_2_size": tune.choice([64, 128, 256]), + "lr": tune.loguniform(1e-4, 1e-1), + "batch_size": tune.choice([32, 64, 128]), + "data_dir": data_dir + } + scheduler = ASHAScheduler( + metric="loss", + mode="min", + max_t=10, + grace_period=1, + reduction_factor=2) + reporter = CLIReporter( + parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"], + metric_columns=["loss", "mean_accuracy", "training_iteration"]) + tune.run( + train_mnist_tune, + resources_per_trial={"cpu": 1}, + config=config, + num_samples=10, + scheduler=scheduler, + progress_reporter=reporter) + shutil.rmtree(data_dir) +# __tune_asha_end__ + + +# __tune_pbt_begin__ +def tune_mnist_pbt(): + data_dir = mkdtemp(prefix="mnist_data_") + LightningMNISTClassifier.download_data(data_dir) + config = { + "layer_1_size": tune.choice([32, 64, 128]), + "layer_2_size": tune.choice([64, 128, 256]), + "lr": 1e-3, + "batch_size": 64, + "data_dir": data_dir + } + scheduler = PopulationBasedTraining( + time_attr="training_iteration", + metric="loss", + mode="min", + perturbation_interval=4, + hyperparam_mutations={ + "lr": lambda: tune.loguniform(1e-4, 1e-1).func(None), + "batch_size": [32, 64, 128] + }) + reporter = CLIReporter( + parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"], + metric_columns=["loss", "mean_accuracy", "training_iteration"]) + tune.run( + train_mnist_tune_checkpoint, + resources_per_trial={"cpu": 1}, + config=config, + num_samples=10, + scheduler=scheduler, + progress_reporter=reporter) + shutil.rmtree(data_dir) +# __tune_pbt_end__ + + +if __name__ == "__main__": + # tune_mnist_asha() # ASHA scheduler + tune_mnist_pbt() # population based training