[raysgd] Custom training operator (#7211)

This commit is contained in:
Richard Liaw
2020-03-01 21:22:48 -08:00
committed by GitHub
parent 2d97650b1e
commit 48cdca843f
15 changed files with 1013 additions and 702 deletions
+4 -1
View File
@@ -3,6 +3,7 @@ logger = logging.getLogger(__name__)
PyTorchTrainer = None
PyTorchTrainable = None
TrainingOperator = None
try:
import torch # noqa: F401
@@ -10,6 +11,8 @@ try:
from ray.util.sgd.pytorch.pytorch_trainer import (PyTorchTrainer,
PyTorchTrainable)
__all__ = ["PyTorchTrainer", "PyTorchTrainable"]
from ray.util.sgd.pytorch.training_operator import TrainingOperator
__all__ = ["PyTorchTrainer", "PyTorchTrainable", "TrainingOperator"]
except ImportError:
logger.warning("PyTorch not found. PyTorchTrainer will not be available")
+7
View File
@@ -0,0 +1,7 @@
USE_FP16 = "__use_fp16__"
BATCH_COUNT = "batch_count"
SCHEDULER_STEP = "scheduler_step"
SCHEDULER_STEP_BATCH = "batch"
SCHEDULER_STEP_EPOCH = "epoch"
VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH}
@@ -95,15 +95,22 @@ class DistributedPyTorchRunner(PyTorchRunner):
self.validation_loader = torch.utils.data.DataLoader(
val_set, batch_size=self.batch_size, **self.dataloader_config)
def step(self):
self.training_operator = self.training_operator_cls(
self.config,
models=self.models,
optimizers=self.optimizers,
criterion=self.criterion,
schedulers=self.schedulers,
use_fp16=self.use_fp16)
def train_epoch(self, **kwargs):
"""Runs a training epoch and updates the model parameters.
Automatically sets epoch of sampler if possible.
"""
logger.debug("Starting step")
if hasattr(self.train_loader.sampler, "set_epoch"):
self.train_loader.sampler.set_epoch(self.epoch)
return super(DistributedPyTorchRunner, self).step()
self.train_loader.sampler.set_epoch(self.epochs)
return super(DistributedPyTorchRunner, self).train_epoch(**kwargs)
def _get_model_state_dicts(self):
"""Fetch state from ``model.module`` instead of ``model``.
@@ -10,10 +10,9 @@ import torchvision.transforms as transforms
import ray
from ray.util.sgd.pytorch import (PyTorchTrainer, PyTorchTrainable)
from ray.util.sgd.pytorch.resnet import ResNet18
from ray.util.sgd.pytorch.utils import TEST_MODE
def initialization_hook(runner):
def initialization_hook():
print("NCCL DEBUG SET")
# Need this for avoiding a connection restart issue
os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
@@ -40,6 +39,11 @@ def cifar_creator(config):
validation_dataset = torchvision.datasets.CIFAR10(
root="~/data", train=False, download=False, transform=transform_test)
if config.get("test_mode"):
train_dataset = torch.utils.data.Subset(train_dataset, list(range(64)))
validation_dataset = torch.utils.data.Subset(validation_dataset,
list(range(64)))
return train_dataset, validation_dataset
@@ -58,7 +62,6 @@ def train_example(num_replicas=1,
use_gpu=False,
use_fp16=False,
test_mode=False):
config = {TEST_MODE: test_mode}
trainer1 = PyTorchTrainer(
ResNet18,
cifar_creator,
@@ -67,7 +70,10 @@ def train_example(num_replicas=1,
scheduler_creator=scheduler_creator,
initialization_hook=initialization_hook,
num_replicas=num_replicas,
config=config,
config={
"lr": 0.01,
"test_mode": test_mode
},
use_gpu=use_gpu,
batch_size=16 if test_mode else 512,
backend="nccl" if use_gpu else "gloo",
@@ -88,14 +94,14 @@ def tune_example(num_replicas=1, use_gpu=False, test_mode=False):
"model_creator": ResNet18,
"data_creator": cifar_creator,
"optimizer_creator": optimizer_creator,
"loss_creator": lambda config: nn.CrossEntropyLoss(),
"loss_creator": nn.CrossEntropyLoss,
"num_replicas": num_replicas,
"initialization_hook": initialization_hook,
"use_gpu": use_gpu,
"batch_size": 16 if test_mode else 512,
"config": {
"lr": tune.choice([1e-4, 1e-3, 5e-3, 1e-2]),
TEST_MODE: test_mode
"lr": tune.choice([1e-4, 1e-3]),
"test_mode": test_mode
},
"backend": "nccl" if use_gpu else "gloo"
}
+148 -141
View File
@@ -6,7 +6,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
@@ -16,25 +16,12 @@ from scipy.stats import entropy
import ray
from ray.util.sgd import PyTorchTrainer
from ray.util.sgd.pytorch.utils import TEST_MODE
# Training parameters
TRAIN_BATCHES = 5
# Number of channels in the training images. For color images this is 3
num_channels = 1
# Size of z latent vector (i.e. size of generator input)
latent_vector_size = 100
# Size of feature maps in generator
features_g = 32
# Size of feature maps in discriminator
features_d = 32
from ray.util.sgd.utils import override
from ray.util.sgd.pytorch import TrainingOperator
def data_creator(config):
return dset.MNIST(
dataset = datasets.MNIST(
root="~/mnist/",
download=True,
transform=transforms.Compose([
@@ -42,62 +29,56 @@ def data_creator(config):
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, )),
]))
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
if config.get("test_mode"):
dataset = torch.utils.data.Subset(dataset, list(range(64)))
return dataset
class Generator(nn.Module):
def __init__(self):
def __init__(self, latent_vector_size, features=32, num_channels=1):
super(Generator, self).__init__()
self.latent_vector_size = latent_vector_size
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(
latent_vector_size, features_g * 4, 4, 1, 0, bias=False),
nn.BatchNorm2d(features_g * 4),
latent_vector_size, features * 4, 4, 1, 0, bias=False),
nn.BatchNorm2d(features * 4),
nn.ReLU(True),
nn.ConvTranspose2d(
features_g * 4, features_g * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 2),
features * 4, features * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features * 2),
nn.ReLU(True),
nn.ConvTranspose2d(
features_g * 2, features_g, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g),
nn.ConvTranspose2d(features * 2, features, 4, 2, 1, bias=False),
nn.BatchNorm2d(features),
nn.ReLU(True),
nn.ConvTranspose2d(features_g, num_channels, 4, 2, 1, bias=False),
nn.ConvTranspose2d(features, num_channels, 4, 2, 1, bias=False),
nn.Tanh())
def forward(self, input):
return self.main(input)
def forward(self, x):
return self.main(x)
class Discriminator(nn.Module):
def __init__(self):
def __init__(self, features=32, num_channels=1):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(num_channels, features_d, 4, 2, 1, bias=False),
nn.Conv2d(num_channels, features, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 2), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 4), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features_d * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid())
nn.Conv2d(features, features * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features * 2), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features * 2, features * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features * 4), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid())
def forward(self, input):
return self.main(input)
def forward(self, x):
return self.main(x)
class Net(nn.Module):
class LeNet(nn.Module):
"""LeNet for MNist classification, used for inception_score."""
def __init__(self):
super(Net, self).__init__()
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
@@ -114,92 +95,22 @@ class Net(nn.Module):
return F.log_softmax(x, dim=1)
def inception_score(imgs, batch_size=32, splits=1):
N = len(imgs)
dtype = torch.FloatTensor
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
cm = Net()
cm.load_state_dict(torch.load(model_path))
cm.eval()
up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype)
def get_pred(x):
x = up(x)
x = cm(x)
return F.softmax(x).data.cpu().numpy()
preds = np.zeros((N, 10))
for i, batch in enumerate(dataloader, 0):
batch = batch.type(dtype)
batchv = Variable(batch)
batch_size_i = batch.size()[0]
preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)
# Now compute the mean kl-div
split_scores = []
for k in range(splits):
part = preds[k * (N // splits):(k + 1) * (N // splits), :]
py = np.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i, :]
scores.append(entropy(pyx, py))
split_scores.append(np.exp(np.mean(scores)))
return np.mean(split_scores), np.std(split_scores)
def model_creator(config):
netD = Discriminator()
netD.apply(weights_init)
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
netG = Generator()
netG.apply(weights_init)
return netD, netG
discriminator = Discriminator()
discriminator.apply(weights_init)
def train(config, models, dataloader, criterion, optimizers, **kwargs):
netD, netG = models
optimD, optimG = optimizers
real_label = 1
fake_label = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for i, data in enumerate(dataloader, 0):
if i >= TRAIN_BATCHES and config.get(TEST_MODE):
break
netD.zero_grad()
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size, ), real_label, device=device)
output = netD(real_cpu).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
noise = torch.randn(b_size, latent_vector_size, 1, 1, device=device)
fake = netG(noise)
label.fill_(fake_label)
output = netD(fake.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
errD = errD_real + errD_fake
optimD.step()
netG.zero_grad()
label.fill_(real_label)
output = netD(fake).view(-1)
errG = criterion(output, label)
errG.backward()
optimG.step()
is_score, is_std = inception_score(fake)
return {
"loss_g": errG.item(),
"loss_d": errD.item(),
"inception": is_score
}
generator = Generator(
latent_vector_size=config.get("latent_vector_size", 100))
generator.apply(weights_init)
return discriminator, generator
def optimizer_creator(models, config):
@@ -211,22 +122,122 @@ def optimizer_creator(models, config):
return discriminator_opt, generator_opt
class GANOperator(TrainingOperator):
def setup(self, config):
self.device = torch.device("cuda"
if torch.cuda.is_available() else "cpu")
self.classifier = LeNet()
self.classifier.load_state_dict(
torch.load(config["classification_model_path"]))
self.classifier.eval()
def inception_score(self, imgs, batch_size=32, splits=1):
"""Calculate the inception score of the generated images."""
N = len(imgs)
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
up = nn.Upsample(
size=(28, 28), mode="bilinear").type(torch.FloatTensor)
def get_pred(x):
x = up(x)
x = self.classifier(x)
return F.softmax(x).data.cpu().numpy()
# Obtain predictions for the fake provided images
preds = np.zeros((N, 10))
for i, batch in enumerate(dataloader, 0):
batch = batch.type(torch.FloatTensor)
batchv = Variable(batch)
batch_size_i = batch.size()[0]
preds[i * batch_size:i * batch_size +
batch_size_i] = get_pred(batchv)
# Now compute the mean kl-div
split_scores = []
for k in range(splits):
part = preds[k * (N // splits):(k + 1) * (N // splits), :]
py = np.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i, :]
scores.append(entropy(pyx, py))
split_scores.append(np.exp(np.mean(scores)))
return np.mean(split_scores), np.std(split_scores)
@override(TrainingOperator)
def train_batch(self, batch, batch_info):
"""Trains on one batch of data from the data creator."""
real_label = 1
fake_label = 0
discriminator, generator = self.models
optimD, optimG = self.optimizers
# Compute a discriminator update for real images
discriminator.zero_grad()
real_cpu = batch[0].to(self.device)
batch_size = real_cpu.size(0)
label = torch.full((batch_size, ), real_label, device=self.device)
output = discriminator(real_cpu).view(-1)
errD_real = self.criterion(output, label)
errD_real.backward()
# Compute a discriminator update for fake images
noise = torch.randn(
batch_size,
self.config.get("latent_vector_size", 100),
1,
1,
device=self.device)
fake = generator(noise)
label.fill_(fake_label)
output = discriminator(fake.detach()).view(-1)
errD_fake = self.criterion(output, label)
errD_fake.backward()
errD = errD_real + errD_fake
# Update the discriminator
optimD.step()
# Update the generator
generator.zero_grad()
label.fill_(real_label)
output = discriminator(fake).view(-1)
errG = self.criterion(output, label)
errG.backward()
optimG.step()
is_score, is_std = self.inception_score(fake)
return {
"loss_g": errG.item(),
"loss_d": errD.item(),
"inception": is_score,
"num_samples": batch_size
}
def train_example(num_replicas=1, use_gpu=False, test_mode=False):
config = {TEST_MODE: test_mode}
config = {
"test_mode": test_mode,
"classification_model_path": os.path.join(
os.path.dirname(ray.__file__),
"util/sgd/pytorch/examples/mnist_cnn.pt")
}
trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
nn.BCELoss,
train_function=train,
validation_function=False,
training_operator_cls=GANOperator,
num_replicas=num_replicas,
config=config,
use_gpu=use_gpu,
batch_size=16 if test_mode else 512,
backend="nccl" if use_gpu else "gloo")
for i in range(10):
stats = trainer.train(max_retries=3)
for i in range(5):
stats = trainer.train()
print(stats)
return trainer
@@ -240,7 +251,7 @@ if __name__ == "__main__":
"--address",
required=False,
type=str,
help="the address to use for Redis")
help="the address to use to connect to a cluster.")
parser.add_argument(
"--num-replicas",
"-n",
@@ -255,10 +266,6 @@ if __name__ == "__main__":
args, _ = parser.parse_known_args()
ray.init(address=args.address)
path = os.path.dirname(ray.__file__)
model_path = os.path.join(path, "util/sgd/pytorch/examples/mnist_cnn.pt")
# load the pretrained mnist classification model for inception_score
trainer = train_example(
num_replicas=args.num_replicas,
use_gpu=args.use_gpu,
+44 -37
View File
@@ -2,13 +2,15 @@ import collections
from filelock import FileLock
import logging
import inspect
import itertools
import os
import torch
import torch.utils.data
from torch.utils.data import Dataset
import ray
from ray.util.sgd.pytorch import utils as pytorch_utils
from ray.util.sgd.pytorch.constants import USE_FP16, SCHEDULER_STEP
from ray.util.sgd.pytorch.training_operator import TrainingOperator
from ray.util.sgd import utils
logger = logging.getLogger(__name__)
@@ -31,8 +33,7 @@ class PyTorchRunner:
loss_creator (dict -> loss | Loss class): see pytorch_trainer.py.
scheduler_creator (optimizers, dict -> schedulers): see
pytorch_trainer.py.
train_function: see pytorch_trainer.py
validation_function: see pytorch_trainer.py
training_operator_cls: see pytorch_trainer.py
config (dict): see pytorch_trainer.py.
dataloader_config (dict): See pytorch_trainer.py.
batch_size (int): see pytorch_trainer.py.
@@ -47,8 +48,7 @@ class PyTorchRunner:
optimizer_creator,
loss_creator,
scheduler_creator=None,
train_function=None,
validation_function=None,
training_operator_cls=None,
config=None,
dataloader_config=None,
batch_size=16,
@@ -60,17 +60,15 @@ class PyTorchRunner:
self.optimizer_creator = optimizer_creator
self.loss_creator = loss_creator
self.scheduler_creator = scheduler_creator
self.training_operator_cls = training_operator_cls or TrainingOperator
self.config = {} if config is None else config
self.dataloader_config = {
"num_workers": 2
} if dataloader_config is None else dataloader_config
self.train_function = train_function or pytorch_utils.train
self.validation_function = (validation_function
or pytorch_utils.validate)
self.batch_size = batch_size
self.verbose = True
self.epoch = 0
self.epochs = 0
self._timers = {
k: utils.TimerStat(window_size=1)
for k in [
@@ -160,6 +158,14 @@ class PyTorchRunner:
self.validation_loader = torch.utils.data.DataLoader(
val_set, batch_size=self.batch_size, **self.dataloader_config)
self.training_operator = self.training_operator_cls(
self.config,
models=self.models,
optimizers=self.optimizers,
criterion=self.criterion,
schedulers=self.schedulers,
use_fp16=self.use_fp16)
def get_node_ip(self):
"""Returns the IP address of the current node."""
return ray.services.get_node_ip_address()
@@ -168,47 +174,42 @@ class PyTorchRunner:
"""Finds a free port on the current node."""
return utils.find_free_port()
def step(self):
def train_epoch(self, num_steps=None, info=None):
"""Runs a training epoch and updates the model parameters."""
logger.debug("Begin Training Epoch {}".format(self.epoch + 1))
training_config = self.config.copy()
training_config.update({
pytorch_utils.USE_FP16: self.use_fp16,
pytorch_utils.SCHEDULER_STEP: self.scheduler_step_freq
logger.debug("Begin Training Step {}".format(self.epochs + 1))
info = info or {}
info.update({
USE_FP16: self.use_fp16,
SCHEDULER_STEP: self.scheduler_step_freq
})
with self._timers["training"]:
train_stats = self.train_function(
training_config,
self.given_models,
self.train_loader,
self.criterion,
self.given_optimizers,
scheduler=self.given_schedulers)
train_stats["epoch"] = self.epoch
self.epoch += 1
iterator = self.train_loader
if num_steps:
iterator = itertools.islice(iter(self.train_loader), num_steps)
train_stats = self.training_operator.train_epoch(iterator, info)
self.epochs += 1
train_stats.update(self.stats())
return train_stats
def validate(self):
def validate(self, num_steps=None, info=None):
"""Evaluates the model on the validation data set."""
if self.validation_loader is None:
raise ValueError("No validation dataloader provided.")
info = info or {}
with self._timers["validation"]:
validation_stats = self.validation_function(
self.config,
self.given_models,
self.validation_loader,
self.criterion,
scheduler=self.given_schedulers)
iterator = self.validation_loader
if num_steps:
iterator = itertools.islice(
iter(self.validation_loader), num_steps)
validation_stats = self.training_operator.validate(iterator, info)
validation_stats.update(self.stats())
return validation_stats
def stats(self):
"""Returns a dictionary of statistics collected."""
stats = {"epoch": self.epoch}
stats = {"epoch": self.epochs}
for k, t in self._timers.items():
stats[k + "_time_mean"] = t.mean
stats[k + "_time_total"] = t.sum
@@ -233,7 +234,8 @@ class PyTorchRunner:
"""Returns the state of the runner."""
state = {
"epoch": self.epoch,
"epoch": self.epochs,
"operator": self.training_operator.state_dict(),
"models": self._get_model_state_dicts(),
"optimizers": [opt.state_dict() for opt in self.optimizers],
"stats": self.stats()
@@ -262,13 +264,18 @@ class PyTorchRunner:
if self.use_fp16 and "amp" in state and amp:
amp.load_state_dict(state["amp"])
self.epoch = state["stats"]["epoch"]
self.epochs = state["stats"]["epoch"]
self.training_operator.load_state_dict(state_dict)
def apply_fn(self, fn):
return fn(self)
def apply(self, fn):
return fn()
def apply_operator(self, fn):
return fn(self.training_operator)
def shutdown(self):
"""Attempts to shut down the worker."""
del self.training_operator
del self.validation_loader
del self.train_loader
del self.criterion
+99 -46
View File
@@ -15,12 +15,20 @@ from ray.util.sgd.pytorch.distributed_pytorch_runner import (
DistributedPyTorchRunner)
from ray.util.sgd import utils
from ray.util.sgd.pytorch.pytorch_runner import PyTorchRunner
from ray.util.sgd.pytorch import utils as pytorch_utils
from ray.util.sgd.pytorch.constants import VALID_SCHEDULER_STEP
logger = logging.getLogger(__name__)
RESIZE_COOLDOWN_S = 10
def _validate_scheduler_step_freq(scheduler_step_freq):
if scheduler_step_freq:
if scheduler_step_freq not in VALID_SCHEDULER_STEP:
raise ValueError(
"Scheduler step freq must be in {}. Got {}".format(
VALID_SCHEDULER_STEP, scheduler_step_freq))
class PyTorchTrainer:
"""Train a PyTorch model using distributed PyTorch.
@@ -48,14 +56,15 @@ class PyTorchTrainer:
loss_creator=nn.MSELoss,
use_gpu=True
)
trainer.train()
for i in range(4):
trainer.train()
Args:
model_creator (dict -> Model(s)): Constructor function that takes in
config and returns the model(s) to be optimized. These must be
``torch.nn.Module`` objects. If multiple models are returned,
a ``train_function`` must be specified. You do not need to
a ``training_operator_cls`` must be specified. You do not need to
handle GPU/devices in this function; RaySGD will do that under
the hood.
data_creator (dict -> Dataset(s)): Constructor function
@@ -75,22 +84,18 @@ class PyTorchTrainer:
of ``torch.nn.modules.loss._Loss``, which is most Pytorch
loss classes. For example, ``loss_creator=torch.nn.BCELoss``.
scheduler_creator (optimizers, dict -> loss):
A constructor function for the scheduler loss. This is
A constructor function for the torch scheduler. This is
a function that takes in the generated optimizers (from
``optimizer_creator``) provided config for customization.
Be sure to set ``scheduler_step_freq`` to increment the
scheduler correctly.
train_function: Custom function for training. This function
will be executed in parallel across all workers at once. The
function needs to take in (models, train_dataloader, criterion,
optimizers, config), and return a dict of training stats.
validation_function: Custom function for validation. This function
will be executed in parallel across all workers at once.
This takes in (model, val_dataloader, criterion, config)
and returns a dict of validation stats.
training_operator_cls (type): Custom training operator class
that subclasses the TrainingOperator class. This class
will be copied onto all remote workers and used to specify
custom training and validation operations. Defaults to
TrainingOperator.
config (dict): Custom configuration value to be passed to
"model_creator", "data_creator", "optimizer_creator", and
"loss_creator".
all creator and operator constructors.
dataloader_config (dict): Configuration values to be passed into
the ``torch.utils.data.DataLoader`` object that wraps
the dataset on each parallel worker for both training
@@ -130,8 +135,7 @@ class PyTorchTrainer:
optimizer_creator,
loss_creator,
scheduler_creator=None,
train_function=None,
validation_function=None,
training_operator_cls=None,
initialization_hook=None,
config=None,
dataloader_config=None,
@@ -151,11 +155,10 @@ class PyTorchTrainer:
self.model_creator = model_creator
self.data_creator = data_creator
self.train_function = train_function
self.optimizer_creator = optimizer_creator
self.loss_creator = loss_creator
self.scheduler_creator = scheduler_creator
self.validation_function = validation_function
self.training_operator_cls = training_operator_cls
self.initialization_hook = initialization_hook
self.config = {} if config is None else config
self.dataloader_config = dataloader_config
@@ -166,6 +169,8 @@ class PyTorchTrainer:
logger.info("Using {} as backend.".format(backend))
self.backend = backend
# TODO: Have an auto "use_gpu" option to detect and use GPUs.
self.use_gpu = use_gpu
self.batch_size = batch_size
self.max_replicas = num_replicas
@@ -180,12 +185,7 @@ class PyTorchTrainer:
self._num_failures = 0
self._last_resize = float("-inf")
if scheduler_step_freq and (
scheduler_step_freq not in pytorch_utils.VALID_SCHEDULER_STEP):
raise ValueError(
"Scheduler step freq must be in {}. Got {}".format(
pytorch_utils.VALID_SCHEDULER_STEP, scheduler_step_freq))
_validate_scheduler_step_freq(scheduler_step_freq)
self.scheduler_step_freq = scheduler_step_freq
self._start_workers(self.max_replicas)
@@ -204,8 +204,7 @@ class PyTorchTrainer:
self.optimizer_creator,
self.loss_creator,
self.scheduler_creator,
train_function=self.train_function,
validation_function=self.validation_function,
training_operator_cls=self.training_operator_cls,
config=self.config,
dataloader_config=self.dataloader_config,
batch_size=self.batch_size,
@@ -243,8 +242,7 @@ class PyTorchTrainer:
self.loss_creator,
self.scheduler_creator,
backend=self.backend,
train_function=self.train_function,
validation_function=self.validation_function,
training_operator_cls=self.training_operator_cls,
config=self.config,
dataloader_config=self.dataloader_config,
batch_size=batch_size_per_replica,
@@ -266,21 +264,35 @@ class PyTorchTrainer:
for i, worker in enumerate(self.workers)
])
def train(self, max_retries=0, checkpoint="auto"):
def train(self,
num_steps=None,
max_retries=0,
checkpoint="auto",
info=None):
"""Runs a training epoch.
Runs an average over all values returned from workers. Set
`max_retries` to enable fault handling in case of instance preemption.
Args:
num_steps (int): Number of batches to compute update steps on.
This corresponds also to the number of times
``TrainingOperator.train_batch`` is called.
max_retries (int): Must be non-negative. If set to N, will
kill all current workers, query the Ray global state for
total available resources, and re-launch up to the
available resources. Behavior is not well-defined
in case of shared cluster usage.
checkpoint (str): Path to checkpoint to restore from if retrying.
If max_retries is set and checkpoint == "auto", PyTorchTrainer
will save a checkpoint before starting to train.
If max_retries is set and ``checkpoint == "auto"``,
PyTorchTrainer will save a checkpoint before starting to train.
info (dict): Optional dictionary passed to the training
operator for ``train_epoch`` and ``train_batch``.
Returns:
A dictionary of metrics for training.
You can provide custom metrics by passing in a custom
``training_operator_cls``.
"""
assert max_retries >= 0, "`max_retries` must be non-negative."
if max_retries:
@@ -296,7 +308,8 @@ class PyTorchTrainer:
self._resize_workers(checkpoint=checkpoint)
with self.optimizer_timer:
success, worker_stats = self._train_step()
success, worker_stats = self._train_epoch(
num_steps=num_steps, info=info)
# Fault handling
for i in range(max_retries):
if success:
@@ -306,7 +319,8 @@ class PyTorchTrainer:
self._resize_workers(checkpoint=checkpoint)
logger.info("Retrying training step with %d workers." % len(
self.workers))
success, worker_stats = self._train_step()
success, worker_stats = self._train_epoch(
num_steps=num_steps, info=info)
if not success:
raise RuntimeError("Training run failed.")
@@ -321,19 +335,58 @@ class PyTorchTrainer:
train_stats[stat_key] = worker_stats[0][stat_key]
return train_stats
def _train_step(self):
worker_stats = [w.step.remote() for w in self.workers]
def _train_epoch(self, num_steps=None, info=None):
worker_stats = [
w.train_epoch.remote(num_steps=num_steps, info=info)
for w in self.workers
]
success = utils.check_for_failure(worker_stats)
return success, worker_stats
def apply_all_workers(self, fn):
return ray.get([w.apply_fn.remote(fn) for w in self.workers])
"""Run a function on all operators on the workers.
def validate(self):
"""Evaluates the model on the validation data set."""
if self.validation_function is False:
return {}
worker_stats = ray.get([w.validate.remote() for w in self.workers])
Args:
fn (Callable): A function that takes in no arguments.
Returns:
A list of objects returned by ``fn`` on each worker.
"""
return ray.get([w.apply.remote(fn) for w in self.workers])
def apply_all_operators(self, fn):
"""Run a function on all operators on the workers.
Args:
fn (Callable[TrainingOperator]): A function that takes in a
TrainingOperator.
Returns:
A list of objects returned by ``fn`` on each operator.
"""
return ray.get([w.apply_operator.remote(fn) for w in self.workers])
def validate(self, num_steps=None, info=None):
"""Evaluates the model on the validation data set.
Args:
num_steps (int): Number of batches to compute update steps on.
This corresponds also to the number of times
``TrainingOperator.validate_batch`` is called.
info (dict): Optional dictionary passed to the training
operator for `validate` and `validate_batch`.
Returns:
A dictionary of metrics for validation.
You can provide custom metrics by passing in a custom
``training_operator_cls``.
"""
worker_stats = ray.get([
w.validate.remote(num_steps=num_steps, info=info)
for w in self.workers
])
validation_stats = {}
for stat_key in worker_stats[0]:
@@ -346,8 +399,8 @@ class PyTorchTrainer:
This is useful for lr_schedulers such as ``ReduceLROnPlateau``.
"""
self.apply_all_workers(
lambda runner: [sched.step(metric) for sched in runner.schedulers])
self.apply_all_operators(
lambda op: [sched.step(metric) for sched in op.schedulers])
def get_model(self):
"""Returns the learned model(s)."""
@@ -366,17 +419,18 @@ class PyTorchTrainer:
Args:
checkpoint (str): Path to target checkpoint file.
Returns:
checkpoint (str): Path to target checkpoint file.
"""
state = ray.get(self.workers[0].get_state.remote())
torch.save(state, checkpoint)
return checkpoint
def restore(self, checkpoint):
"""Restores the model from the provided checkpoint.
"""Restores the Trainer and all workers from the provided checkpoint.
Args:
checkpoint (str): Path to target checkpoint file.
"""
state = torch.load(checkpoint)
state_id = ray.put(state)
@@ -450,7 +504,6 @@ class PyTorchTrainable(Trainable):
validation_stats = self._trainer.validate()
train_stats.update(validation_stats)
# output {"mean_loss": test_loss, "mean_accuracy": accuracy}
return train_stats
@@ -0,0 +1,343 @@
import collections
import torch
from ray.util.sgd.utils import TimerStat, AverageMeter
from ray.util.sgd.pytorch.constants import (
SCHEDULER_STEP_EPOCH, SCHEDULER_STEP_BATCH, SCHEDULER_STEP, BATCH_COUNT)
amp = None
try:
from apex import amp
except ImportError:
# Apex library is not installed, so we cannot enable mixed precision.
# We don't log here because logging happens in the pytorch_runner,
# where amp is initialized.
pass
def _is_multiple(component):
"""Checks if a component (optimizer, model, etc) is not singular."""
return isinstance(component, collections.Iterable) and len(component) > 1
class TrainingOperator:
"""Abstract class for custom training or validation loops.
The scheduler will only be called at a batch or epoch frequency, depending
on the user parameter. Be sure to set ``scheduler_step_freq`` in
``PyTorchTrainer`` to either "batch" or "epoch" to increment the scheduler
correctly during training. If using a learning rate scheduler
that depends on validation loss, you can use ``trainer.update_scheduler``.
For both training and validation, there are two granularities that
you can provide customization: per epoch or per batch.
You do not need to override both.
.. image:: raysgd-custom.jpg
:scale: 80%
:align: center
Raises:
ValueError if multiple models/optimizers/schedulers are provided.
You are expected to subclass this class if you wish
to train over multiple models/optimizers/schedulers.
"""
def __init__(self,
config,
models,
optimizers,
criterion,
schedulers=None,
use_fp16=False):
# You are not expected to override this method.
self.timers = {
k: TimerStat()
for k in ["fwd", "grad", "apply", "epoch_time"]
}
self._validated_customization = False
self._models = models # List of models
assert isinstance(models, collections.Iterable), (
"Components need to be iterable. Got: {}".format(type(models)))
self._optimizers = optimizers # List of optimizers
assert isinstance(optimizers, collections.Iterable), (
"Components need to be iterable. Got: {}".format(type(optimizers)))
self._criterion = criterion
self._schedulers = schedulers
if schedulers:
assert isinstance(schedulers, collections.Iterable), (
"Components need to be iterable. Got: {}".format(
type(schedulers)))
self._config = config
self._use_fp16 = use_fp16
self.global_step = 0
if type(self) is TrainingOperator:
for component in (models, schedulers, optimizers):
if _is_multiple(component):
raise ValueError(
"Need to provide a custom operator subclassing "
"TrainingOperator if using multi-scheduler, "
"multi-model or multi-optimizer training/validation.")
self.setup(config)
def setup(self, config):
"""Override this method to implement custom operator setup.
Args:
config (dict): Custom configuration value to be passed to
all creator and operator constructors. Same as ``self.config``.
"""
pass
def train_epoch(self, iterator, info):
"""Runs one standard training pass over the train_iterator.
By default, this method will iterate over the given iterator and
call ``self.train_batch`` over each batch.
If ``scheduler_step_freq`` is set, this class will also step the
scheduler accordingly.
You do not need to call ``train_batch`` in this method if you plan
to implement a custom optimization/training routine here.
Args:
iterator (iter): Iterator over the training data for the entire
epoch. This iterator is expected to be entirely consumed.
info (dict): Dictionary for information to be used for custom
training operations.
Returns:
A dict of metrics from training.
"""
self._losses = AverageMeter()
self.model.train()
with self.timers["epoch_time"]:
for batch_idx, batch in enumerate(iterator):
batch_info = {
"batch_idx": batch_idx,
"global_step": self.global_step
}
batch_info.update(info)
metrics = self.train_batch(batch, batch_info=batch_info)
if self.scheduler and batch_info.get(
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
self.scheduler.step()
if "loss" in metrics:
self._losses.update(
metrics["loss"], n=metrics.get("num_samples", 1))
self.global_step += 1
if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:
self.scheduler.step()
stats = {
BATCH_COUNT: batch_idx + 1,
"mean_train_loss": self._losses.avg,
"last_train_loss": self._losses.val,
"epoch_time": self.timers["epoch_time"].last
}
stats.update({
timer_tag: timer.mean
for timer_tag, timer in self.timers.items()
})
return stats
def train_batch(self, batch, batch_info):
"""Computes loss and updates the model over one batch.
This method is responsible for computing the loss and gradient and
updating the model.
By default, this method implementation assumes that batches
are in (features, labels) format. If using amp/fp16
training, it will also scale the loss automatically.
You can provide custom loss metrics and training operations if you
override this method. If overriding this method, you can access model,
optimizer, criterion via ``self.model``, ``self.optimizer``,
and ``self.criterion``.
You do not need to override this method if you plan to
override ``train_epoch``.
Args:
batch: One item of the validation iterator.
batch_info (dict): Information dict passed in from ``train_epoch``.
Returns:
A dictionary of metrics.
By default, this dictionary contains "loss" and "num_samples".
"num_samples" corresponds to number of datapoints in the batch.
However, you can provide any number of other values.
"""
features, target = batch
# Create non_blocking tensors for distributed training
if torch.cuda.is_available():
features = features.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# Compute output.
with self.timers["fwd"]:
output = self.model(features)
loss = self.criterion(output, target)
# Compute gradients in a backward pass.
with self.timers["grad"]:
self.optimizer.zero_grad()
if self.use_fp16:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# Call step of optimizer to update model params.
with self.timers["apply"]:
self.optimizer.step()
return {"loss": loss.item(), "num_samples": features.size(0)}
def validate(self, val_iterator, info):
"""Runs one standard validation pass over the val_iterator.
This will call ``model.eval()`` and ``torch.no_grad`` when iterating
over the validation dataset.
If overriding this method, you can access model, criterion via
``self.model`` and ``self.criterion``. You also do not need to call
``validate_batch`` if overriding this method.
Args:
val_iterator (iter): Iterable constructed over the
validation dataset.
info: (dict): Dictionary for information to be used for custom
validation operations.
Returns:
A dict of metrics from the evaluation.
By default, returns "mean_accuracy" and "mean_validation_loss"
which is computed by aggregating "loss" and "correct" values
from ``validate_batch`` and dividing it by the sum of
``num_samples`` from all calls to ``self.validate_batch``.
"""
losses = AverageMeter()
total_correct = 0
# switch to evaluate mode
self.model.eval()
with torch.no_grad():
for batch_idx, batch in enumerate(val_iterator):
batch_info = {"batch_idx": batch_idx}
batch_info.update(info)
metrics = self.validate_batch(batch, batch_info)
if "loss" in metrics:
losses.update(
metrics["loss"], n=metrics.get("num_samples", 1))
if "num_correct" in metrics:
total_correct += metrics["num_correct"]
stats = {
"batch_count": batch_idx + 1,
"mean_validation_loss": losses.avg,
"mean_accuracy": total_correct / losses.count
}
return stats
def validate_batch(self, batch, batch_info):
"""Calcuates the loss and accuracy over a given batch.
You can override this method to provide arbitrary metrics.
Args:
batch: One item of the validation iterator.
batch_info (dict): Contains information per batch from
``validate()``.
Returns:
A dict of metrics.
By default, returns "loss", "num_correct", and "num_samples".
"""
features, target = batch
if torch.cuda.is_available():
features = features.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = self.model(features)
loss = self.criterion(output, target)
_, predicted = torch.max(output.data, 1)
return {
"loss": loss.item(),
"num_correct": (predicted == target).sum().item(),
"num_samples": target.size(0)
}
def state_dict(self):
"""Returns a serializable representation of the operator state."""
pass
def load_state_dict(self, state_dict):
"""Loads a serializable representation of the operator state."""
pass
@property
def config(self):
"""Dictionary as provided into PyTorchTrainer."""
return self._config
@property
def model(self):
"""First or only model created by the provided ``model_creator``."""
return self._models[0]
@property
def models(self):
"""List of models created by the provided ``model_creator``."""
return self._models
@property
def optimizer(self):
"""First or only optimizer(s) created by the ``optimizer_creator``."""
return self._optimizers[0]
@property
def optimizers(self):
"""List of optimizers created by the ``optimizer_creator``."""
return self._optimizers
@property
def criterion(self):
"""Criterion created by the provided ``loss_creator``."""
return self._criterion
@property
def scheduler(self):
"""First or only scheduler(s) created by the ``scheduler_creator``."""
if self._schedulers:
return self._schedulers[0]
@property
def schedulers(self):
"""List of schedulers created by the ``scheduler_creator``."""
return self._schedulers
@property
def use_fp16(self):
"""Whether the model and optimizer have been FP16 enabled."""
return self._use_fp16
class _TestingOperator(TrainingOperator):
def train_epoch(self, iterator, info):
func = self.config.get("custom_func")
if callable(func):
return func(self, iterator, info)
return {"done": 1}
-229
View File
@@ -1,229 +0,0 @@
import collections
import time
import torch
from ray.util.sgd.utils import TimerStat
amp = None
try:
from apex import amp
except ImportError:
# Apex library is not installed, so we cannot enable mixed precision.
# We don't log here because logging happens in the pytorch_runner,
# where amp is initialized.
pass
USE_FP16 = "__use_fp16__"
TEST_MODE = "__test_mode__"
BATCH_COUNT = "batch_processed"
SCHEDULER_STEP = "scheduler_step"
SCHEDULER_STEP_BATCH = "batch"
SCHEDULER_STEP_EPOCH = "epoch"
VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH}
def train(config, model, train_iterator, criterion, optimizer, scheduler=None):
"""Runs one standard training pass over the train_iterator.
This function automatically measures timing for various operations such
as host to device transfer, gradient calculation, and gradient application.
It also automatically detects and places the data on the given GPU device
if available.
The scheduler will only be called at a batch or epoch frequency, depending
on the user parameter. Be sure to set ``scheduler_step_freq`` in
``PyTorchTrainer`` to either "batch" or "epoch" to increment the scheduler
correctly during training. If using a learning rate scheduler
that depends on validation loss, you can use ``trainer.update_scheduler``.
Raises:
ValueError if multiple models/optimizers/schedulers are provided. You
are expected to have a custom training function if you wish
to use multiple models/optimizers/schedulers.
Args:
config: (dict): A user configuration provided into the Trainer
constructor.
model: The model as created by the model_creator.
train_iterator: An iterator created from the DataLoader which
wraps the provided Dataset.
criterion: The loss object created by the loss_creator.
optimizer: The torch.optim.Optimizer object as created by the
optimizer_creator.
scheduler (optional): The torch.optim.lr_scheduler object
as created by the scheduler_creator. Be sure to set
``scheduler_step_freq`` in ``PyTorchTrainer``
to increment the scheduler correctly.
Returns:
A dict of metrics from training.
"""
if isinstance(model, collections.Iterable) or isinstance(
optimizer, collections.Iterable) or isinstance(
scheduler, collections.Iterable):
raise ValueError(
"Need to provide custom training function if using multi-model "
"or multi-scheduler or multi-optimizer training.")
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
timers = {k: TimerStat() for k in ["h2d", "fwd", "grad", "apply"]}
# switch to train mode
model.train()
end = time.time()
for batch_idx, (features, target) in enumerate(train_iterator):
# measure data loading time
data_time.update(time.time() - end)
# Create non_blocking tensors for distributed training
with timers["h2d"]:
if torch.cuda.is_available():
features = features.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
with timers["fwd"]:
output = model(features)
loss = criterion(output, target)
# measure accuracy and record loss
losses.update(loss.item(), features.size(0))
with timers["grad"]:
# compute gradients in a backward pass
optimizer.zero_grad()
if config.get(USE_FP16):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
with timers["apply"]:
# Call step of optimizer to update model params
optimizer.step()
if scheduler and config.get(SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if config.get(TEST_MODE) and batch_idx == 0:
break
if scheduler and config.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:
scheduler.step()
stats = {
"batch_time": batch_time.avg,
BATCH_COUNT: batch_idx + 1,
"train_loss": losses.avg,
"data_time": data_time.avg,
}
stats.update({k: t.mean for k, t in timers.items()})
return stats
def validate(config, model, val_iterator, criterion, scheduler=None):
"""Runs one standard validation pass over the val_iterator.
This function automatically measures timing for various operations such
as host to device transfer and processing time for the batch.
It also automatically detects and places the data on the given GPU device
if available.
Raises:
ValueError if multiple models/schedulers are provided. You
are expected to have a custom validation function if you wish
to use multiple models/schedulers.
Args:
config: (dict): A user configuration provided into the Trainer
constructor.
model: The model as created by the model_creator.
train_iterator: An iterator created from the DataLoader which
wraps the provided Dataset.
criterion: The loss object created by the loss_creator.
scheduler (optional): The torch.optim.lr_scheduler object
as created by the scheduler_creator. By default,
this is not used in this function.
Returns:
A dict of metrics from the evaluation.
"""
if isinstance(model, collections.Iterable) or isinstance(
scheduler, collections.Iterable):
raise ValueError(
"Need to provide custom validation function if using multi-model "
"or multi-scheduler training.")
batch_time = AverageMeter()
losses = AverageMeter()
# switch to evaluate mode
model.eval()
correct = 0
total = 0
batch_idx = 0
with torch.no_grad():
end = time.time()
for batch_idx, (features, target) in enumerate(val_iterator):
if torch.cuda.is_available():
features = features.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(features)
loss = criterion(output, target)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
# measure accuracy and record loss
losses.update(loss.item(), features.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if config.get(TEST_MODE) and batch_idx == 0:
break
stats = {
BATCH_COUNT: batch_idx + 1,
"batch_time": batch_time.avg,
"validation_loss": losses.avg,
"mean_accuracy": correct / total,
"mean_loss": losses.sum / total,
}
return stats
class AverageMeter:
"""Computes and stores the average and current value."""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
+104 -59
View File
@@ -10,28 +10,34 @@ import torch.distributed as dist
import ray
from ray import tune
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
from ray.util.sgd.pytorch import PyTorchTrainer, PyTorchTrainable
from ray.util.sgd.pytorch.utils import (train, BATCH_COUNT, TEST_MODE,
SCHEDULER_STEP)
from ray.util.sgd.pytorch.training_operator import _TestingOperator
from ray.util.sgd.pytorch.constants import BATCH_COUNT, SCHEDULER_STEP
from ray.util.sgd.utils import check_for_failure
from ray.util.sgd.pytorch.examples.train_example import (
model_creator, optimizer_creator, data_creator, LinearDataset)
def test_test_mode(ray_start_2_cpus): # noqa: F811
@pytest.fixture
def ray_start_2_cpus():
address_info = ray.init(num_cpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
def test_single_step(ray_start_2_cpus): # noqa: F811
trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
config={TEST_MODE: True},
num_replicas=1)
metrics = trainer.train()
metrics = trainer.train(num_steps=1)
assert metrics[BATCH_COUNT] == 1
val_metrics = trainer.validate()
val_metrics = trainer.validate(num_steps=1)
assert val_metrics[BATCH_COUNT] == 1
@@ -45,29 +51,51 @@ def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
loss_creator=lambda config: nn.MSELoss(),
num_replicas=num_replicas)
for i in range(3):
train_loss1 = trainer.train()["train_loss"]
validation_loss1 = trainer.validate()["validation_loss"]
train_loss1 = trainer.train()["mean_train_loss"]
validation_loss1 = trainer.validate()["mean_validation_loss"]
for i in range(3):
train_loss2 = trainer.train()["train_loss"]
validation_loss2 = trainer.validate()["validation_loss"]
train_loss2 = trainer.train()["mean_train_loss"]
validation_loss2 = trainer.validate()["mean_validation_loss"]
print(train_loss1, train_loss2)
print(validation_loss1, validation_loss2)
assert train_loss2 <= train_loss1
assert validation_loss2 <= validation_loss1
assert train_loss2 <= train_loss1, (train_loss2, train_loss1)
assert validation_loss2 <= validation_loss1, (validation_loss2,
validation_loss1)
@pytest.mark.parametrize("num_replicas", [1, 2]
if dist.is_available() else [1])
def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
def custom_train(config, models, dataloader, criterion, optimizers,
**kwargs):
def test_multi_model(ray_start_2_cpus, num_replicas):
def train(*, model=None, criterion=None, optimizer=None, dataloader=None):
model.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(dataloader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return {
"accuracy": correct / total,
"train_loss": train_loss / (batch_idx + 1)
}
def train_epoch(self, iterator, info):
result = {}
for i, (model, optimizer) in enumerate(zip(models, optimizers)):
result["model_{}".format(i)] = train(config, model, dataloader,
criterion, optimizer)
for i, (model, optimizer) in enumerate(
zip(self.models, self.optimizers)):
result["model_{}".format(i)] = train(
model=model,
criterion=self.criterion,
optimizer=optimizer,
dataloader=iterator)
return result
def multi_model_creator(config):
@@ -84,7 +112,8 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
data_creator,
multi_optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
train_function=custom_train,
config={"custom_func": train_epoch},
training_operator_cls=_TestingOperator,
num_replicas=num_replicas)
trainer1.train()
@@ -100,6 +129,8 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
data_creator,
multi_optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
config={"custom_func": train_epoch},
training_operator_cls=_TestingOperator,
num_replicas=num_replicas)
trainer2.restore(filename)
@@ -123,16 +154,17 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
@pytest.mark.parametrize("num_replicas", [1, 2]
if dist.is_available() else [1])
def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811
def custom_train(config, model, dataloader, criterion, optimizer,
scheduler):
if config.get("models", 1) > 1:
assert len(model) == config["models"], config
def train_epoch(self, iterator, info):
if self.config.get("models", 1) > 1:
assert len(self.models) == self.config["models"], self.config
if config.get("optimizers", 1) > 1:
assert len(optimizer) == config["optimizers"], config
if self.config.get("optimizers", 1) > 1:
assert len(
self.optimizers) == self.config["optimizers"], self.config
if config.get("schedulers", 1) > 1:
assert len(scheduler) == config["schedulers"], config
if self.config.get("schedulers", 1) > 1:
assert len(
self.schedulers) == self.config["schedulers"], self.config
return {"done": 1}
def multi_model_creator(config):
@@ -167,12 +199,13 @@ def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811
multi_optimizer_creator,
loss_creator=nn.MSELoss,
scheduler_creator=multi_scheduler_creator,
train_function=custom_train,
training_operator_cls=_TestingOperator,
num_replicas=num_replicas,
config={
"models": model_count,
"optimizers": optimizer_count,
"schedulers": scheduler_count
"schedulers": scheduler_count,
"custom_func": train_epoch
})
trainer.train()
trainer.shutdown()
@@ -180,9 +213,8 @@ def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811
@pytest.mark.parametrize("scheduler_freq", ["epoch", "batch"])
def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
def custom_train(config, model, dataloader, criterion, optimizer,
scheduler):
assert config[SCHEDULER_STEP] == scheduler_freq
def train_epoch(self, iterator, info):
assert info[SCHEDULER_STEP] == scheduler_freq
return {"done": 1}
def scheduler_creator(optimizer, config):
@@ -194,18 +226,17 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
data_creator,
optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
scheduler_creator=scheduler_creator)
config={"custom_func": train_epoch},
training_operator_cls=_TestingOperator,
scheduler_creator=scheduler_creator,
scheduler_step_freq=scheduler_freq)
for i in range(3):
trainer.train()["train_loss"]
trainer.train()
trainer.shutdown()
def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
def custom_train(config, model, dataloader, criterion, optimizer,
scheduler):
return {"done": 1}
from torch.optim.lr_scheduler import ReduceLROnPlateau
trainer = PyTorchTrainer(
@@ -213,11 +244,13 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
data_creator,
optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer))
scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer),
training_operator_cls=_TestingOperator)
trainer.update_scheduler(0.5)
trainer.update_scheduler(0.5)
assert all(
trainer.apply_all_workers(lambda r: r.schedulers[0].last_epoch == 2))
trainer.apply_all_operators(
lambda op: op.schedulers[0].last_epoch == 2))
trainer.shutdown()
@@ -248,13 +281,13 @@ def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
# checks loss decreasing for every trials
for path, df in analysis.trial_dataframes.items():
train_loss1 = df.loc[0, "train_loss"]
train_loss2 = df.loc[1, "train_loss"]
validation_loss1 = df.loc[0, "validation_loss"]
validation_loss2 = df.loc[1, "validation_loss"]
mean_train_loss1 = df.loc[0, "mean_train_loss"]
mean_train_loss2 = df.loc[1, "mean_train_loss"]
mean_validation_loss1 = df.loc[0, "mean_validation_loss"]
mean_validation_loss2 = df.loc[1, "mean_validation_loss"]
assert train_loss2 <= train_loss1
assert validation_loss2 <= validation_loss1
assert mean_train_loss2 <= mean_train_loss1
assert mean_validation_loss2 <= mean_validation_loss1
@pytest.mark.parametrize("num_replicas", [1, 2]
@@ -303,15 +336,17 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
def single_loader(config):
return LinearDataset(2, 5, size=1000000)
def step_with_fail(self):
worker_stats = [w.step.remote() for w in self.workers]
def step_with_fail(self, *args, **kwargs):
worker_stats = [
w.train_epoch.remote(*args, **kwargs) for w in self.workers
]
if self._num_failures < 3:
time.sleep(1) # Make the batch will fail correctly.
self.workers[0].__ray_kill__()
success = check_for_failure(worker_stats)
return success, worker_stats
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail):
trainer1 = PyTorchTrainer(
model_creator,
single_loader,
@@ -331,15 +366,17 @@ def test_resize(ray_start_2_cpus): # noqa: F811
def single_loader(config):
return LinearDataset(2, 5, size=1000000)
def step_with_fail(self):
worker_stats = [w.step.remote() for w in self.workers]
def step_with_fail(self, *args, **kwargs):
worker_stats = [
w.train_epoch.remote(*args, **kwargs) for w in self.workers
]
if self._num_failures < 1:
time.sleep(1) # Make the batch will fail correctly.
self.workers[0].__ray_kill__()
success = check_for_failure(worker_stats)
return success, worker_stats
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail):
trainer1 = PyTorchTrainer(
model_creator,
single_loader,
@@ -365,15 +402,17 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
def single_loader(config):
return LinearDataset(2, 5, size=1000000)
def step_with_fail(self):
worker_stats = [w.step.remote() for w in self.workers]
def step_with_fail(self, *args, **kwargs):
worker_stats = [
w.train_epoch.remote(*args, **kwargs) for w in self.workers
]
if self._num_failures < 2:
time.sleep(1)
self.workers[0].__ray_kill__()
success = check_for_failure(worker_stats)
return success, worker_stats
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail):
trainer1 = PyTorchTrainer(
model_creator,
single_loader,
@@ -383,3 +422,9 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
num_replicas=2)
trainer1.train(max_retries=2)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", "-x", __file__]))
@@ -4,6 +4,7 @@ import torch.nn as nn
import unittest
from unittest.mock import MagicMock
from ray.util.sgd.pytorch.training_operator import TrainingOperator
from ray.util.sgd.pytorch.pytorch_runner import PyTorchRunner
@@ -46,39 +47,55 @@ def create_dataloaders(config):
class TestPyTorchRunner(unittest.TestCase):
def testValidate(self):
mock_function = MagicMock(returns=dict(mean_accuracy=10))
class MockOperator(TrainingOperator):
def setup(self, config):
self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
self.validate = MagicMock(returns=dict(mean_accuracy=10))
runner = PyTorchRunner(
model_creator,
create_dataloaders,
optimizer_creator,
loss_creator,
validation_function=mock_function)
training_operator_cls=MockOperator)
runner.setup()
runner.step()
runner.step()
runner.step()
self.assertEqual(mock_function.call_count, 0)
runner.train_epoch()
runner.train_epoch()
runner.train_epoch()
self.assertEqual(runner.training_operator.validate.call_count, 0)
runner.validate()
self.assertTrue(mock_function.called)
self.assertTrue(runner.training_operator.validate.called)
self.assertEqual(runner.stats()["epoch"], 3)
def testStep(self):
mock_function = MagicMock(return_value=dict(mean_accuracy=10))
def testtrain_epoch(self):
class MockOperator(TrainingOperator):
def setup(self, config):
self.count = 0
def train_epoch(self, *args, **kwargs):
self.count += 1
return {"count": self.count}
runner = PyTorchRunner(
model_creator,
create_dataloaders,
optimizer_creator,
loss_creator,
train_function=mock_function)
training_operator_cls=MockOperator)
runner.setup()
runner.step()
runner.step()
result = runner.step()
self.assertEqual(mock_function.call_count, 3)
self.assertEqual(result["epoch"], 3)
runner.train_epoch(num_steps=1)
runner.train_epoch(num_steps=1)
result = runner.train_epoch()
self.assertEqual(runner.training_operator.count, 3)
self.assertEqual(result["count"], 3)
self.assertEqual(runner.stats()["epoch"], 3)
def testGivens(self):
class MockOperator(TrainingOperator):
def setup(self, config):
self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
self.validate = MagicMock(returns=dict(mean_accuracy=10))
def three_model_creator(config):
return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)
@@ -88,8 +105,12 @@ class TestPyTorchRunner(unittest.TestCase):
]
return opts[0], opts[1], opts[2]
runner = PyTorchRunner(three_model_creator, single_loader,
three_optimizer_creator, loss_creator)
runner = PyTorchRunner(
three_model_creator,
single_loader,
three_optimizer_creator,
loss_creator,
training_operator_cls=MockOperator)
runner.setup()
self.assertEqual(len(runner.given_models), 3)
@@ -121,7 +142,7 @@ class TestPyTorchRunner(unittest.TestCase):
runner = PyTorchRunner(model_creator, single_loader, optimizer_creator,
loss_creator)
runner.setup()
runner.step()
runner.train_epoch()
with self.assertRaises(ValueError):
runner.validate()
@@ -132,7 +153,7 @@ class TestPyTorchRunner(unittest.TestCase):
optimizer_creator,
loss_creator=nn.MSELoss)
runner.setup()
runner.step()
runner.train_epoch()
def testMultiModel(self):
def multi_model_creator(config):
@@ -146,6 +167,6 @@ class TestPyTorchRunner(unittest.TestCase):
runner = PyTorchRunner(multi_model_creator, single_loader,
multi_optimizer_creator, loss_creator)
runner.setup()
with self.assertRaises(ValueError):
runner.step()
runner.setup()
+8
View File
@@ -149,3 +149,11 @@ def check_for_failure(remote_values):
except RayActorError as exc:
logger.exception(str(exc))
return False
def override(interface_class):
def overrider(method):
assert (method.__name__ in dir(interface_class))
return method
return overrider