added wandb

This commit is contained in:
Dr. Kashif Rasul
2020-12-23 16:15:50 +01:00
parent 7a3f5d2961
commit 1ac4bf70d9
3 changed files with 8 additions and 6 deletions
+1
View File
@@ -104,6 +104,7 @@ venv.bak/
.mypy_cache/
# other
wandb/
.idea/
runs/
.vscode/
+6 -6
View File
@@ -2,6 +2,7 @@ import time
from typing import List, Optional, Union
from tqdm import tqdm
import wandb
import torch
import torch.nn as nn
@@ -20,6 +21,7 @@ class Trainer:
learning_rate: float = 1e-3,
weight_decay: float = 1e-6,
device: Optional[Union[torch.device, str]] = None,
**kwargs,
) -> None:
self.epochs = epochs
self.batch_size = batch_size
@@ -27,6 +29,7 @@ class Trainer:
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.device = device
wandb.init(**kwargs)
def __call__(
self,
@@ -34,6 +37,8 @@ class Trainer:
train_iter: DataLoader,
validation_iter: Optional[DataLoader] = None,
) -> None:
wandb.watch(net, log="all", log_freq=self.num_batches_per_epoch)
optimizer = torch.optim.Adam(
net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
)
@@ -62,17 +67,12 @@ class Trainer:
},
refresh=False,
)
n_iter = epoch_no * self.num_batches_per_epoch + batch_no
# .add_scalar("Loss/train", loss.item(), n_iter)
wandb.log({"loss": loss.item()})
loss.backward()
optimizer.step()
if self.num_batches_per_epoch == batch_no:
# for name, param in net.named_parameters():
# writer.add_histogram(
# name, param.clone().cpu().data.numpy(), n_iter
# )
break
# mark epoch end time and log time cost of current epoch
+1
View File
@@ -26,6 +26,7 @@ setup(
'pydantic',
'matplotlib',
'tensorboard',
'wandb',
],
test_suite='tests',