[sgd] Add support for multi-model multi-optimizer training (#6317)

This commit is contained in:
Richard Liaw
2019-12-15 15:19:45 -08:00
committed by GitHub
parent c2499c802f
commit 5719a05757
12 changed files with 646 additions and 50 deletions
@@ -89,13 +89,8 @@ def cifar_creator(batch_size, config):
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
from filelock import FileLock
with FileLock(os.path.expanduser("~/data.lock")):
train_dataset = torchvision.datasets.CIFAR10(
root="~/data",
train=True,
download=True,
transform=transform_train)
train_dataset = torchvision.datasets.CIFAR10(
root="~/data", train=True, download=True, transform=transform_train)
validation_dataset = torchvision.datasets.CIFAR10(
root="~/data", train=False, download=False, transform=transform_test)
@@ -2,9 +2,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from filelock import FileLock
import logging
import os
import torch.nn as nn
import torch.distributed as dist
import torch.utils.data
from torch.nn.parallel import DistributedDataParallel
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
@@ -51,20 +56,29 @@ class DistributedPyTorchRunner(PyTorchRunner):
def _setup_training(self):
logger.debug("Creating model")
self.model = self.model_creator(self.config)
self.models = self.model_creator(self.config)
if not isinstance(self.models, collections.Iterable):
self.models = [self.models]
assert all(isinstance(model, nn.Module) for model in self.models), (
"All models must be PyTorch models: {}.".format(self.models))
if torch.cuda.is_available():
self.model = self.model.cuda()
self.model = torch.nn.parallel.DistributedDataParallel(self.model)
self.models = [model.cuda() for model in self.models]
self.models = [DistributedDataParallel(model) for model in self.models]
logger.debug("Creating optimizer.")
self.optimizer = self.optimizer_creator(self.model, self.config)
self.optimizers = self.optimizer_creator(self.given_models,
self.config)
if not isinstance(self.optimizers, collections.Iterable):
self.optimizers = [self.optimizers]
self.criterion = self.loss_creator(self.config)
if torch.cuda.is_available():
self.criterion = self.criterion.cuda()
logger.debug("Creating dataset")
self.train_loader, self.validation_loader = self.data_creator(
self.batch_size, self.config)
logger.debug("Creating dataset.")
with FileLock(os.path.expanduser("~/.ray_data.lock")):
data_loaders = self.data_creator(self.batch_size, self.config)
self.train_loader, self.validation_loader = self._validate_loaders(
data_loaders)
def step(self):
"""Runs a training epoch and updates the model parameters.
@@ -80,16 +94,21 @@ class DistributedPyTorchRunner(PyTorchRunner):
"""Returns the state of the runner."""
return {
"epoch": self.epoch,
"model": self.model.module.cpu().state_dict(),
"optimizer": self.optimizer.state_dict(),
"models": [
model.module.cpu().state_dict() for model in self.models
],
"optimizers": [opt.state_dict() for opt in self.optimizers],
"stats": self.stats()
}
def set_state(self, state):
"""Sets the state of the model."""
# TODO: restore timer stats
self.model.module.load_state_dict(state["model"])
self.optimizer.load_state_dict(state["optimizer"])
for model, model_state_dict in zip(self.models, state["models"]):
model.module.load_state_dict(model_state_dict)
for optimizer, opt_state_dict in zip(self.optimizers,
state["optimizers"]):
optimizer.load_state_dict(opt_state_dict)
self.epoch = state["stats"]["epoch"]
def shutdown(self):
@@ -0,0 +1,284 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import torch
import torch.nn as nn
from torch import distributed
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np
from torch.autograd import Variable
from torch.nn import functional as F
from torch.utils.data.distributed import DistributedSampler
from scipy.stats import entropy
import ray
from ray.experimental.sgd.pytorch import PyTorchTrainer
# Training parameters
TRAIN_BATCHES = 5
# Number of channels in the training images. For color images this is 3
num_channels = 1
# Size of z latent vector (i.e. size of generator input)
latent_vector_size = 100
# Size of feature maps in generator
features_g = 32
# Size of feature maps in discriminator
features_d = 32
def data_creator(batch_size, config):
dataset = dset.MNIST(
root="~/mnist/",
download=True,
transform=transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, )),
]))
# Create the dataloader
if distributed.is_initialized():
train_sampler = DistributedSampler(dataset)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
num_workers=3,
shuffle=(train_sampler is None),
sampler=train_sampler)
return dataloader, None
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(
latent_vector_size, features_g * 4, 4, 1, 0, bias=False),
nn.BatchNorm2d(features_g * 4),
nn.ReLU(True),
nn.ConvTranspose2d(
features_g * 4, features_g * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 2),
nn.ReLU(True),
nn.ConvTranspose2d(
features_g * 2, features_g, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g),
nn.ReLU(True),
nn.ConvTranspose2d(features_g, num_channels, 4, 2, 1, bias=False),
nn.Tanh())
def forward(self, input):
return self.main(input)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(num_channels, features_d, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 2), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 4), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features_d * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid())
def forward(self, input):
return self.main(input)
class Net(nn.Module):
"""LeNet for MNist classification, used for inception_score."""
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def inception_score(imgs, batch_size=32, splits=1):
N = len(imgs)
dtype = torch.FloatTensor
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
cm = Net()
cm.load_state_dict(torch.load(model_path))
cm.eval()
up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype)
def get_pred(x):
x = up(x)
x = cm(x)
return F.softmax(x).data.cpu().numpy()
preds = np.zeros((N, 10))
for i, batch in enumerate(dataloader, 0):
batch = batch.type(dtype)
batchv = Variable(batch)
batch_size_i = batch.size()[0]
preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)
# Now compute the mean kl-div
split_scores = []
for k in range(splits):
part = preds[k * (N // splits):(k + 1) * (N // splits), :]
py = np.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i, :]
scores.append(entropy(pyx, py))
split_scores.append(np.exp(np.mean(scores)))
return np.mean(split_scores), np.std(split_scores)
def model_creator(config):
netD = Discriminator()
netD.apply(weights_init)
netG = Generator()
netG.apply(weights_init)
return netD, netG
def train(models, dataloader, criterion, optimizers, config):
netD, netG = models
optimD, optimG = optimizers
real_label = 1
fake_label = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for i, data in enumerate(dataloader, 0):
if i >= TRAIN_BATCHES and config.get("test_mode"):
break
netD.zero_grad()
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size, ), real_label, device=device)
output = netD(real_cpu).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
noise = torch.randn(b_size, latent_vector_size, 1, 1, device=device)
fake = netG(noise)
label.fill_(fake_label)
output = netD(fake.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
errD = errD_real + errD_fake
optimD.step()
netG.zero_grad()
label.fill_(real_label)
output = netD(fake).view(-1)
errG = criterion(output, label)
errG.backward()
optimG.step()
is_score, is_std = inception_score(fake)
return {
"loss_g": errG.item(),
"loss_d": errD.item(),
"inception": is_score
}
def optimizer_creator(models, config):
net_d, net_g = models
discriminator_opt = optim.Adam(
net_d.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999))
generator_opt = optim.Adam(
net_g.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999))
return discriminator_opt, generator_opt
def train_example(num_replicas=1, use_gpu=False, test_mode=False):
config = {"test_mode": test_mode}
trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
lambda config: nn.BCELoss(),
train_function=train,
validation_function=False,
num_replicas=num_replicas,
config=config,
use_gpu=use_gpu,
batch_size=16 if test_mode else 512,
backend="nccl" if use_gpu else "gloo")
for i in range(5):
stats = trainer.train()
print(stats)
return trainer
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
parser.add_argument(
"--address",
required=False,
type=str,
help="the address to use for Redis")
parser.add_argument(
"--num-replicas",
"-n",
type=int,
default=1,
help="Sets number of replicas for training.")
parser.add_argument(
"--use-gpu",
action="store_true",
default=False,
help="Enables GPU training")
args, _ = parser.parse_known_args()
ray.init(address=args.address)
path = os.path.dirname(ray.__file__)
model_path = os.path.join(
path, "experimental/sgd/pytorch/examples/mnist_cnn.pt")
# load the pretrained mnist classification model for inception_score
trainer = train_example(
num_replicas=args.num_replicas,
use_gpu=args.use_gpu,
test_mode=args.smoke_test)
models = trainer.get_model()
@@ -2,9 +2,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from filelock import FileLock
import logging
import os
import torch
import torch.utils.data
from torch.utils.data import DataLoader
import ray
from ray.experimental.sgd.pytorch import utils as pytorch_utils
@@ -59,22 +63,46 @@ class PyTorchRunner(object):
]
}
self.models = None
self.optimizers = None
self.criterion = None
self.train_loader = None
self.validation_loader = None
def _validate_loaders(self, data_loaders):
assert data_loaders, "Dataloaders need to be returned in data_creator."
if isinstance(data_loaders, DataLoader):
return data_loaders, None
elif len(data_loaders) == 2 and isinstance(data_loaders[0],
DataLoader):
return data_loaders
else:
raise ValueError(
"Dataloaders must be <= 2. Got {}".format(data_loaders))
def setup(self):
"""Initializes the model."""
logger.debug("Creating model")
self.model = self.model_creator(self.config)
self.models = self.model_creator(self.config)
if not isinstance(self.models, collections.Iterable):
self.models = [self.models]
if torch.cuda.is_available():
self.model = self.model.cuda()
self.models = [model.cuda() for model in self.models]
logger.debug("Creating optimizer")
self.optimizer = self.optimizer_creator(self.model, self.config)
self.optimizers = self.optimizer_creator(self.given_models,
self.config)
if not isinstance(self.optimizers, collections.Iterable):
self.optimizers = [self.optimizers]
self.criterion = self.loss_creator(self.config)
if torch.cuda.is_available():
self.criterion = self.criterion.cuda()
logger.debug("Creating dataset")
self.train_loader, self.validation_loader = self.data_creator(
self.batch_size, self.config)
with FileLock(os.path.expanduser("~/.ray_data.lock")):
dataloaders = self.data_creator(self.batch_size, self.config)
self.train_loader, self.validation_loader = self._validate_loaders(
dataloaders)
def get_node_ip(self):
"""Returns the IP address of the current node."""
@@ -88,9 +116,9 @@ class PyTorchRunner(object):
"""Runs a training epoch and updates the model parameters."""
logger.debug("Begin Training Epoch {}".format(self.epoch + 1))
with self._timers["training"]:
train_stats = self.train_function(self.model, self.train_loader,
self.criterion, self.optimizer,
self.config)
train_stats = self.train_function(
self.given_models, self.train_loader, self.criterion,
self.given_optimizers, self.config)
train_stats["epoch"] = self.epoch
self.epoch += 1
@@ -100,9 +128,11 @@ class PyTorchRunner(object):
def validate(self):
"""Evaluates the model on the validation data set."""
if self.validation_loader is None:
raise ValueError("No validation dataloader provided.")
with self._timers["validation"]:
validation_stats = self.validation_function(
self.model, self.validation_loader, self.criterion,
self.given_models, self.validation_loader, self.criterion,
self.config)
validation_stats.update(self.stats())
@@ -121,16 +151,18 @@ class PyTorchRunner(object):
"""Returns the state of the runner."""
return {
"epoch": self.epoch,
"model": self.model.cpu().state_dict(),
"optimizer": self.optimizer.state_dict(),
"models": [model.cpu().state_dict() for model in self.models],
"optimizers": [opt.state_dict() for opt in self.optimizers],
"stats": self.stats()
}
def set_state(self, state):
"""Sets the state of the model."""
# TODO: restore timer stats
self.model.load_state_dict(state["model"])
self.optimizer.load_state_dict(state["optimizer"])
for model, state_dict in zip(self.models, state["models"]):
model.load_state_dict(state_dict)
for optimizer, state_dict in zip(self.optimizers, state["optimizers"]):
optimizer.load_state_dict(state_dict)
self.epoch = state["stats"]["epoch"]
def apply_fn(self, fn):
@@ -141,7 +173,21 @@ class PyTorchRunner(object):
del self.validation_loader
del self.train_loader
del self.criterion
del self.optimizer
del self.model
del self.optimizers
del self.models
if torch.cuda.is_available():
torch.cuda.empty_cache()
@property
def given_optimizers(self):
if len(self.optimizers) > 1:
return self.optimizers
else:
return self.optimizers[0]
@property
def given_models(self):
if len(self.models) > 1:
return self.models
else:
return self.models[0]
@@ -7,6 +7,7 @@ import os
import torch
import torch.distributed as dist
import logging
import numbers
import ray
@@ -77,6 +78,8 @@ class PyTorchTrainer(object):
"https://github.com/pytorch/examples/issues/467."))
self.model_creator = model_creator
self.train_function = train_function
self.validation_function = validation_function
self.config = {} if config is None else config
self.optimizer_timer = utils.TimerStat(window_size=1)
@@ -148,13 +151,20 @@ class PyTorchTrainer(object):
])
def train(self):
"""Runs a training epoch."""
"""Runs a training epoch.
Runs an average over all values returned from workers.
"""
with self.optimizer_timer:
worker_stats = ray.get([w.step.remote() for w in self.workers])
train_stats = worker_stats[0].copy()
train_stats["train_loss"] = np.mean(
[s["train_loss"] for s in worker_stats])
train_stats = {}
for stat_key in worker_stats[0]:
if isinstance(worker_stats[0], numbers.Number):
train_stats[stat_key] = np.nanmean(
[s.get(stat_key, np.nan) for s in worker_stats])
else:
train_stats[stat_key] = worker_stats[0][stat_key]
return train_stats
def apply_all_workers(self, fn):
@@ -162,22 +172,29 @@ class PyTorchTrainer(object):
def validate(self):
"""Evaluates the model on the validation data set."""
if self.validation_function is False:
return {}
worker_stats = ray.get([w.validate.remote() for w in self.workers])
validation_stats = worker_stats[0].copy()
if "validation_loss" in validation_stats:
validation_stats["validation_loss"] = np.nanmean(
[s.get("validation_loss", np.nan) for s in worker_stats])
validation_stats = {}
for stat_key in worker_stats[0]:
validation_stats[stat_key] = np.nanmean(
[s.get(stat_key, np.nan) for s in worker_stats])
return validation_stats
def get_model(self):
"""Returns the learned model."""
model = self.model_creator(self.config)
"""Returns the learned model(s)."""
models = self.model_creator(self.config)
state = ray.get(self.workers[0].get_state.remote())
model.load_state_dict(state["model"])
return model
if len(state["models"]) == 1:
models.load_state_dict(state["models"][0])
else:
for model, state_dict in zip(models, state["models"]):
model.load_state_dict(state_dict)
return models
def save(self, checkpoint):
"""Saves the model at the provided checkpoint.
"""Saves the model(s) to the provided checkpoint.
Args:
checkpoint (str): Path to target checkpoint file.
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import time
import torch
@@ -10,6 +11,12 @@ from ray.experimental.sgd.utils import TimerStat
def train(model, train_iterator, criterion, optimizer, config):
"""Runs 1 training epoch"""
if isinstance(model, collections.Iterable) or isinstance(
optimizer, collections.Iterable):
raise ValueError(
"Need to provide custom training function if using multi-model "
"or multi-optimizer training.")
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
@@ -63,6 +70,10 @@ def train(model, train_iterator, criterion, optimizer, config):
def validate(model, val_iterator, criterion, config):
if isinstance(model, collections.Iterable):
raise ValueError(
"Need to provide custom validation function if using multi-model "
"training.")
batch_time = AverageMeter()
losses = AverageMeter()
@@ -12,6 +12,7 @@ import torch.distributed as dist
from ray import tune
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
from ray.experimental.sgd.pytorch import PyTorchTrainer, PyTorchTrainable
from ray.experimental.sgd.pytorch.utils import train
from ray.experimental.sgd.examples.train_example import (
model_creator, optimizer_creator, data_creator)
@@ -39,6 +40,64 @@ def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
assert validation_loss2 <= validation_loss1
@pytest.mark.parametrize( # noqa: F811
"num_replicas", [1, 2] if dist.is_available() else [1])
def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
def custom_train(models, dataloader, criterion, optimizers, config):
result = {}
for i, (model, optimizer) in enumerate(zip(models, optimizers)):
result["model_{}".format(i)] = train(model, dataloader, criterion,
optimizer, config)
return result
def multi_model_creator(config):
return nn.Linear(1, 1), nn.Linear(1, 1)
def multi_optimizer_creator(models, config):
opts = [
torch.optim.SGD(model.parameters(), lr=0.0001) for model in models
]
return opts[0], opts[1]
trainer1 = PyTorchTrainer(
multi_model_creator,
data_creator,
multi_optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
train_function=custom_train,
num_replicas=num_replicas)
trainer1.train()
filename = os.path.join(tempfile.mkdtemp(), "checkpoint")
trainer1.save(filename)
models1 = trainer1.get_model()
trainer1.shutdown()
trainer2 = PyTorchTrainer(
multi_model_creator,
data_creator,
multi_optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
num_replicas=num_replicas)
trainer2.restore(filename)
os.remove(filename)
models2 = trainer2.get_model()
for model_1, model_2 in zip(models1, models2):
model1_state_dict = model_1.state_dict()
model2_state_dict = model_2.state_dict()
assert set(model1_state_dict.keys()) == set(model2_state_dict.keys())
for k in model1_state_dict:
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
@pytest.mark.parametrize( # noqa: F811
"num_replicas", [1, 2] if dist.is_available() else [1])
def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
@@ -0,0 +1,160 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import sys
import torch
import torch.nn as nn
import unittest
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
if sys.version_info >= (3, 3):
from unittest.mock import MagicMock
else:
from mock import MagicMock
class LinearDataset(torch.utils.data.Dataset):
"""y = a * x + b"""
def __init__(self, a, b, size=1000):
x = np.random.random(size).astype(np.float32) * 10
x = np.arange(0, 10, 10 / size, dtype=np.float32)
self.x = torch.from_numpy(x)
self.y = torch.from_numpy(a * x + b)
def __getitem__(self, index):
return self.x[index, None], self.y[index, None]
def __len__(self):
return len(self.x)
def model_creator(config):
return nn.Linear(1, 1)
def optimizer_creator(models, config):
"""Returns optimizer."""
return torch.optim.SGD(models.parameters(), lr=0.1)
def loss_creator(config):
return nn.MSELoss()
def single_loader(batch_size, config):
train_dataset = LinearDataset(2, 5)
train_loader = torch.utils.data.DataLoader(train_dataset)
return train_loader
def create_dataloaders(batch_size, config):
train_dataset = LinearDataset(2, 5)
validation_dataset = LinearDataset(2, 5, size=400)
train_loader = torch.utils.data.DataLoader(train_dataset)
validation_loader = torch.utils.data.DataLoader(validation_dataset)
return train_loader, validation_loader
class TestPyTorchRunner(unittest.TestCase):
def testValidate(self):
mock_function = MagicMock(returns=dict(mean_accuracy=10))
runner = PyTorchRunner(
model_creator,
create_dataloaders,
optimizer_creator,
loss_creator,
validation_function=mock_function)
runner.setup()
runner.step()
runner.step()
runner.step()
self.assertEqual(mock_function.call_count, 0)
runner.validate()
self.assertTrue(mock_function.called)
self.assertEqual(runner.stats()["epoch"], 3)
def testStep(self):
mock_function = MagicMock(return_value=dict(mean_accuracy=10))
runner = PyTorchRunner(
model_creator,
create_dataloaders,
optimizer_creator,
loss_creator,
train_function=mock_function)
runner.setup()
runner.step()
runner.step()
result = runner.step()
self.assertEqual(mock_function.call_count, 3)
self.assertEqual(result["epoch"], 3)
self.assertEqual(runner.stats()["epoch"], 3)
def testGivens(self):
def three_model_creator(config):
return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)
def three_optimizer_creator(models, config):
opts = [
torch.optim.SGD(model.parameters(), lr=0.1) for model in models
]
return opts[0], opts[1], opts[2]
runner = PyTorchRunner(three_model_creator, single_loader,
three_optimizer_creator, loss_creator)
runner.setup()
self.assertEqual(len(runner.given_models), 3)
self.assertEqual(len(runner.given_optimizers), 3)
runner2 = PyTorchRunner(model_creator, single_loader,
optimizer_creator, loss_creator)
runner2.setup()
self.assertNotEqual(runner2.given_models, runner2.models)
self.assertNotEqual(runner2.given_optimizers, runner2.optimizers)
def testMultiLoaders(self):
def three_data_loader(batch_size, config):
train_dataset = LinearDataset(2, 5)
validation_dataset = LinearDataset(2, 5, size=400)
train_loader = torch.utils.data.DataLoader(train_dataset)
validation_loader = torch.utils.data.DataLoader(validation_dataset)
return train_loader, validation_loader, validation_loader
runner = PyTorchRunner(model_creator, three_data_loader,
optimizer_creator, loss_creator)
with self.assertRaises(ValueError):
runner.setup()
runner2 = PyTorchRunner(model_creator, three_data_loader,
optimizer_creator, loss_creator)
with self.assertRaises(ValueError):
runner2.setup()
def testSingleLoader(self):
runner = PyTorchRunner(model_creator, single_loader, optimizer_creator,
loss_creator)
runner.setup()
runner.step()
with self.assertRaises(ValueError):
runner.validate()
def testMultiModel(self):
def multi_model_creator(config):
return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)
def multi_optimizer_creator(models, config):
opts = [
torch.optim.SGD(model.parameters(), lr=0.1) for model in models
]
return opts[0], opts[1], opts[2]
runner = PyTorchRunner(multi_model_creator, single_loader,
multi_optimizer_creator, loss_creator)
runner.setup()
with self.assertRaises(ValueError):
runner.step()