mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 15:16:30 +08:00
added wandb
This commit is contained in:
@@ -104,6 +104,7 @@ venv.bak/
|
||||
.mypy_cache/
|
||||
|
||||
# other
|
||||
wandb/
|
||||
.idea/
|
||||
runs/
|
||||
.vscode/
|
||||
|
||||
+6
-6
@@ -2,6 +2,7 @@ import time
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
import wandb
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -20,6 +21,7 @@ class Trainer:
|
||||
learning_rate: float = 1e-3,
|
||||
weight_decay: float = 1e-6,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.epochs = epochs
|
||||
self.batch_size = batch_size
|
||||
@@ -27,6 +29,7 @@ class Trainer:
|
||||
self.learning_rate = learning_rate
|
||||
self.weight_decay = weight_decay
|
||||
self.device = device
|
||||
wandb.init(**kwargs)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -34,6 +37,8 @@ class Trainer:
|
||||
train_iter: DataLoader,
|
||||
validation_iter: Optional[DataLoader] = None,
|
||||
) -> None:
|
||||
wandb.watch(net, log="all", log_freq=self.num_batches_per_epoch)
|
||||
|
||||
optimizer = torch.optim.Adam(
|
||||
net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
|
||||
)
|
||||
@@ -62,17 +67,12 @@ class Trainer:
|
||||
},
|
||||
refresh=False,
|
||||
)
|
||||
n_iter = epoch_no * self.num_batches_per_epoch + batch_no
|
||||
# .add_scalar("Loss/train", loss.item(), n_iter)
|
||||
wandb.log({"loss": loss.item()})
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if self.num_batches_per_epoch == batch_no:
|
||||
# for name, param in net.named_parameters():
|
||||
# writer.add_histogram(
|
||||
# name, param.clone().cpu().data.numpy(), n_iter
|
||||
# )
|
||||
break
|
||||
|
||||
# mark epoch end time and log time cost of current epoch
|
||||
|
||||
Reference in New Issue
Block a user