added wandb

This commit is contained in:
Dr. Kashif Rasul
2020-12-23 16:15:50 +01:00
parent 7a3f5d2961
commit 1ac4bf70d9
3 changed files with 8 additions and 6 deletions
+1
View File
@@ -104,6 +104,7 @@ venv.bak/
.mypy_cache/ .mypy_cache/
# other # other
wandb/
.idea/ .idea/
runs/ runs/
.vscode/ .vscode/
+6 -6
View File
@@ -2,6 +2,7 @@ import time
from typing import List, Optional, Union from typing import List, Optional, Union
from tqdm import tqdm from tqdm import tqdm
import wandb
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -20,6 +21,7 @@ class Trainer:
learning_rate: float = 1e-3, learning_rate: float = 1e-3,
weight_decay: float = 1e-6, weight_decay: float = 1e-6,
device: Optional[Union[torch.device, str]] = None, device: Optional[Union[torch.device, str]] = None,
**kwargs,
) -> None: ) -> None:
self.epochs = epochs self.epochs = epochs
self.batch_size = batch_size self.batch_size = batch_size
@@ -27,6 +29,7 @@ class Trainer:
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.device = device self.device = device
wandb.init(**kwargs)
def __call__( def __call__(
self, self,
@@ -34,6 +37,8 @@ class Trainer:
train_iter: DataLoader, train_iter: DataLoader,
validation_iter: Optional[DataLoader] = None, validation_iter: Optional[DataLoader] = None,
) -> None: ) -> None:
wandb.watch(net, log="all", log_freq=self.num_batches_per_epoch)
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
) )
@@ -62,17 +67,12 @@ class Trainer:
}, },
refresh=False, refresh=False,
) )
n_iter = epoch_no * self.num_batches_per_epoch + batch_no wandb.log({"loss": loss.item()})
# .add_scalar("Loss/train", loss.item(), n_iter)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if self.num_batches_per_epoch == batch_no: if self.num_batches_per_epoch == batch_no:
# for name, param in net.named_parameters():
# writer.add_histogram(
# name, param.clone().cpu().data.numpy(), n_iter
# )
break break
# mark epoch end time and log time cost of current epoch # mark epoch end time and log time cost of current epoch
+1
View File
@@ -26,6 +26,7 @@ setup(
'pydantic', 'pydantic',
'matplotlib', 'matplotlib',
'tensorboard', 'tensorboard',
'wandb',
], ],
test_suite='tests', test_suite='tests',