partially fix tqdm with validation dataset

This commit is contained in:
Kashif Rasul
2021-06-18 13:31:33 +02:00
parent 729c4a1b57
commit 5b8dc69617
+47 -44
View File
@@ -12,6 +12,7 @@ from torch.utils.data import DataLoader
from gluonts.core.component import validated
class Trainer:
@validated()
def __init__(
@@ -59,35 +60,14 @@ class Trainer:
for epoch_no in range(self.epochs):
# mark epoch start time
tic = time.time()
avg_epoch_loss = 0.0
if validation_iter is not None:
avg_epoch_loss_val = 0.0
train_iter_obj = list(zip(range(1, train_iter.batch_size+1), tqdm(train_iter)))
if validation_iter is not None:
val_iter_obj = list(zip(range(1, validation_iter.batch_size+1), tqdm(validation_iter)))
cumm_epoch_loss = 0.0
# training loop
with tqdm(train_iter) as it:
for batch_no, data_entry in train_iter_obj:
for batch_no, data_entry in enumerate(it, start=1):
it.update(1)
optimizer.zero_grad()
# Strong assumption that validation_iter and train_iter are same iter size
if validation_iter is not None:
with torch.no_grad():
val_data_entry = val_iter_obj[batch_no-1][1]
inputs_val = [v.to(self.device) for v in val_data_entry.values()]
output_val = net(*inputs_val)
if isinstance(output_val, (list, tuple)):
loss_val = output_val[0]
else:
loss_val = output_val
avg_epoch_loss_val += loss_val.item()
inputs = [v.to(self.device) for v in data_entry.values()]
output = net(*inputs)
@@ -96,24 +76,18 @@ class Trainer:
else:
loss = output
avg_epoch_loss += loss.item()
if validation_iter is not None:
post_fix_dict = ordered_dict={
"avg_epoch_loss": avg_epoch_loss / batch_no,
"avg_epoch_loss_val": avg_epoch_loss_val / batch_no,
"epoch": epoch_no,
}
wandb.log({"loss_val": loss_val.item()})
else:
post_fix_dict={
"avg_epoch_loss": avg_epoch_loss / batch_no,
"epoch": epoch_no,
}
cumm_epoch_loss += loss.item()
avg_epoch_loss = cumm_epoch_loss / batch_no
it.set_postfix(
{
"epoch": f"{epoch_no + 1}/{self.epochs}",
"avg_loss": avg_epoch_loss,
},
refresh=False
)
wandb.log({"loss": loss.item()})
it.set_postfix(post_fix_dict, refresh=False)
loss.backward()
if self.clip_gradient is not None:
nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient)
@@ -123,8 +97,37 @@ class Trainer:
if self.num_batches_per_epoch == batch_no:
break
# validation loop
if validation_iter is not None:
cumm_epoch_loss_val = 0.0
for batch_no, data_entry in enumerate(validation_iter, start=1):
it.update(1)
inputs = [v.to(self.device) for v in data_entry.values()]
with torch.no_grad():
output = net(*inputs)
if isinstance(output, (list, tuple)):
loss = output[0]
else:
loss = output
cumm_epoch_loss_val += loss.item()
avg_epoch_loss_val = cumm_epoch_loss_val / batch_no
it.set_postfix(
{
"epoch": f"{epoch_no + 1}/{self.epochs}",
"avg_loss": avg_epoch_loss,
"avg_val_loss": avg_epoch_loss_val,
},
)
if self.num_batches_per_epoch == batch_no:
break
wandb.log({"avg_val_loss": avg_epoch_loss_val})
print(it) # TODO fix this
it.close()
# mark epoch end time and log time cost of current epoch
toc = time.time()
# writer.close()