diff --git a/.gitignore b/.gitignore index 6f18569..8a95c35 100644 --- a/.gitignore +++ b/.gitignore @@ -104,6 +104,7 @@ venv.bak/ .mypy_cache/ # other +wandb/ .idea/ runs/ .vscode/ diff --git a/pts/trainer.py b/pts/trainer.py index 1585de4..44a06fa 100644 --- a/pts/trainer.py +++ b/pts/trainer.py @@ -2,6 +2,7 @@ import time from typing import List, Optional, Union from tqdm import tqdm +import wandb import torch import torch.nn as nn @@ -20,6 +21,7 @@ class Trainer: learning_rate: float = 1e-3, weight_decay: float = 1e-6, device: Optional[Union[torch.device, str]] = None, + **kwargs, ) -> None: self.epochs = epochs self.batch_size = batch_size @@ -27,6 +29,7 @@ class Trainer: self.learning_rate = learning_rate self.weight_decay = weight_decay self.device = device + wandb.init(**kwargs) def __call__( self, @@ -34,6 +37,8 @@ class Trainer: train_iter: DataLoader, validation_iter: Optional[DataLoader] = None, ) -> None: + wandb.watch(net, log="all", log_freq=self.num_batches_per_epoch) + optimizer = torch.optim.Adam( net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) @@ -62,17 +67,12 @@ class Trainer: }, refresh=False, ) - n_iter = epoch_no * self.num_batches_per_epoch + batch_no - # .add_scalar("Loss/train", loss.item(), n_iter) + wandb.log({"loss": loss.item()}) loss.backward() optimizer.step() if self.num_batches_per_epoch == batch_no: - # for name, param in net.named_parameters(): - # writer.add_histogram( - # name, param.clone().cpu().data.numpy(), n_iter - # ) break # mark epoch end time and log time cost of current epoch diff --git a/setup.py b/setup.py index 06b71dc..0afc123 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,7 @@ setup( 'pydantic', 'matplotlib', 'tensorboard', + 'wandb', ], test_suite='tests',