diff --git a/pts/trainer.py b/pts/trainer.py index 2e163e5..e224c74 100644 --- a/pts/trainer.py +++ b/pts/trainer.py @@ -12,6 +12,7 @@ from torch.utils.data import DataLoader from gluonts.core.component import validated + class Trainer: @validated() def __init__( @@ -59,35 +60,14 @@ class Trainer: for epoch_no in range(self.epochs): # mark epoch start time tic = time.time() - avg_epoch_loss = 0.0 - - if validation_iter is not None: - avg_epoch_loss_val = 0.0 - - train_iter_obj = list(zip(range(1, train_iter.batch_size+1), tqdm(train_iter))) - if validation_iter is not None: - val_iter_obj = list(zip(range(1, validation_iter.batch_size+1), tqdm(validation_iter))) - + cumm_epoch_loss = 0.0 + # training loop with tqdm(train_iter) as it: - for batch_no, data_entry in train_iter_obj: - + for batch_no, data_entry in enumerate(it, start=1): + it.update(1) optimizer.zero_grad() - # Strong assumption that validation_iter and train_iter are same iter size - if validation_iter is not None: - with torch.no_grad(): - val_data_entry = val_iter_obj[batch_no-1][1] - inputs_val = [v.to(self.device) for v in val_data_entry.values()] - output_val = net(*inputs_val) - - if isinstance(output_val, (list, tuple)): - loss_val = output_val[0] - else: - loss_val = output_val - - avg_epoch_loss_val += loss_val.item() - inputs = [v.to(self.device) for v in data_entry.values()] output = net(*inputs) @@ -96,24 +76,18 @@ class Trainer: else: loss = output - avg_epoch_loss += loss.item() - if validation_iter is not None: - post_fix_dict = ordered_dict={ - "avg_epoch_loss": avg_epoch_loss / batch_no, - "avg_epoch_loss_val": avg_epoch_loss_val / batch_no, - "epoch": epoch_no, - } - wandb.log({"loss_val": loss_val.item()}) - else: - post_fix_dict={ - "avg_epoch_loss": avg_epoch_loss / batch_no, - "epoch": epoch_no, - } - + cumm_epoch_loss += loss.item() + avg_epoch_loss = cumm_epoch_loss / batch_no + it.set_postfix( + { + "epoch": f"{epoch_no + 1}/{self.epochs}", + "avg_loss": avg_epoch_loss, + }, + refresh=False + ) + wandb.log({"loss": loss.item()}) - it.set_postfix(post_fix_dict, refresh=False) - loss.backward() if self.clip_gradient is not None: nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient) @@ -123,8 +97,37 @@ class Trainer: if self.num_batches_per_epoch == batch_no: break - + + # validation loop + if validation_iter is not None: + cumm_epoch_loss_val = 0.0 + + for batch_no, data_entry in enumerate(validation_iter, start=1): + it.update(1) + inputs = [v.to(self.device) for v in data_entry.values()] + with torch.no_grad(): + output = net(*inputs) + if isinstance(output, (list, tuple)): + loss = output[0] + else: + loss = output + + cumm_epoch_loss_val += loss.item() + avg_epoch_loss_val = cumm_epoch_loss_val / batch_no + it.set_postfix( + { + "epoch": f"{epoch_no + 1}/{self.epochs}", + "avg_loss": avg_epoch_loss, + "avg_val_loss": avg_epoch_loss_val, + }, + ) + + if self.num_batches_per_epoch == batch_no: + break + + wandb.log({"avg_val_loss": avg_epoch_loss_val}) + print(it) # TODO fix this + it.close() + # mark epoch end time and log time cost of current epoch toc = time.time() - - # writer.close()