mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 18:06:19 +08:00
Added validation set loss tracking (#47)
* added validation loss print * resolved batch index bug on validation set * leave notebooks unchanged * fixed bug using zip iteration * tracking validation loss * cleaned up and working validation tracker * updates to validation tracker * added torch no grad for validation * updates to rolling inference * updated working inference * updated issue with causal deep ar lengths * removed testing script
This commit is contained in:
+44
-13
@@ -12,7 +12,6 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from gluonts.core.component import validated
|
||||
|
||||
|
||||
class Trainer:
|
||||
@validated()
|
||||
def __init__(
|
||||
@@ -62,27 +61,59 @@ class Trainer:
|
||||
tic = time.time()
|
||||
avg_epoch_loss = 0.0
|
||||
|
||||
with tqdm(train_iter) as it:
|
||||
for batch_no, data_entry in enumerate(it, start=1):
|
||||
optimizer.zero_grad()
|
||||
inputs = [v.to(self.device) for v in data_entry.values()]
|
||||
if validation_iter is not None:
|
||||
avg_epoch_loss_val = 0.0
|
||||
|
||||
train_iter_obj = list(zip(range(1, train_iter.batch_size+1), tqdm(train_iter)))
|
||||
if validation_iter is not None:
|
||||
val_iter_obj = list(zip(range(1, validation_iter.batch_size+1), tqdm(validation_iter)))
|
||||
|
||||
|
||||
with tqdm(train_iter) as it:
|
||||
for batch_no, data_entry in train_iter_obj:
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Strong assumption that validation_iter and train_iter are same iter size
|
||||
if validation_iter is not None:
|
||||
with torch.no_grad():
|
||||
val_data_entry = val_iter_obj[batch_no-1][1]
|
||||
inputs_val = [v.to(self.device) for v in val_data_entry.values()]
|
||||
output_val = net(*inputs_val)
|
||||
|
||||
if isinstance(output_val, (list, tuple)):
|
||||
loss_val = output_val[0]
|
||||
else:
|
||||
loss_val = output_val
|
||||
|
||||
avg_epoch_loss_val += loss_val.item()
|
||||
|
||||
inputs = [v.to(self.device) for v in data_entry.values()]
|
||||
output = net(*inputs)
|
||||
|
||||
if isinstance(output, (list, tuple)):
|
||||
loss = output[0]
|
||||
else:
|
||||
loss = output
|
||||
|
||||
avg_epoch_loss += loss.item()
|
||||
it.set_postfix(
|
||||
ordered_dict={
|
||||
"avg_epoch_loss": avg_epoch_loss / batch_no,
|
||||
"epoch": epoch_no,
|
||||
},
|
||||
refresh=False,
|
||||
)
|
||||
if validation_iter is not None:
|
||||
post_fix_dict = ordered_dict={
|
||||
"avg_epoch_loss": avg_epoch_loss / batch_no,
|
||||
"avg_epoch_loss_val": avg_epoch_loss_val / batch_no,
|
||||
"epoch": epoch_no,
|
||||
}
|
||||
wandb.log({"loss_val": loss_val.item()})
|
||||
else:
|
||||
post_fix_dict={
|
||||
"avg_epoch_loss": avg_epoch_loss / batch_no,
|
||||
"epoch": epoch_no,
|
||||
}
|
||||
|
||||
wandb.log({"loss": loss.item()})
|
||||
|
||||
it.set_postfix(post_fix_dict, refresh=False)
|
||||
|
||||
loss.backward()
|
||||
if self.clip_gradient is not None:
|
||||
nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient)
|
||||
@@ -92,7 +123,7 @@ class Trainer:
|
||||
|
||||
if self.num_batches_per_epoch == batch_no:
|
||||
break
|
||||
|
||||
|
||||
# mark epoch end time and log time cost of current epoch
|
||||
toc = time.time()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user