diff --git a/pts/trainer.py b/pts/trainer.py index c6e062e..2e163e5 100644 --- a/pts/trainer.py +++ b/pts/trainer.py @@ -12,7 +12,6 @@ from torch.utils.data import DataLoader from gluonts.core.component import validated - class Trainer: @validated() def __init__( @@ -62,27 +61,59 @@ class Trainer: tic = time.time() avg_epoch_loss = 0.0 - with tqdm(train_iter) as it: - for batch_no, data_entry in enumerate(it, start=1): - optimizer.zero_grad() - inputs = [v.to(self.device) for v in data_entry.values()] + if validation_iter is not None: + avg_epoch_loss_val = 0.0 + train_iter_obj = list(zip(range(1, train_iter.batch_size+1), tqdm(train_iter))) + if validation_iter is not None: + val_iter_obj = list(zip(range(1, validation_iter.batch_size+1), tqdm(validation_iter))) + + + with tqdm(train_iter) as it: + for batch_no, data_entry in train_iter_obj: + + optimizer.zero_grad() + + # Strong assumption that validation_iter and train_iter are same iter size + if validation_iter is not None: + with torch.no_grad(): + val_data_entry = val_iter_obj[batch_no-1][1] + inputs_val = [v.to(self.device) for v in val_data_entry.values()] + output_val = net(*inputs_val) + + if isinstance(output_val, (list, tuple)): + loss_val = output_val[0] + else: + loss_val = output_val + + avg_epoch_loss_val += loss_val.item() + + inputs = [v.to(self.device) for v in data_entry.values()] output = net(*inputs) + if isinstance(output, (list, tuple)): loss = output[0] else: loss = output avg_epoch_loss += loss.item() - it.set_postfix( - ordered_dict={ - "avg_epoch_loss": avg_epoch_loss / batch_no, - "epoch": epoch_no, - }, - refresh=False, - ) + if validation_iter is not None: + post_fix_dict = ordered_dict={ + "avg_epoch_loss": avg_epoch_loss / batch_no, + "avg_epoch_loss_val": avg_epoch_loss_val / batch_no, + "epoch": epoch_no, + } + wandb.log({"loss_val": loss_val.item()}) + else: + post_fix_dict={ + "avg_epoch_loss": avg_epoch_loss / batch_no, + "epoch": epoch_no, + } + wandb.log({"loss": loss.item()}) + it.set_postfix(post_fix_dict, refresh=False) + loss.backward() if self.clip_gradient is not None: nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient) @@ -92,7 +123,7 @@ class Trainer: if self.num_batches_per_epoch == batch_no: break - + # mark epoch end time and log time cost of current epoch toc = time.time()