From 580051d191b59295487267b8449541e86de126e7 Mon Sep 17 00:00:00 2001 From: ilennaj Date: Tue, 1 Sep 2020 10:54:33 -0400 Subject: [PATCH] added comments to traintestloop.py --- .../traintestloop-checkpoint.py | 63 +++++++++++++----- .../__pycache__/traintestloop.cpython-36.pyc | Bin 3152 -> 4122 bytes custompackage/traintestloop.py | 63 +++++++++++++----- 3 files changed, 96 insertions(+), 30 deletions(-) diff --git a/custompackage/.ipynb_checkpoints/traintestloop-checkpoint.py b/custompackage/.ipynb_checkpoints/traintestloop-checkpoint.py index 6d98716..9021452 100644 --- a/custompackage/.ipynb_checkpoints/traintestloop-checkpoint.py +++ b/custompackage/.ipynb_checkpoints/traintestloop-checkpoint.py @@ -9,6 +9,8 @@ import torch.nn.functional as F import torch.optim as optim from torch.optim import Optimizer +from pytorchtools import EarlyStopping + def train_test_ktree(model, trainloader, validloader, testloader, epochs=10, randorder=False, patience=60): @@ -30,14 +32,17 @@ def train_test_ktree(model, trainloader, validloader, testloader, epochs=10, ran # to track the average validation loss per epoch as the model trains avg_valid_losses = [] - # + # if randorder == True, generate the randomizer index array for randomizing the input image pixel order if randorder == True: ordering = torch.randperm(len(trainloader.dataset.tensors[0][0])) + # Initialize early stopping object early_stopping = EarlyStopping(patience=patience, verbose=False) for epoch in range(epochs): # loop over the dataset multiple times - + ###################### + # train the model # + ###################### running_loss = 0.0 running_acc = 0.0 model.train() @@ -46,6 +51,7 @@ def train_test_ktree(model, trainloader, validloader, testloader, epochs=10, ran # get the inputs; data is a list of [inputs, labels] inputs, labels, _ = data if randorder == True: + # Randomize pixel order inputs = inputs[:,ordering].cuda() else: inputs = inputs.cuda() @@ -72,10 +78,12 @@ def train_test_ktree(model, trainloader, validloader, testloader, epochs=10, ran # print statistics running_loss += loss.item() running_acc += (torch.round(outputs) == labels.float().reshape(-1,1)).sum().item()/trainloader.batch_size - if (i % 4) == 3: # print every 80 mini-batches - loss_curve.append(running_loss/3) - acc_curve.append(running_acc/3) + # Generate loss and accuracy curves by saving average every 4th minibatch + if (i % 4) == 3: + loss_curve.append(running_loss/4) + acc_curve.append(running_acc/4) running_loss = 0.0 + running_acc = 0.0 ###################### # validate the model # @@ -83,7 +91,11 @@ def train_test_ktree(model, trainloader, validloader, testloader, epochs=10, ran model.eval() # prep model for evaluation for _, data in enumerate(validloader): inputs, labels, _ = data - inputs = inputs.cuda() + if randorder == True: + # Randomize pixel order + inputs = inputs[:,ordering].cuda() + else: + inputs = inputs.cuda() labels = labels.cuda() # forward pass: compute predicted outputs by passing inputs to the model output = model(inputs) @@ -95,7 +107,7 @@ def train_test_ktree(model, trainloader, validloader, testloader, epochs=10, ran valid_loss = np.average(valid_losses) - # early_stopping needs the validation loss to check if it has decresed, + # early_stopping needs the validation loss to check if it has decreased, # and if it has, it will make a checkpoint of the current model early_stopping(valid_loss, model) @@ -108,22 +120,29 @@ def train_test_ktree(model, trainloader, validloader, testloader, epochs=10, ran print('Finished Training, %d epochs' % (epoch+1)) + ###################### + # test the model # + ###################### correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels, _ = data if randorder == True: - images = images[:,ordering].cuda() + # Randomize pixel order + inputs = inputs[:,ordering].cuda() else: - images = images.cuda() + inputs = inputs.cuda() labels = labels.cuda() + # forward pass: compute predicted outputs by passing inputs to the model outputs = model(images) + # calculate the loss loss = criterion(outputs, labels.float().reshape(-1,1)) + # Sum up correct labelings predicted = torch.round(outputs) total += labels.size(0) correct += (predicted == labels.float().reshape(-1,1)).sum().item() - + # Calculate test accuracy accuracy = correct/total print('Accuracy of the network on the test images: %2f %%' % ( @@ -135,6 +154,12 @@ def train_test_ktree(model, trainloader, validloader, testloader, epochs=10, ran return(loss_curve, acc_curve, loss, accuracy, model) def train_test_fc(model, trainloader, validloader, testloader, epochs=10, patience=60): + ''' + Trains and tests fcnn models + Inputs: model, trainloader, validloader, testloader, epochs, patience + Outputs: train loss_curve, train acc_curve, test ave_loss, test accuracy, trained model + ''' + # Initialize loss function and optimizer criterion = nn.BCELoss() optimizer = optim.Adam(model.parameters(), lr=0.001) @@ -144,15 +169,18 @@ def train_test_fc(model, trainloader, validloader, testloader, epochs=10, patien # to track the average validation loss per epoch as the model trains avg_valid_losses = [] - + # to track training loss and accuracy as model trains loss_curve = [] acc_curve = [] + # Initialize early stopping object early_stopping = EarlyStopping(patience=patience, verbose=False) for epoch in range(epochs): # loop over the dataset multiple times - + ###################### + # train the model # + ###################### running_loss = 0.0 running_acc = 0.0 model.train() @@ -176,9 +204,9 @@ def train_test_fc(model, trainloader, validloader, testloader, epochs=10, patien # print statistics running_loss += loss.item() running_acc += (torch.round(outputs) == labels.float().reshape(-1,1)).sum().item()/trainloader.batch_size - if i % 4 == 3: # print every 80 mini-batches - loss_curve.append(running_loss/3) - acc_curve.append(running_acc/3) + if i % 4 == 3: # Generate loss and accuracy curves by saving average every 4th minibatch + loss_curve.append(running_loss/4) + acc_curve.append(running_acc/4) running_loss = 0.0 running_acc = 0.0 @@ -221,13 +249,18 @@ def train_test_fc(model, trainloader, validloader, testloader, epochs=10, patien images, labels, _ = data images = images.cuda() labels = labels.cuda() + # forward pass: compute predicted outputs by passing inputs to the model outputs = model(images) + # calculate the loss loss = criterion(outputs, labels.float().reshape(-1,1)) + # Sum up correct labelings predicted = torch.round(outputs) total += labels.size(0) correct += (predicted == labels.float().reshape(-1,1)).sum().item() all_loss += loss + # Calculate test accuracy accuracy = correct/total + # Calculate average loss ave_loss = all_loss.item()/total print('Accuracy of the network on the 10000 test images: %d %%' % ( diff --git a/custompackage/__pycache__/traintestloop.cpython-36.pyc b/custompackage/__pycache__/traintestloop.cpython-36.pyc index f0e82eefaa95f647a8cac454764f665bf36f9f8d..a46b8a7acfa672e4ebecf9cfd4cbe1f1e371743a 100644 GIT binary patch literal 4122 zcmb7H&2t+`6`vW6q|xYGmgNsQ_9P(%Uc7cT3uFl)n`9HiQrRS%1VX8tQjOJPTVs8> zXOwkZJuFp%oLJ)19{0e3s{Ic*@&|CBxX~vHPPxIAq6&VmN0Oa{;$u|(oBnv+{rdHr z_j`IbpSORyb+`HHSxx(kcIru>ejO$F9KbZDceDn*jgHV9Kkfq_{@l`T`kI>*eF?)oJg|7zQON=Fin^ix^MXzW-e${M*2vP@}VtKEXh)s zT{zYc{~Bh6Inr-w$NG~$M@5#7N+@OQk{f9{dJnIMd4ZWq;?kB8RiYU&vz3Ud(QH&3 z@An%~Wfrr3ZixJjv7q?{Xj|=Bq8iu4ta{H1%P-QZLK9W!GAGhv4r8T}=9gKCWk$xx zh~~w-Um>Z(2hd_>q}|ey;^A8utD^2ib+PEr`ZZDaYx~Br&KH?U zJ)(XU+8-=FCM+n_9&6BGPH9lWF7qSpvKEy@8P-}5b0dAKOPO@pH~d9t1}WIFUMJbI zs4%;xu`F!-Tx9$uQDSx-mc^NDw91p##Eu8gv8U#rQRmr$B@MFVll!n^?G-z|ZmHNp zOR&(=RIAF=LRsvk^rxM1zTZFA!exI&WMS1aq9D!)ZG_K76nuxX@GQ%b+X*!>HUfR3M!^>a?{`PksU63!bp?ck;zb;Hi|rZwTb zalw#k@ENkIV0*9%?7;o#aETgAaQ4=#RF8=Mi_?y}-ej>x5nq z2G0K5A@@9|+h<-UP$D1p2E#D85_hdRA#RQIE zvE~fiuBX5Hwpu{?r?9;JB^Uv;*rp?o+QxQW#gw7Np(Qv$43x zV)fM8yZ`v|Z-4k``~Ld~LPm!_;yOag=BgpBi3&3P)Z^Rzz}v-$dhWg?p_e8B1>~cG z3Rq4sj#x)ao9$jZ*uzH3WC&tw&IRVg=RCS}W2}$U-*Lh{&*^#LCw;!}^n0qQbZ&Ru zT@Qx5aB0W6aN#ARz5a=Zmr#2RR=*;RUQecPesKE^PA!vt3Z~M$!Q8I22QGKJUg&WQ zgni!HLma`h64SXC?70I^TH9`G{}Y!pX|?v+9mcUqaR;7qddZ}^1K5T9KxR+fgG|Me(mv5krdxf^y;exh z62j1_*b`-D0OoGDLJzkJ4b%tqrN$q@EPubUvDfc<8*O+~&)pdCzK;chjeUv^8?7Pk zr8|I2;Hoy1`D07>2kV1pvZQ*N)YViI$iahO0JOBFR}!XfCQQRjSo&fjXH*Qdb<3z3 zwnEFW68Pyg-AWYnf}Yb$`V#tb`kIkLyFk)OtfNC7GiYYju#>itLOYc#sr77((ywMf z2EQ8W1^ptVO+AO5N(TQC^l8y|L63tnO+UUcrs+rOjr1dpl;f2YBurB1WhRn~98!m) zFf~msI#He_qXJ41gqT)@_`A@GQ;R3R1tFF|uVo|_NHC(RsQ4Kq5fw3`-ZLWY+lsg~ zQ30{cij(Swu&M|rr9t4q90;}wD#f-1o*0ezIbPOwC9hG3asg}@=ezbK8LBls%8*9d9^s|0Tme4Sv8 z;BA6+f(?R;1n&@BB6t_D>hNz+{hI{eBDhTO9s$t=|2DxDg6|SsBlr#h{fXi46I>^_ zL2#3R{>boK1h)w`3BE`0eS!}OJ|ehFa1YS1+)hUk1HVsw9}_$vco>h4Qy>+fWB4Hk zWd2mR-D$l*S8xc>5cMi1#aYBxBbS)O*$Ty3#MW_)jUy|f?Ee~B&2eNs9akNF9&6@Q zY~_!!_iBQt@!I4TH4N2=>os*v;0VfF)`wxc6RZ>Mal~8=T&}tm5i75dR@~F;DR+@c zN2I-b;CV5duq!1l)*V=adQ6kUDQ^{RXAmOjhV^X>pR0?h*h3xl)O8SZa&;%xRpchaj(3gH%3u)y^C-iS*>dV<6Pos?{{q&* BD{}w< literal 3152 zcmai$&u=4V8OPrlkL?-%j_o*cvPD*50WaHRyM=C9R?tn;&_mf4wrY^o$eNvZ5_@9L zIPZ+xGgm0Lka7f7LPDtUeO^1>q$`1~_dDZHpAE+T!mj)>$iIpd{{q4pH%6>SPjh69&7Nt@>oS4Wvvg}V$n|nM zm+#s7>GT{tGo!+|*emM()~GZt_sWa0l^_>Xdo|tna-QrcGhXSfpht<@-1(`|tMf81 z@FH>zZaik)$_3tsJ?a{&uu->G&5#K>YN*=?m82-qOSjy5&rHi{ zMONO+q}8;Rt}N#MNLJw#_A^sDhi01vCA?8Jv}JXklQsRU$y`vLF@s4at5~-#v$Br1 z%8Uh7xGy&|XJ*=vji5%oP9MRZl^NS*2Rmu=^nq-Wbtv6QS5I%Dt&V&rZOPT35j15h zXdavAOnlBQs*$Y+?3|r^zRynnf%MmXhSgW~>P4*3nz4H@WBDFsdXHl> zSjStT7yQ-Uq`oCt=J^)mHhdY#e6S&l+}?x-VULXxwtrqC-<Y;s2ycy5ig zbWd)jYe?%@^~Q|-=6A62V`;*U-!8n`e(Ke8LwqLN&-k?szqYkM+e^PP*?z{Kl}mqK zT9|Igb&Pn9TQdXptjjgIv9!8+-DJ2{z)?Oz z(!UETU05V^pU-aW4Z~qP@OkGQ;UXIz-R|7v9e)yegZS+3L(iKE*ZZ&&9d?p|-wFNX zgGd~AqEHtTKTbNsv3uml_d7T59(Hctyx=tVfAA9}o^*5t9R@sQI zpt2rvcdQB%SGZ$8@dX-^NO%KWUtHOVFUBe}@-dF%ZpVJ2@`)eD5%oidBVRdwI34@K zO?+i}Q|>AUw>s({374zv;V5!3MELQ*o%qUr;(Es)xPmL&8w^KWVB+eb@clEtKX&8e zJ}qF)ht5Y>apF&uHB9_57L2ALSDAQ<7k=U<-k=`?Au8uiCO#Un6EO_2TDVxxm+PpUw!TjXcnv^A=D#Z4;jMMC&qLVK0CBTS~xfO00l?a)hnI2fzxqw@(9uj3&6>;Cp28vENr zY-Z?gPec@8f_VElf%~_;X`Dpk33dl(XuM51+?G6 z43<&CY&BDSA8UDZZ|K)XT1JV_aHSSe!g?l3WDy~Bq&WRslAT9@1C`o5lR8L+(?2D- zIsFg*0QMI_!4lYy7?4(FImmHyfH)v!QVFWE1{zk})22vBNH$mlJgUUbGRQ}1r-c@bys`o8$znALGyl9x!nN%Ad{mr3rCyh8GAlKUjzA$dUZDx_z- zqmgEsct~|SB=n&scIU0{`}0=fjfRfcLr#^i(BsfMKvn!MWQh)=q5)FZ*jOUt8Yl<@ zb5vZy;wmUQ#yX%;(y$P3VCJqVa(L`@4gKskwGV+YifmicWH^eqh{6INm<)kIS8LAe zjZ*f!CJglssw^TLC}@U=L}ND@sMRYYwwCchP()YjUtC15guo_