From 1ac4bf70d984b1d028fc4341fb2fe71cf36138ff Mon Sep 17 00:00:00 2001 From: "Dr. Kashif Rasul" Date: Wed, 23 Dec 2020 16:15:50 +0100 Subject: [PATCH] added wandb --- .gitignore | 1 + pts/trainer.py | 12 ++++++------ setup.py | 1 + 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 6f18569..8a95c35 100644 --- a/.gitignore +++ b/.gitignore @@ -104,6 +104,7 @@ venv.bak/ .mypy_cache/ # other +wandb/ .idea/ runs/ .vscode/ diff --git a/pts/trainer.py b/pts/trainer.py index 1585de4..44a06fa 100644 --- a/pts/trainer.py +++ b/pts/trainer.py @@ -2,6 +2,7 @@ import time from typing import List, Optional, Union from tqdm import tqdm +import wandb import torch import torch.nn as nn @@ -20,6 +21,7 @@ class Trainer: learning_rate: float = 1e-3, weight_decay: float = 1e-6, device: Optional[Union[torch.device, str]] = None, + **kwargs, ) -> None: self.epochs = epochs self.batch_size = batch_size @@ -27,6 +29,7 @@ class Trainer: self.learning_rate = learning_rate self.weight_decay = weight_decay self.device = device + wandb.init(**kwargs) def __call__( self, @@ -34,6 +37,8 @@ class Trainer: train_iter: DataLoader, validation_iter: Optional[DataLoader] = None, ) -> None: + wandb.watch(net, log="all", log_freq=self.num_batches_per_epoch) + optimizer = torch.optim.Adam( net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) @@ -62,17 +67,12 @@ class Trainer: }, refresh=False, ) - n_iter = epoch_no * self.num_batches_per_epoch + batch_no - # .add_scalar("Loss/train", loss.item(), n_iter) + wandb.log({"loss": loss.item()}) loss.backward() optimizer.step() if self.num_batches_per_epoch == batch_no: - # for name, param in net.named_parameters(): - # writer.add_histogram( - # name, param.clone().cpu().data.numpy(), n_iter - # ) break # mark epoch end time and log time cost of current epoch diff --git a/setup.py b/setup.py index 06b71dc..0afc123 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,7 @@ setup( 'pydantic', 'matplotlib', 'tensorboard', + 'wandb', ], test_suite='tests',