mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 16:31:20 +08:00
fix tqdm
This commit is contained in:
+11
-10
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
import wandb
|
||||
|
||||
import torch
|
||||
@@ -61,11 +61,11 @@ class Trainer:
|
||||
# mark epoch start time
|
||||
tic = time.time()
|
||||
cumm_epoch_loss = 0.0
|
||||
total = self.num_batches_per_epoch - 1
|
||||
|
||||
# training loop
|
||||
with tqdm(train_iter) as it:
|
||||
with tqdm(train_iter, total=total) as it:
|
||||
for batch_no, data_entry in enumerate(it, start=1):
|
||||
it.update(1)
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = [v.to(self.device) for v in data_entry.values()]
|
||||
@@ -83,7 +83,7 @@ class Trainer:
|
||||
"epoch": f"{epoch_no + 1}/{self.epochs}",
|
||||
"avg_loss": avg_epoch_loss,
|
||||
},
|
||||
refresh=False
|
||||
refresh=False,
|
||||
)
|
||||
|
||||
wandb.log({"loss": loss.item()})
|
||||
@@ -97,13 +97,14 @@ class Trainer:
|
||||
|
||||
if self.num_batches_per_epoch == batch_no:
|
||||
break
|
||||
it.close()
|
||||
|
||||
# validation loop
|
||||
if validation_iter is not None:
|
||||
cumm_epoch_loss_val = 0.0
|
||||
# validation loop
|
||||
if validation_iter is not None:
|
||||
cumm_epoch_loss_val = 0.0
|
||||
with tqdm(validation_iter, total=total, colour="green") as it:
|
||||
|
||||
for batch_no, data_entry in enumerate(validation_iter, start=1):
|
||||
it.update(1)
|
||||
for batch_no, data_entry in enumerate(it, start=1):
|
||||
inputs = [v.to(self.device) for v in data_entry.values()]
|
||||
with torch.no_grad():
|
||||
output = net(*inputs)
|
||||
@@ -120,13 +121,13 @@ class Trainer:
|
||||
"avg_loss": avg_epoch_loss,
|
||||
"avg_val_loss": avg_epoch_loss_val,
|
||||
},
|
||||
refresh=False,
|
||||
)
|
||||
|
||||
if self.num_batches_per_epoch == batch_no:
|
||||
break
|
||||
|
||||
wandb.log({"avg_val_loss": avg_epoch_loss_val})
|
||||
print(it) # TODO fix this
|
||||
it.close()
|
||||
|
||||
# mark epoch end time and log time cost of current epoch
|
||||
|
||||
Reference in New Issue
Block a user