This commit is contained in:
Kashif Rasul
2021-06-18 20:10:11 +02:00
parent 5b8dc69617
commit 5da3be5abd
+11 -10
View File
@@ -1,7 +1,7 @@
import time
from typing import List, Optional, Union
from tqdm import tqdm
from tqdm.auto import tqdm
import wandb
import torch
@@ -61,11 +61,11 @@ class Trainer:
# mark epoch start time
tic = time.time()
cumm_epoch_loss = 0.0
total = self.num_batches_per_epoch - 1
# training loop
with tqdm(train_iter) as it:
with tqdm(train_iter, total=total) as it:
for batch_no, data_entry in enumerate(it, start=1):
it.update(1)
optimizer.zero_grad()
inputs = [v.to(self.device) for v in data_entry.values()]
@@ -83,7 +83,7 @@ class Trainer:
"epoch": f"{epoch_no + 1}/{self.epochs}",
"avg_loss": avg_epoch_loss,
},
refresh=False
refresh=False,
)
wandb.log({"loss": loss.item()})
@@ -97,13 +97,14 @@ class Trainer:
if self.num_batches_per_epoch == batch_no:
break
it.close()
# validation loop
if validation_iter is not None:
cumm_epoch_loss_val = 0.0
# validation loop
if validation_iter is not None:
cumm_epoch_loss_val = 0.0
with tqdm(validation_iter, total=total, colour="green") as it:
for batch_no, data_entry in enumerate(validation_iter, start=1):
it.update(1)
for batch_no, data_entry in enumerate(it, start=1):
inputs = [v.to(self.device) for v in data_entry.values()]
with torch.no_grad():
output = net(*inputs)
@@ -120,13 +121,13 @@ class Trainer:
"avg_loss": avg_epoch_loss,
"avg_val_loss": avg_epoch_loss_val,
},
refresh=False,
)
if self.num_batches_per_epoch == batch_no:
break
wandb.log({"avg_val_loss": avg_epoch_loss_val})
print(it) # TODO fix this
it.close()
# mark epoch end time and log time cost of current epoch