Added validation set loss tracking (#47)

* added validation loss print

* resolved batch index bug on validation set

* leave notebooks unchanged

* fixed bug using zip iteration

* tracking validation loss

* cleaned up and working validation tracker

* updates to validation tracker

* added torch no grad for validation

* updates to rolling inference

* updated working inference

* updated issue with causal deep ar lengths

* removed testing script
This commit is contained in:
Larkin Liu
2021-04-30 15:09:01 +02:00
committed by GitHub
parent cbb7bb9089
commit 4c066aa6cb
+44 -13
View File
@@ -12,7 +12,6 @@ from torch.utils.data import DataLoader
from gluonts.core.component import validated
class Trainer:
@validated()
def __init__(
@@ -62,27 +61,59 @@ class Trainer:
tic = time.time()
avg_epoch_loss = 0.0
with tqdm(train_iter) as it:
for batch_no, data_entry in enumerate(it, start=1):
optimizer.zero_grad()
inputs = [v.to(self.device) for v in data_entry.values()]
if validation_iter is not None:
avg_epoch_loss_val = 0.0
train_iter_obj = list(zip(range(1, train_iter.batch_size+1), tqdm(train_iter)))
if validation_iter is not None:
val_iter_obj = list(zip(range(1, validation_iter.batch_size+1), tqdm(validation_iter)))
with tqdm(train_iter) as it:
for batch_no, data_entry in train_iter_obj:
optimizer.zero_grad()
# Strong assumption that validation_iter and train_iter are same iter size
if validation_iter is not None:
with torch.no_grad():
val_data_entry = val_iter_obj[batch_no-1][1]
inputs_val = [v.to(self.device) for v in val_data_entry.values()]
output_val = net(*inputs_val)
if isinstance(output_val, (list, tuple)):
loss_val = output_val[0]
else:
loss_val = output_val
avg_epoch_loss_val += loss_val.item()
inputs = [v.to(self.device) for v in data_entry.values()]
output = net(*inputs)
if isinstance(output, (list, tuple)):
loss = output[0]
else:
loss = output
avg_epoch_loss += loss.item()
it.set_postfix(
ordered_dict={
"avg_epoch_loss": avg_epoch_loss / batch_no,
"epoch": epoch_no,
},
refresh=False,
)
if validation_iter is not None:
post_fix_dict = ordered_dict={
"avg_epoch_loss": avg_epoch_loss / batch_no,
"avg_epoch_loss_val": avg_epoch_loss_val / batch_no,
"epoch": epoch_no,
}
wandb.log({"loss_val": loss_val.item()})
else:
post_fix_dict={
"avg_epoch_loss": avg_epoch_loss / batch_no,
"epoch": epoch_no,
}
wandb.log({"loss": loss.item()})
it.set_postfix(post_fix_dict, refresh=False)
loss.backward()
if self.clip_gradient is not None:
nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient)
@@ -92,7 +123,7 @@ class Trainer:
if self.num_batches_per_epoch == batch_no:
break
# mark epoch end time and log time cost of current epoch
toc = time.time()