mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 17:49:41 +08:00
partially fix tqdm with validation dataset
This commit is contained in:
+47
-44
@@ -12,6 +12,7 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from gluonts.core.component import validated
|
||||
|
||||
|
||||
class Trainer:
|
||||
@validated()
|
||||
def __init__(
|
||||
@@ -59,35 +60,14 @@ class Trainer:
|
||||
for epoch_no in range(self.epochs):
|
||||
# mark epoch start time
|
||||
tic = time.time()
|
||||
avg_epoch_loss = 0.0
|
||||
|
||||
if validation_iter is not None:
|
||||
avg_epoch_loss_val = 0.0
|
||||
|
||||
train_iter_obj = list(zip(range(1, train_iter.batch_size+1), tqdm(train_iter)))
|
||||
if validation_iter is not None:
|
||||
val_iter_obj = list(zip(range(1, validation_iter.batch_size+1), tqdm(validation_iter)))
|
||||
|
||||
cumm_epoch_loss = 0.0
|
||||
|
||||
# training loop
|
||||
with tqdm(train_iter) as it:
|
||||
for batch_no, data_entry in train_iter_obj:
|
||||
|
||||
for batch_no, data_entry in enumerate(it, start=1):
|
||||
it.update(1)
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Strong assumption that validation_iter and train_iter are same iter size
|
||||
if validation_iter is not None:
|
||||
with torch.no_grad():
|
||||
val_data_entry = val_iter_obj[batch_no-1][1]
|
||||
inputs_val = [v.to(self.device) for v in val_data_entry.values()]
|
||||
output_val = net(*inputs_val)
|
||||
|
||||
if isinstance(output_val, (list, tuple)):
|
||||
loss_val = output_val[0]
|
||||
else:
|
||||
loss_val = output_val
|
||||
|
||||
avg_epoch_loss_val += loss_val.item()
|
||||
|
||||
inputs = [v.to(self.device) for v in data_entry.values()]
|
||||
output = net(*inputs)
|
||||
|
||||
@@ -96,24 +76,18 @@ class Trainer:
|
||||
else:
|
||||
loss = output
|
||||
|
||||
avg_epoch_loss += loss.item()
|
||||
if validation_iter is not None:
|
||||
post_fix_dict = ordered_dict={
|
||||
"avg_epoch_loss": avg_epoch_loss / batch_no,
|
||||
"avg_epoch_loss_val": avg_epoch_loss_val / batch_no,
|
||||
"epoch": epoch_no,
|
||||
}
|
||||
wandb.log({"loss_val": loss_val.item()})
|
||||
else:
|
||||
post_fix_dict={
|
||||
"avg_epoch_loss": avg_epoch_loss / batch_no,
|
||||
"epoch": epoch_no,
|
||||
}
|
||||
|
||||
cumm_epoch_loss += loss.item()
|
||||
avg_epoch_loss = cumm_epoch_loss / batch_no
|
||||
it.set_postfix(
|
||||
{
|
||||
"epoch": f"{epoch_no + 1}/{self.epochs}",
|
||||
"avg_loss": avg_epoch_loss,
|
||||
},
|
||||
refresh=False
|
||||
)
|
||||
|
||||
wandb.log({"loss": loss.item()})
|
||||
|
||||
it.set_postfix(post_fix_dict, refresh=False)
|
||||
|
||||
loss.backward()
|
||||
if self.clip_gradient is not None:
|
||||
nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient)
|
||||
@@ -123,8 +97,37 @@ class Trainer:
|
||||
|
||||
if self.num_batches_per_epoch == batch_no:
|
||||
break
|
||||
|
||||
|
||||
# validation loop
|
||||
if validation_iter is not None:
|
||||
cumm_epoch_loss_val = 0.0
|
||||
|
||||
for batch_no, data_entry in enumerate(validation_iter, start=1):
|
||||
it.update(1)
|
||||
inputs = [v.to(self.device) for v in data_entry.values()]
|
||||
with torch.no_grad():
|
||||
output = net(*inputs)
|
||||
if isinstance(output, (list, tuple)):
|
||||
loss = output[0]
|
||||
else:
|
||||
loss = output
|
||||
|
||||
cumm_epoch_loss_val += loss.item()
|
||||
avg_epoch_loss_val = cumm_epoch_loss_val / batch_no
|
||||
it.set_postfix(
|
||||
{
|
||||
"epoch": f"{epoch_no + 1}/{self.epochs}",
|
||||
"avg_loss": avg_epoch_loss,
|
||||
"avg_val_loss": avg_epoch_loss_val,
|
||||
},
|
||||
)
|
||||
|
||||
if self.num_batches_per_epoch == batch_no:
|
||||
break
|
||||
|
||||
wandb.log({"avg_val_loss": avg_epoch_loss_val})
|
||||
print(it) # TODO fix this
|
||||
it.close()
|
||||
|
||||
# mark epoch end time and log time cost of current epoch
|
||||
toc = time.time()
|
||||
|
||||
# writer.close()
|
||||
|
||||
Reference in New Issue
Block a user