diff --git a/pts/trainer.py b/pts/trainer.py index e224c74..19c6453 100644 --- a/pts/trainer.py +++ b/pts/trainer.py @@ -1,7 +1,7 @@ import time from typing import List, Optional, Union -from tqdm import tqdm +from tqdm.auto import tqdm import wandb import torch @@ -61,11 +61,11 @@ class Trainer: # mark epoch start time tic = time.time() cumm_epoch_loss = 0.0 + total = self.num_batches_per_epoch - 1 # training loop - with tqdm(train_iter) as it: + with tqdm(train_iter, total=total) as it: for batch_no, data_entry in enumerate(it, start=1): - it.update(1) optimizer.zero_grad() inputs = [v.to(self.device) for v in data_entry.values()] @@ -83,7 +83,7 @@ class Trainer: "epoch": f"{epoch_no + 1}/{self.epochs}", "avg_loss": avg_epoch_loss, }, - refresh=False + refresh=False, ) wandb.log({"loss": loss.item()}) @@ -97,13 +97,14 @@ class Trainer: if self.num_batches_per_epoch == batch_no: break + it.close() - # validation loop - if validation_iter is not None: - cumm_epoch_loss_val = 0.0 + # validation loop + if validation_iter is not None: + cumm_epoch_loss_val = 0.0 + with tqdm(validation_iter, total=total, colour="green") as it: - for batch_no, data_entry in enumerate(validation_iter, start=1): - it.update(1) + for batch_no, data_entry in enumerate(it, start=1): inputs = [v.to(self.device) for v in data_entry.values()] with torch.no_grad(): output = net(*inputs) @@ -120,13 +121,13 @@ class Trainer: "avg_loss": avg_epoch_loss, "avg_val_loss": avg_epoch_loss_val, }, + refresh=False, ) if self.num_batches_per_epoch == batch_no: break wandb.log({"avg_val_loss": avg_epoch_loss_val}) - print(it) # TODO fix this it.close() # mark epoch end time and log time cost of current epoch