mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 16:46:32 +08:00
added wandb
This commit is contained in:
@@ -104,6 +104,7 @@ venv.bak/
|
|||||||
.mypy_cache/
|
.mypy_cache/
|
||||||
|
|
||||||
# other
|
# other
|
||||||
|
wandb/
|
||||||
.idea/
|
.idea/
|
||||||
runs/
|
runs/
|
||||||
.vscode/
|
.vscode/
|
||||||
|
|||||||
+6
-6
@@ -2,6 +2,7 @@ import time
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import wandb
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -20,6 +21,7 @@ class Trainer:
|
|||||||
learning_rate: float = 1e-3,
|
learning_rate: float = 1e-3,
|
||||||
weight_decay: float = 1e-6,
|
weight_decay: float = 1e-6,
|
||||||
device: Optional[Union[torch.device, str]] = None,
|
device: Optional[Union[torch.device, str]] = None,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.epochs = epochs
|
self.epochs = epochs
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -27,6 +29,7 @@ class Trainer:
|
|||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.weight_decay = weight_decay
|
self.weight_decay = weight_decay
|
||||||
self.device = device
|
self.device = device
|
||||||
|
wandb.init(**kwargs)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -34,6 +37,8 @@ class Trainer:
|
|||||||
train_iter: DataLoader,
|
train_iter: DataLoader,
|
||||||
validation_iter: Optional[DataLoader] = None,
|
validation_iter: Optional[DataLoader] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
wandb.watch(net, log="all", log_freq=self.num_batches_per_epoch)
|
||||||
|
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
|
net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
|
||||||
)
|
)
|
||||||
@@ -62,17 +67,12 @@ class Trainer:
|
|||||||
},
|
},
|
||||||
refresh=False,
|
refresh=False,
|
||||||
)
|
)
|
||||||
n_iter = epoch_no * self.num_batches_per_epoch + batch_no
|
wandb.log({"loss": loss.item()})
|
||||||
# .add_scalar("Loss/train", loss.item(), n_iter)
|
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
if self.num_batches_per_epoch == batch_no:
|
if self.num_batches_per_epoch == batch_no:
|
||||||
# for name, param in net.named_parameters():
|
|
||||||
# writer.add_histogram(
|
|
||||||
# name, param.clone().cpu().data.numpy(), n_iter
|
|
||||||
# )
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# mark epoch end time and log time cost of current epoch
|
# mark epoch end time and log time cost of current epoch
|
||||||
|
|||||||
Reference in New Issue
Block a user