mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 12:36:50 +08:00
[sgd] Add support for multi-model multi-optimizer training (#6317)
This commit is contained in:
@@ -89,13 +89,8 @@ def cifar_creator(batch_size, config):
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
from filelock import FileLock
|
||||
with FileLock(os.path.expanduser("~/data.lock")):
|
||||
train_dataset = torchvision.datasets.CIFAR10(
|
||||
root="~/data",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transform_train)
|
||||
train_dataset = torchvision.datasets.CIFAR10(
|
||||
root="~/data", train=True, download=True, transform=transform_train)
|
||||
validation_dataset = torchvision.datasets.CIFAR10(
|
||||
root="~/data", train=False, download=False, transform=transform_test)
|
||||
|
||||
|
||||
@@ -2,9 +2,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
from filelock import FileLock
|
||||
import logging
|
||||
import os
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
|
||||
|
||||
@@ -51,20 +56,29 @@ class DistributedPyTorchRunner(PyTorchRunner):
|
||||
|
||||
def _setup_training(self):
|
||||
logger.debug("Creating model")
|
||||
self.model = self.model_creator(self.config)
|
||||
self.models = self.model_creator(self.config)
|
||||
if not isinstance(self.models, collections.Iterable):
|
||||
self.models = [self.models]
|
||||
assert all(isinstance(model, nn.Module) for model in self.models), (
|
||||
"All models must be PyTorch models: {}.".format(self.models))
|
||||
if torch.cuda.is_available():
|
||||
self.model = self.model.cuda()
|
||||
self.model = torch.nn.parallel.DistributedDataParallel(self.model)
|
||||
self.models = [model.cuda() for model in self.models]
|
||||
self.models = [DistributedDataParallel(model) for model in self.models]
|
||||
|
||||
logger.debug("Creating optimizer.")
|
||||
self.optimizer = self.optimizer_creator(self.model, self.config)
|
||||
self.optimizers = self.optimizer_creator(self.given_models,
|
||||
self.config)
|
||||
if not isinstance(self.optimizers, collections.Iterable):
|
||||
self.optimizers = [self.optimizers]
|
||||
self.criterion = self.loss_creator(self.config)
|
||||
if torch.cuda.is_available():
|
||||
self.criterion = self.criterion.cuda()
|
||||
|
||||
logger.debug("Creating dataset")
|
||||
self.train_loader, self.validation_loader = self.data_creator(
|
||||
self.batch_size, self.config)
|
||||
logger.debug("Creating dataset.")
|
||||
with FileLock(os.path.expanduser("~/.ray_data.lock")):
|
||||
data_loaders = self.data_creator(self.batch_size, self.config)
|
||||
self.train_loader, self.validation_loader = self._validate_loaders(
|
||||
data_loaders)
|
||||
|
||||
def step(self):
|
||||
"""Runs a training epoch and updates the model parameters.
|
||||
@@ -80,16 +94,21 @@ class DistributedPyTorchRunner(PyTorchRunner):
|
||||
"""Returns the state of the runner."""
|
||||
return {
|
||||
"epoch": self.epoch,
|
||||
"model": self.model.module.cpu().state_dict(),
|
||||
"optimizer": self.optimizer.state_dict(),
|
||||
"models": [
|
||||
model.module.cpu().state_dict() for model in self.models
|
||||
],
|
||||
"optimizers": [opt.state_dict() for opt in self.optimizers],
|
||||
"stats": self.stats()
|
||||
}
|
||||
|
||||
def set_state(self, state):
|
||||
"""Sets the state of the model."""
|
||||
# TODO: restore timer stats
|
||||
self.model.module.load_state_dict(state["model"])
|
||||
self.optimizer.load_state_dict(state["optimizer"])
|
||||
for model, model_state_dict in zip(self.models, state["models"]):
|
||||
model.module.load_state_dict(model_state_dict)
|
||||
for optimizer, opt_state_dict in zip(self.optimizers,
|
||||
state["optimizers"]):
|
||||
optimizer.load_state_dict(opt_state_dict)
|
||||
self.epoch = state["stats"]["epoch"]
|
||||
|
||||
def shutdown(self):
|
||||
|
||||
@@ -0,0 +1,284 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import distributed
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
|
||||
from torch.autograd import Variable
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from scipy.stats import entropy
|
||||
|
||||
import ray
|
||||
from ray.experimental.sgd.pytorch import PyTorchTrainer
|
||||
|
||||
# Training parameters
|
||||
TRAIN_BATCHES = 5
|
||||
# Number of channels in the training images. For color images this is 3
|
||||
num_channels = 1
|
||||
|
||||
# Size of z latent vector (i.e. size of generator input)
|
||||
latent_vector_size = 100
|
||||
|
||||
# Size of feature maps in generator
|
||||
features_g = 32
|
||||
|
||||
# Size of feature maps in discriminator
|
||||
features_d = 32
|
||||
|
||||
|
||||
def data_creator(batch_size, config):
|
||||
dataset = dset.MNIST(
|
||||
root="~/mnist/",
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize(32),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, ), (0.5, )),
|
||||
]))
|
||||
|
||||
# Create the dataloader
|
||||
if distributed.is_initialized():
|
||||
train_sampler = DistributedSampler(dataset)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=3,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler)
|
||||
|
||||
return dataloader, None
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find("BatchNorm") != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self):
|
||||
super(Generator, self).__init__()
|
||||
self.main = nn.Sequential(
|
||||
# input is Z, going into a convolution
|
||||
nn.ConvTranspose2d(
|
||||
latent_vector_size, features_g * 4, 4, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(features_g * 4),
|
||||
nn.ReLU(True),
|
||||
nn.ConvTranspose2d(
|
||||
features_g * 4, features_g * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(features_g * 2),
|
||||
nn.ReLU(True),
|
||||
nn.ConvTranspose2d(
|
||||
features_g * 2, features_g, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(features_g),
|
||||
nn.ReLU(True),
|
||||
nn.ConvTranspose2d(features_g, num_channels, 4, 2, 1, bias=False),
|
||||
nn.Tanh())
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input)
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super(Discriminator, self).__init__()
|
||||
self.main = nn.Sequential(
|
||||
nn.Conv2d(num_channels, features_d, 4, 2, 1, bias=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(features_d * 2), nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(features_d * 4), nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(features_d * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid())
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input)
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
"""LeNet for MNist classification, used for inception_score."""
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.conv2_drop = nn.Dropout2d()
|
||||
self.fc1 = nn.Linear(320, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||||
x = x.view(-1, 320)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.dropout(x, training=self.training)
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
def inception_score(imgs, batch_size=32, splits=1):
|
||||
N = len(imgs)
|
||||
dtype = torch.FloatTensor
|
||||
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
|
||||
cm = Net()
|
||||
cm.load_state_dict(torch.load(model_path))
|
||||
cm.eval()
|
||||
up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype)
|
||||
|
||||
def get_pred(x):
|
||||
x = up(x)
|
||||
x = cm(x)
|
||||
return F.softmax(x).data.cpu().numpy()
|
||||
|
||||
preds = np.zeros((N, 10))
|
||||
for i, batch in enumerate(dataloader, 0):
|
||||
batch = batch.type(dtype)
|
||||
batchv = Variable(batch)
|
||||
batch_size_i = batch.size()[0]
|
||||
preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)
|
||||
|
||||
# Now compute the mean kl-div
|
||||
split_scores = []
|
||||
for k in range(splits):
|
||||
part = preds[k * (N // splits):(k + 1) * (N // splits), :]
|
||||
py = np.mean(part, axis=0)
|
||||
scores = []
|
||||
for i in range(part.shape[0]):
|
||||
pyx = part[i, :]
|
||||
scores.append(entropy(pyx, py))
|
||||
split_scores.append(np.exp(np.mean(scores)))
|
||||
|
||||
return np.mean(split_scores), np.std(split_scores)
|
||||
|
||||
|
||||
def model_creator(config):
|
||||
netD = Discriminator()
|
||||
netD.apply(weights_init)
|
||||
|
||||
netG = Generator()
|
||||
netG.apply(weights_init)
|
||||
return netD, netG
|
||||
|
||||
|
||||
def train(models, dataloader, criterion, optimizers, config):
|
||||
netD, netG = models
|
||||
optimD, optimG = optimizers
|
||||
real_label = 1
|
||||
fake_label = 0
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
for i, data in enumerate(dataloader, 0):
|
||||
if i >= TRAIN_BATCHES and config.get("test_mode"):
|
||||
break
|
||||
|
||||
netD.zero_grad()
|
||||
real_cpu = data[0].to(device)
|
||||
b_size = real_cpu.size(0)
|
||||
label = torch.full((b_size, ), real_label, device=device)
|
||||
output = netD(real_cpu).view(-1)
|
||||
errD_real = criterion(output, label)
|
||||
errD_real.backward()
|
||||
|
||||
noise = torch.randn(b_size, latent_vector_size, 1, 1, device=device)
|
||||
fake = netG(noise)
|
||||
label.fill_(fake_label)
|
||||
output = netD(fake.detach()).view(-1)
|
||||
errD_fake = criterion(output, label)
|
||||
errD_fake.backward()
|
||||
errD = errD_real + errD_fake
|
||||
optimD.step()
|
||||
|
||||
netG.zero_grad()
|
||||
label.fill_(real_label)
|
||||
output = netD(fake).view(-1)
|
||||
errG = criterion(output, label)
|
||||
errG.backward()
|
||||
optimG.step()
|
||||
|
||||
is_score, is_std = inception_score(fake)
|
||||
|
||||
return {
|
||||
"loss_g": errG.item(),
|
||||
"loss_d": errD.item(),
|
||||
"inception": is_score
|
||||
}
|
||||
|
||||
|
||||
def optimizer_creator(models, config):
|
||||
net_d, net_g = models
|
||||
discriminator_opt = optim.Adam(
|
||||
net_d.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999))
|
||||
generator_opt = optim.Adam(
|
||||
net_g.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999))
|
||||
return discriminator_opt, generator_opt
|
||||
|
||||
|
||||
def train_example(num_replicas=1, use_gpu=False, test_mode=False):
|
||||
config = {"test_mode": test_mode}
|
||||
trainer = PyTorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
lambda config: nn.BCELoss(),
|
||||
train_function=train,
|
||||
validation_function=False,
|
||||
num_replicas=num_replicas,
|
||||
config=config,
|
||||
use_gpu=use_gpu,
|
||||
batch_size=16 if test_mode else 512,
|
||||
backend="nccl" if use_gpu else "gloo")
|
||||
for i in range(5):
|
||||
stats = trainer.train()
|
||||
print(stats)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
parser.add_argument(
|
||||
"--address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the address to use for Redis")
|
||||
parser.add_argument(
|
||||
"--num-replicas",
|
||||
"-n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Sets number of replicas for training.")
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enables GPU training")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(address=args.address)
|
||||
|
||||
path = os.path.dirname(ray.__file__)
|
||||
model_path = os.path.join(
|
||||
path, "experimental/sgd/pytorch/examples/mnist_cnn.pt")
|
||||
# load the pretrained mnist classification model for inception_score
|
||||
|
||||
trainer = train_example(
|
||||
num_replicas=args.num_replicas,
|
||||
use_gpu=args.use_gpu,
|
||||
test_mode=args.smoke_test)
|
||||
models = trainer.get_model()
|
||||
Binary file not shown.
@@ -2,9 +2,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
from filelock import FileLock
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import ray
|
||||
from ray.experimental.sgd.pytorch import utils as pytorch_utils
|
||||
@@ -59,22 +63,46 @@ class PyTorchRunner(object):
|
||||
]
|
||||
}
|
||||
|
||||
self.models = None
|
||||
self.optimizers = None
|
||||
self.criterion = None
|
||||
self.train_loader = None
|
||||
self.validation_loader = None
|
||||
|
||||
def _validate_loaders(self, data_loaders):
|
||||
assert data_loaders, "Dataloaders need to be returned in data_creator."
|
||||
if isinstance(data_loaders, DataLoader):
|
||||
return data_loaders, None
|
||||
elif len(data_loaders) == 2 and isinstance(data_loaders[0],
|
||||
DataLoader):
|
||||
return data_loaders
|
||||
else:
|
||||
raise ValueError(
|
||||
"Dataloaders must be <= 2. Got {}".format(data_loaders))
|
||||
|
||||
def setup(self):
|
||||
"""Initializes the model."""
|
||||
logger.debug("Creating model")
|
||||
self.model = self.model_creator(self.config)
|
||||
self.models = self.model_creator(self.config)
|
||||
if not isinstance(self.models, collections.Iterable):
|
||||
self.models = [self.models]
|
||||
if torch.cuda.is_available():
|
||||
self.model = self.model.cuda()
|
||||
self.models = [model.cuda() for model in self.models]
|
||||
|
||||
logger.debug("Creating optimizer")
|
||||
self.optimizer = self.optimizer_creator(self.model, self.config)
|
||||
self.optimizers = self.optimizer_creator(self.given_models,
|
||||
self.config)
|
||||
if not isinstance(self.optimizers, collections.Iterable):
|
||||
self.optimizers = [self.optimizers]
|
||||
self.criterion = self.loss_creator(self.config)
|
||||
if torch.cuda.is_available():
|
||||
self.criterion = self.criterion.cuda()
|
||||
|
||||
logger.debug("Creating dataset")
|
||||
self.train_loader, self.validation_loader = self.data_creator(
|
||||
self.batch_size, self.config)
|
||||
with FileLock(os.path.expanduser("~/.ray_data.lock")):
|
||||
dataloaders = self.data_creator(self.batch_size, self.config)
|
||||
self.train_loader, self.validation_loader = self._validate_loaders(
|
||||
dataloaders)
|
||||
|
||||
def get_node_ip(self):
|
||||
"""Returns the IP address of the current node."""
|
||||
@@ -88,9 +116,9 @@ class PyTorchRunner(object):
|
||||
"""Runs a training epoch and updates the model parameters."""
|
||||
logger.debug("Begin Training Epoch {}".format(self.epoch + 1))
|
||||
with self._timers["training"]:
|
||||
train_stats = self.train_function(self.model, self.train_loader,
|
||||
self.criterion, self.optimizer,
|
||||
self.config)
|
||||
train_stats = self.train_function(
|
||||
self.given_models, self.train_loader, self.criterion,
|
||||
self.given_optimizers, self.config)
|
||||
train_stats["epoch"] = self.epoch
|
||||
|
||||
self.epoch += 1
|
||||
@@ -100,9 +128,11 @@ class PyTorchRunner(object):
|
||||
|
||||
def validate(self):
|
||||
"""Evaluates the model on the validation data set."""
|
||||
if self.validation_loader is None:
|
||||
raise ValueError("No validation dataloader provided.")
|
||||
with self._timers["validation"]:
|
||||
validation_stats = self.validation_function(
|
||||
self.model, self.validation_loader, self.criterion,
|
||||
self.given_models, self.validation_loader, self.criterion,
|
||||
self.config)
|
||||
|
||||
validation_stats.update(self.stats())
|
||||
@@ -121,16 +151,18 @@ class PyTorchRunner(object):
|
||||
"""Returns the state of the runner."""
|
||||
return {
|
||||
"epoch": self.epoch,
|
||||
"model": self.model.cpu().state_dict(),
|
||||
"optimizer": self.optimizer.state_dict(),
|
||||
"models": [model.cpu().state_dict() for model in self.models],
|
||||
"optimizers": [opt.state_dict() for opt in self.optimizers],
|
||||
"stats": self.stats()
|
||||
}
|
||||
|
||||
def set_state(self, state):
|
||||
"""Sets the state of the model."""
|
||||
# TODO: restore timer stats
|
||||
self.model.load_state_dict(state["model"])
|
||||
self.optimizer.load_state_dict(state["optimizer"])
|
||||
for model, state_dict in zip(self.models, state["models"]):
|
||||
model.load_state_dict(state_dict)
|
||||
for optimizer, state_dict in zip(self.optimizers, state["optimizers"]):
|
||||
optimizer.load_state_dict(state_dict)
|
||||
self.epoch = state["stats"]["epoch"]
|
||||
|
||||
def apply_fn(self, fn):
|
||||
@@ -141,7 +173,21 @@ class PyTorchRunner(object):
|
||||
del self.validation_loader
|
||||
del self.train_loader
|
||||
del self.criterion
|
||||
del self.optimizer
|
||||
del self.model
|
||||
del self.optimizers
|
||||
del self.models
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def given_optimizers(self):
|
||||
if len(self.optimizers) > 1:
|
||||
return self.optimizers
|
||||
else:
|
||||
return self.optimizers[0]
|
||||
|
||||
@property
|
||||
def given_models(self):
|
||||
if len(self.models) > 1:
|
||||
return self.models
|
||||
else:
|
||||
return self.models[0]
|
||||
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import logging
|
||||
import numbers
|
||||
|
||||
import ray
|
||||
|
||||
@@ -77,6 +78,8 @@ class PyTorchTrainer(object):
|
||||
"https://github.com/pytorch/examples/issues/467."))
|
||||
|
||||
self.model_creator = model_creator
|
||||
self.train_function = train_function
|
||||
self.validation_function = validation_function
|
||||
self.config = {} if config is None else config
|
||||
self.optimizer_timer = utils.TimerStat(window_size=1)
|
||||
|
||||
@@ -148,13 +151,20 @@ class PyTorchTrainer(object):
|
||||
])
|
||||
|
||||
def train(self):
|
||||
"""Runs a training epoch."""
|
||||
"""Runs a training epoch.
|
||||
|
||||
Runs an average over all values returned from workers.
|
||||
"""
|
||||
with self.optimizer_timer:
|
||||
worker_stats = ray.get([w.step.remote() for w in self.workers])
|
||||
|
||||
train_stats = worker_stats[0].copy()
|
||||
train_stats["train_loss"] = np.mean(
|
||||
[s["train_loss"] for s in worker_stats])
|
||||
train_stats = {}
|
||||
for stat_key in worker_stats[0]:
|
||||
if isinstance(worker_stats[0], numbers.Number):
|
||||
train_stats[stat_key] = np.nanmean(
|
||||
[s.get(stat_key, np.nan) for s in worker_stats])
|
||||
else:
|
||||
train_stats[stat_key] = worker_stats[0][stat_key]
|
||||
return train_stats
|
||||
|
||||
def apply_all_workers(self, fn):
|
||||
@@ -162,22 +172,29 @@ class PyTorchTrainer(object):
|
||||
|
||||
def validate(self):
|
||||
"""Evaluates the model on the validation data set."""
|
||||
if self.validation_function is False:
|
||||
return {}
|
||||
worker_stats = ray.get([w.validate.remote() for w in self.workers])
|
||||
validation_stats = worker_stats[0].copy()
|
||||
if "validation_loss" in validation_stats:
|
||||
validation_stats["validation_loss"] = np.nanmean(
|
||||
[s.get("validation_loss", np.nan) for s in worker_stats])
|
||||
|
||||
validation_stats = {}
|
||||
for stat_key in worker_stats[0]:
|
||||
validation_stats[stat_key] = np.nanmean(
|
||||
[s.get(stat_key, np.nan) for s in worker_stats])
|
||||
return validation_stats
|
||||
|
||||
def get_model(self):
|
||||
"""Returns the learned model."""
|
||||
model = self.model_creator(self.config)
|
||||
"""Returns the learned model(s)."""
|
||||
models = self.model_creator(self.config)
|
||||
state = ray.get(self.workers[0].get_state.remote())
|
||||
model.load_state_dict(state["model"])
|
||||
return model
|
||||
if len(state["models"]) == 1:
|
||||
models.load_state_dict(state["models"][0])
|
||||
else:
|
||||
for model, state_dict in zip(models, state["models"]):
|
||||
model.load_state_dict(state_dict)
|
||||
return models
|
||||
|
||||
def save(self, checkpoint):
|
||||
"""Saves the model at the provided checkpoint.
|
||||
"""Saves the model(s) to the provided checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint (str): Path to target checkpoint file.
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import time
|
||||
import torch
|
||||
|
||||
@@ -10,6 +11,12 @@ from ray.experimental.sgd.utils import TimerStat
|
||||
|
||||
def train(model, train_iterator, criterion, optimizer, config):
|
||||
"""Runs 1 training epoch"""
|
||||
if isinstance(model, collections.Iterable) or isinstance(
|
||||
optimizer, collections.Iterable):
|
||||
raise ValueError(
|
||||
"Need to provide custom training function if using multi-model "
|
||||
"or multi-optimizer training.")
|
||||
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
@@ -63,6 +70,10 @@ def train(model, train_iterator, criterion, optimizer, config):
|
||||
|
||||
|
||||
def validate(model, val_iterator, criterion, config):
|
||||
if isinstance(model, collections.Iterable):
|
||||
raise ValueError(
|
||||
"Need to provide custom validation function if using multi-model "
|
||||
"training.")
|
||||
batch_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import torch.distributed as dist
|
||||
from ray import tune
|
||||
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
|
||||
from ray.experimental.sgd.pytorch import PyTorchTrainer, PyTorchTrainable
|
||||
from ray.experimental.sgd.pytorch.utils import train
|
||||
|
||||
from ray.experimental.sgd.examples.train_example import (
|
||||
model_creator, optimizer_creator, data_creator)
|
||||
@@ -39,6 +40,64 @@ def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
assert validation_loss2 <= validation_loss1
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: F811
|
||||
"num_replicas", [1, 2] if dist.is_available() else [1])
|
||||
def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
def custom_train(models, dataloader, criterion, optimizers, config):
|
||||
result = {}
|
||||
for i, (model, optimizer) in enumerate(zip(models, optimizers)):
|
||||
result["model_{}".format(i)] = train(model, dataloader, criterion,
|
||||
optimizer, config)
|
||||
return result
|
||||
|
||||
def multi_model_creator(config):
|
||||
return nn.Linear(1, 1), nn.Linear(1, 1)
|
||||
|
||||
def multi_optimizer_creator(models, config):
|
||||
opts = [
|
||||
torch.optim.SGD(model.parameters(), lr=0.0001) for model in models
|
||||
]
|
||||
return opts[0], opts[1]
|
||||
|
||||
trainer1 = PyTorchTrainer(
|
||||
multi_model_creator,
|
||||
data_creator,
|
||||
multi_optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
train_function=custom_train,
|
||||
num_replicas=num_replicas)
|
||||
trainer1.train()
|
||||
|
||||
filename = os.path.join(tempfile.mkdtemp(), "checkpoint")
|
||||
trainer1.save(filename)
|
||||
|
||||
models1 = trainer1.get_model()
|
||||
|
||||
trainer1.shutdown()
|
||||
|
||||
trainer2 = PyTorchTrainer(
|
||||
multi_model_creator,
|
||||
data_creator,
|
||||
multi_optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=num_replicas)
|
||||
trainer2.restore(filename)
|
||||
|
||||
os.remove(filename)
|
||||
|
||||
models2 = trainer2.get_model()
|
||||
|
||||
for model_1, model_2 in zip(models1, models2):
|
||||
|
||||
model1_state_dict = model_1.state_dict()
|
||||
model2_state_dict = model_2.state_dict()
|
||||
|
||||
assert set(model1_state_dict.keys()) == set(model2_state_dict.keys())
|
||||
|
||||
for k in model1_state_dict:
|
||||
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: F811
|
||||
"num_replicas", [1, 2] if dist.is_available() else [1])
|
||||
def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import unittest
|
||||
|
||||
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
|
||||
|
||||
if sys.version_info >= (3, 3):
|
||||
from unittest.mock import MagicMock
|
||||
else:
|
||||
from mock import MagicMock
|
||||
|
||||
|
||||
class LinearDataset(torch.utils.data.Dataset):
|
||||
"""y = a * x + b"""
|
||||
|
||||
def __init__(self, a, b, size=1000):
|
||||
x = np.random.random(size).astype(np.float32) * 10
|
||||
x = np.arange(0, 10, 10 / size, dtype=np.float32)
|
||||
self.x = torch.from_numpy(x)
|
||||
self.y = torch.from_numpy(a * x + b)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.x[index, None], self.y[index, None]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.x)
|
||||
|
||||
|
||||
def model_creator(config):
|
||||
return nn.Linear(1, 1)
|
||||
|
||||
|
||||
def optimizer_creator(models, config):
|
||||
"""Returns optimizer."""
|
||||
return torch.optim.SGD(models.parameters(), lr=0.1)
|
||||
|
||||
|
||||
def loss_creator(config):
|
||||
return nn.MSELoss()
|
||||
|
||||
|
||||
def single_loader(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset)
|
||||
return train_loader
|
||||
|
||||
|
||||
def create_dataloaders(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5)
|
||||
validation_dataset = LinearDataset(2, 5, size=400)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset)
|
||||
validation_loader = torch.utils.data.DataLoader(validation_dataset)
|
||||
return train_loader, validation_loader
|
||||
|
||||
|
||||
class TestPyTorchRunner(unittest.TestCase):
|
||||
def testValidate(self):
|
||||
mock_function = MagicMock(returns=dict(mean_accuracy=10))
|
||||
runner = PyTorchRunner(
|
||||
model_creator,
|
||||
create_dataloaders,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
validation_function=mock_function)
|
||||
runner.setup()
|
||||
runner.step()
|
||||
runner.step()
|
||||
runner.step()
|
||||
self.assertEqual(mock_function.call_count, 0)
|
||||
runner.validate()
|
||||
self.assertTrue(mock_function.called)
|
||||
self.assertEqual(runner.stats()["epoch"], 3)
|
||||
|
||||
def testStep(self):
|
||||
mock_function = MagicMock(return_value=dict(mean_accuracy=10))
|
||||
runner = PyTorchRunner(
|
||||
model_creator,
|
||||
create_dataloaders,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
train_function=mock_function)
|
||||
runner.setup()
|
||||
runner.step()
|
||||
runner.step()
|
||||
result = runner.step()
|
||||
self.assertEqual(mock_function.call_count, 3)
|
||||
self.assertEqual(result["epoch"], 3)
|
||||
self.assertEqual(runner.stats()["epoch"], 3)
|
||||
|
||||
def testGivens(self):
|
||||
def three_model_creator(config):
|
||||
return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)
|
||||
|
||||
def three_optimizer_creator(models, config):
|
||||
opts = [
|
||||
torch.optim.SGD(model.parameters(), lr=0.1) for model in models
|
||||
]
|
||||
return opts[0], opts[1], opts[2]
|
||||
|
||||
runner = PyTorchRunner(three_model_creator, single_loader,
|
||||
three_optimizer_creator, loss_creator)
|
||||
runner.setup()
|
||||
|
||||
self.assertEqual(len(runner.given_models), 3)
|
||||
self.assertEqual(len(runner.given_optimizers), 3)
|
||||
|
||||
runner2 = PyTorchRunner(model_creator, single_loader,
|
||||
optimizer_creator, loss_creator)
|
||||
runner2.setup()
|
||||
|
||||
self.assertNotEqual(runner2.given_models, runner2.models)
|
||||
self.assertNotEqual(runner2.given_optimizers, runner2.optimizers)
|
||||
|
||||
def testMultiLoaders(self):
|
||||
def three_data_loader(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5)
|
||||
validation_dataset = LinearDataset(2, 5, size=400)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset)
|
||||
validation_loader = torch.utils.data.DataLoader(validation_dataset)
|
||||
return train_loader, validation_loader, validation_loader
|
||||
|
||||
runner = PyTorchRunner(model_creator, three_data_loader,
|
||||
optimizer_creator, loss_creator)
|
||||
with self.assertRaises(ValueError):
|
||||
runner.setup()
|
||||
|
||||
runner2 = PyTorchRunner(model_creator, three_data_loader,
|
||||
optimizer_creator, loss_creator)
|
||||
with self.assertRaises(ValueError):
|
||||
runner2.setup()
|
||||
|
||||
def testSingleLoader(self):
|
||||
runner = PyTorchRunner(model_creator, single_loader, optimizer_creator,
|
||||
loss_creator)
|
||||
runner.setup()
|
||||
runner.step()
|
||||
with self.assertRaises(ValueError):
|
||||
runner.validate()
|
||||
|
||||
def testMultiModel(self):
|
||||
def multi_model_creator(config):
|
||||
return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)
|
||||
|
||||
def multi_optimizer_creator(models, config):
|
||||
opts = [
|
||||
torch.optim.SGD(model.parameters(), lr=0.1) for model in models
|
||||
]
|
||||
return opts[0], opts[1], opts[2]
|
||||
|
||||
runner = PyTorchRunner(multi_model_creator, single_loader,
|
||||
multi_optimizer_creator, loss_creator)
|
||||
runner.setup()
|
||||
with self.assertRaises(ValueError):
|
||||
runner.step()
|
||||
Reference in New Issue
Block a user