[raysgd] Custom training operator (#7211)

This commit is contained in:
Richard Liaw
2020-03-01 21:22:48 -08:00
committed by GitHub
parent 2d97650b1e
commit 48cdca843f
15 changed files with 1013 additions and 702 deletions
Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

+182 -157
View File
@@ -10,6 +10,8 @@ Under the hood, ``PytorchTrainer`` will create *replicas* of your model (control
For end to end examples leveraging RaySGD PyTorchTrainer, jump to :ref:`raysgd-pytorch-examples`.
.. contents:: :local:
Setting up training
-------------------
@@ -81,6 +83,7 @@ for ``PyTorchTrainer(scheduler_creator=...)``.
:start-after: __torch_scheduler_start__
:end-before: __torch_scheduler_end__
.. _starting-pytorch-trainer:
Putting things together
@@ -115,30 +118,17 @@ You can also set the number of workers and whether the workers will use GPUs:
num_replicas=100,
use_gpu=True)
See the documentation on the PyTorchTrainer here: :ref:`ref-pytorch-trainer`. We'll look at the training APIs next.
Training APIs
-------------
Now that the trainer is constructed, you'll naturally want to train the model.
Now that the trainer is constructed, here's how to train the model.
.. code-block:: python
trainer.train()
This takes one pass over the training data.
To run the model on the validation data passed in by the ``data_creator``, you can simply call:
.. code-block:: python
trainer.validate()
You can customize the exact function that is called by using a customized training function (see :ref:`raysgd-custom-training`).
for i in range(10):
metrics = trainer.train()
val_metrics = trainer.validate()
Shutting down training
----------------------
Each ``train`` call makes one pass over the training data, and each ``validate`` call runs the model on the validation data passed in by the ``data_creator``. Provide a custom training operator (:ref:`raysgd-custom-training`) to calculate custom metrics or customize the training/validation process.
After training, you may want to reappropriate the Ray cluster. To release Ray resources obtained by the Trainer:
@@ -148,15 +138,131 @@ After training, you may want to reappropriate the Ray cluster. To release Ray re
.. note:: Be sure to call ``trainer.save()`` or ``trainer.get_model()`` before shutting down.
Initialization Functions
------------------------
See the documentation on the PyTorchTrainer here: :ref:`ref-pytorch-trainer`.
You may want to run some initializers on each worker when they are started. This may be something like setting an environment variable or downloading some data. You can do this via the ``initialization_hook`` parameter:
.. _raysgd-custom-training:
Custom Training and Validation (Operators)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``PyTorchTrainer`` allows you to run a custom training and validation loops in parallel on each worker, providing a flexible interface similar to using PyTorch natively.
This is done via the :ref:`ref-pytorch-operator` interface.
For both training and validation, there are two granularities that you can provide customization - per epoch and per batch. These correspond to ``train_batch``,
``train_epoch``, ``validate``, and ``validate_batch``. Other useful methods to override include ``setup``, ``save`` and ``restore``. You can use these
to manage state (like a classifier neural network for calculating inception score, or a heavy tokenizer).
Providing a custom operator is necessary if creator functions return multiple models, optimizers, or schedulers.
Below is a partial example of a custom ``TrainingOperator`` that provides a ``train_batch`` implementation for a Deep Convolutional GAN.
.. code-block:: python
import torch
from ray.util.sgd.pytorch import TrainingOperator
def initialization_hook(runner):
class GANOperator(TrainingOperator):
def setup(self, config):
"""Custom setup for this operator.
Args:
config (dict): Custom configuration value to be passed to
all creator and operator constructors. Same as ``self.config``.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def train_batch(self, batch, batch_info):
"""Trains on one batch of data from the data creator.
Example taken from:
https://github.com/eriklindernoren/PyTorch-GAN/blob/
a163b82beff3d01688d8315a3fd39080400e7c01/implementations/dcgan/dcgan.py
Args:
batch: One item of the validation iterator.
batch_info (dict): Information dict passed in from ``train_epoch``.
Returns:
A dict of metrics. Defaults to "loss" and "num_samples",
corresponding to the total number of datapoints in the batch.
"""
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
discriminator, generator = self.models
optimizer_D, optimizer_G = self.optimizers
# Adversarial ground truths
valid = Variable(Tensor(batch.shape[0], 1).fill_(1.0), requires_grad=False)
fake = Variable(Tensor(batch.shape[0], 1).fill_(0.0), requires_grad=False)
# Configure input
real_imgs = Variable(batch.type(Tensor))
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (
batch.shape[0], self.config["latent_dim"]))))
# Generate a batch of images
gen_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
return {
"loss_g": g_loss.item(),
"loss_d": d_loss.item(),
"num_samples": imgs.shape[0]
}
trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
nn.BCELoss,
training_operator_cls=GANOperator,
num_replicas=num_replicas,
config=config,
use_gpu=True,
batch_size=128)
for i in range(5):
stats = trainer.train()
print(stats)
See the `DCGAN example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/pytorch/examples/dcgan.py>`__ for an end to end example. It constructs two models and two optimizers and uses a custom training operator to provide a non-standard training loop.
Initialization Functions
------------------------
Use the ``initialization_hook`` parameter to initialize state on each worker process when they are started. This is useful when setting an environment variable:
.. code-block:: python
def initialization_hook():
print("NCCL DEBUG SET")
# Need this for avoiding a connection restart issue
os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
@@ -193,8 +299,8 @@ and ``trainer.load``, which wraps the relevant ``torch.save`` and ``torch.load``
trainer_2.restore(checkpoint_path)
Exporting a model for inference
-------------------------------
Retrieving the model
--------------------
The trained torch model can be extracted for use within the same Python program with ``trainer.get_model()``. This will load the state dictionary of the model(s).
@@ -242,22 +348,23 @@ To specify particular parameters for ``amp.initialize``, you can use the ``apex_
}
)
Note that if using a custom training function, you will need to manage loss scaling manually.
Note that if using a custom training operator (:ref:`raysgd-custom-training`), you will need to manage loss scaling manually.
Distributed Multi-node Training
-------------------------------
You can scale out your training onto multiple nodes without making any modifications to your training code. To train across a cluster, simply make sure that the Ray cluster is started.
You can scale your training to multiple nodes without making any modifications to your training code.
You can start a Ray cluster `via the Ray cluster launcher <autoscaling.html>`_ or `manually <using-ray-on-a-cluster.html>`_.
To train across a cluster, first make sure that the Ray cluster is started. You can start a Ray cluster `via the Ray cluster launcher <autoscaling.html>`_ or `manually <using-ray-on-a-cluster.html>`_.
.. code-block:: bash
Then, in your program, you'll need to connect to this cluster via ``ray.init``:
ray up CLUSTER.yaml
ray submit train.py --args="--address='auto'"
.. code-block:: python
Then, within ``train.py`` you can scale up the number of workers seamlessly across multiple nodes:
ray.init(address="auto") # or a specific redis address of the form "ip-address:port"
After connecting, you can scale up the number of workers seamlessly across multiple nodes:
.. code-block:: python
@@ -266,7 +373,10 @@ Then, within ``train.py`` you can scale up the number of workers seamlessly acro
data_creator,
optimizer_creator,
loss_creator=nn.MSELoss,
num_replicas=100)
num_replicas=100
)
trainer.train()
model = trainer.get_model()
Advanced: Fault Tolerance
@@ -310,22 +420,37 @@ Advanced: Hyperparameter Tuning
Simultaneous Multi-model Training
---------------------------------
In certain scenarios such as training GANs, you may want to use multiple models in the training loop. You can do this in the ``PyTorchTrainer`` by allowing the ``model_creator``, ``optimizer_creator``, and ``scheduler_creator`` to return multiple values.
If multiple models, optimizers, or schedulers are returned, you will need to provide a custom training function (and custom validation function if you plan to call ``validate``).
In certain scenarios, such as training GANs, you may want to use multiple models in the training loop. You can do this in the ``PyTorchTrainer`` by allowing the ``model_creator``, ``optimizer_creator``, and ``scheduler_creator`` to return multiple values. Provide a custom TrainingOperator (:ref:`raysgd-custom-training`) to train across multiple models.
You can see the `DCGAN script <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/pytorch/examples/dcgan.py>`_ for an end-to-end example.
.. code-block:: python
from ray.util.sgd.pytorch import PyTorchTrainer, TrainingOperator
def train(*, model=None, criterion=None, optimizer=None, dataloader=None):
model.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(dataloader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return {
"accuracy": correct / total,
"train_loss": train_loss / (batch_idx + 1)
}
def model_creator(config):
netD = Discriminator()
netD.apply(weights_init)
netG = Generator()
netG.apply(weights_init)
return netD, netG
return Discriminator(), Generator()
def optimizer_creator(models, config):
net_d, net_g = models
@@ -335,125 +460,27 @@ You can see the `DCGAN script <https://github.com/ray-project/ray/blob/master/py
net_g.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999))
return discriminator_opt, generator_opt
def custom_train(models, dataloader, criterion, optimizers, config):
result = {}
for i, (model, optimizer) in enumerate(zip(models, optimizers)):
result["model_{}".format(i)] = train(model, dataloader, criterion,
optimizer, config)
return result
class CustomOperator(TrainingOperator):
def train_epoch(self, dataloader, info):
result = {}
for i, (model, optimizer) in enumerate(
zip(self.models, self.optimizers)):
result["model_{}".format(i)] = train(
model=model,
criterion=self.criterion,
optimizer=optimizer,
dataloader=dataloader)
return result
trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
loss_creator=nn.BCELoss,
train_function=custom_train)
training_operator_cls=CustomOperator)
.. _raysgd-custom-training:
trainer.train()
Custom Training and Validation Functions
----------------------------------------
``PyTorchTrainer`` allows you to run a custom training and validation step in parallel on each worker, providing a flexibility similar to using PyTorch natively. This is done via the ``train_function`` and ``validation_function`` parameters.
Note that this is needed if the model creator returns multiple models, optimizers, or schedulers.
.. code-block:: python
def train(config, model, train_iterator, criterion, optimizer, scheduler=None):
"""Runs one standard training pass over the train_iterator.
Raises:
ValueError if multiple models/optimizers/schedulers are provided. You
are expected to have a custom training function if you wish
to use multiple models/optimizers/schedulers.
Args:
config: (dict): A user configuration provided into the Trainer
constructor.
model: The model(s) as created by the model_creator.
train_iterator: An iterator created from the DataLoader which
wraps the provided Dataset.
criterion: The loss object created by the loss_creator.
optimizer: The torch.optim.Optimizer(s) object
as created by the optimizer_creator.
scheduler (optional): The torch.optim.lr_scheduler(s) object
as created by the scheduler_creator.
Returns:
A dict of metrics from training.
"""
netD, netG = models
optimD, optimG = optimizers
real_label = 1
fake_label = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for i, data in enumerate(dataloader, 0):
netD.zero_grad()
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size, ), real_label, device=device)
output = netD(real_cpu).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
noise = torch.randn(b_size, latent_vector_size, 1, 1, device=device)
fake = netG(noise)
label.fill_(fake_label)
output = netD(fake.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
errD = errD_real + errD_fake
optimD.step()
netG.zero_grad()
label.fill_(real_label)
output = netD(fake).view(-1)
errG = criterion(output, label)
errG.backward()
optimG.step()
is_score, is_std = inception_score(fake)
return {
"loss_g": errG.item(),
"loss_d": errD.item(),
"inception": is_score
}
def custom_validate(config, model, val_iterator, criterion, scheduler=None):
"""Runs one standard validation pass over the val_iterator.
Args:
config: (dict): A user configuration provided into the Trainer
constructor.
model: The model(s) as created by the model_creator.
train_iterator: An iterator created from the DataLoader which
wraps the provided Dataset.
criterion: The loss object created by the loss_creator.
scheduler (optional): The torch.optim.lr_scheduler object(s)
as created by the scheduler_creator.
Returns:
A dict of metrics from the evaluation.
"""
...
return {"validation_accuracy": 0.5}
trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
nn.BCELoss,
train_function=train,
validation_function=custom_validate,
...
)
Feature Requests
----------------
@@ -472,9 +499,7 @@ to contribute an example, feel free to create a `pull request here <https://gith
Simple example of using Ray's PyTorchTrainer.
- `CIFAR10 example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/pytorch/examples/cifar_pytorch_example.py>`__:
Training a ResNet18 model on CIFAR10. It uses a custom training
function, a custom validation function, and custom initialization code for each worker.
Training a ResNet18 model on CIFAR10.
- `DCGAN example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/pytorch/examples/dcgan.py>`__:
Training a Deep Convolutional GAN on MNIST. It constructs
two models and two optimizers and uses a custom training and validation function.
Training a Deep Convolutional GAN on MNIST. It constructs two models and two optimizers and uses a custom training operator.
+8
View File
@@ -11,6 +11,14 @@ PyTorchTrainer
.. automethod:: __init__
.. _ref-pytorch-operator:
PyTorch TrainingOperator
------------------------
.. autoclass:: ray.util.sgd.pytorch.TrainingOperator
:members:
PyTorchTrainable
----------------
+4 -1
View File
@@ -3,6 +3,7 @@ logger = logging.getLogger(__name__)
PyTorchTrainer = None
PyTorchTrainable = None
TrainingOperator = None
try:
import torch # noqa: F401
@@ -10,6 +11,8 @@ try:
from ray.util.sgd.pytorch.pytorch_trainer import (PyTorchTrainer,
PyTorchTrainable)
__all__ = ["PyTorchTrainer", "PyTorchTrainable"]
from ray.util.sgd.pytorch.training_operator import TrainingOperator
__all__ = ["PyTorchTrainer", "PyTorchTrainable", "TrainingOperator"]
except ImportError:
logger.warning("PyTorch not found. PyTorchTrainer will not be available")
+7
View File
@@ -0,0 +1,7 @@
USE_FP16 = "__use_fp16__"
BATCH_COUNT = "batch_count"
SCHEDULER_STEP = "scheduler_step"
SCHEDULER_STEP_BATCH = "batch"
SCHEDULER_STEP_EPOCH = "epoch"
VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH}
@@ -95,15 +95,22 @@ class DistributedPyTorchRunner(PyTorchRunner):
self.validation_loader = torch.utils.data.DataLoader(
val_set, batch_size=self.batch_size, **self.dataloader_config)
def step(self):
self.training_operator = self.training_operator_cls(
self.config,
models=self.models,
optimizers=self.optimizers,
criterion=self.criterion,
schedulers=self.schedulers,
use_fp16=self.use_fp16)
def train_epoch(self, **kwargs):
"""Runs a training epoch and updates the model parameters.
Automatically sets epoch of sampler if possible.
"""
logger.debug("Starting step")
if hasattr(self.train_loader.sampler, "set_epoch"):
self.train_loader.sampler.set_epoch(self.epoch)
return super(DistributedPyTorchRunner, self).step()
self.train_loader.sampler.set_epoch(self.epochs)
return super(DistributedPyTorchRunner, self).train_epoch(**kwargs)
def _get_model_state_dicts(self):
"""Fetch state from ``model.module`` instead of ``model``.
@@ -10,10 +10,9 @@ import torchvision.transforms as transforms
import ray
from ray.util.sgd.pytorch import (PyTorchTrainer, PyTorchTrainable)
from ray.util.sgd.pytorch.resnet import ResNet18
from ray.util.sgd.pytorch.utils import TEST_MODE
def initialization_hook(runner):
def initialization_hook():
print("NCCL DEBUG SET")
# Need this for avoiding a connection restart issue
os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
@@ -40,6 +39,11 @@ def cifar_creator(config):
validation_dataset = torchvision.datasets.CIFAR10(
root="~/data", train=False, download=False, transform=transform_test)
if config.get("test_mode"):
train_dataset = torch.utils.data.Subset(train_dataset, list(range(64)))
validation_dataset = torch.utils.data.Subset(validation_dataset,
list(range(64)))
return train_dataset, validation_dataset
@@ -58,7 +62,6 @@ def train_example(num_replicas=1,
use_gpu=False,
use_fp16=False,
test_mode=False):
config = {TEST_MODE: test_mode}
trainer1 = PyTorchTrainer(
ResNet18,
cifar_creator,
@@ -67,7 +70,10 @@ def train_example(num_replicas=1,
scheduler_creator=scheduler_creator,
initialization_hook=initialization_hook,
num_replicas=num_replicas,
config=config,
config={
"lr": 0.01,
"test_mode": test_mode
},
use_gpu=use_gpu,
batch_size=16 if test_mode else 512,
backend="nccl" if use_gpu else "gloo",
@@ -88,14 +94,14 @@ def tune_example(num_replicas=1, use_gpu=False, test_mode=False):
"model_creator": ResNet18,
"data_creator": cifar_creator,
"optimizer_creator": optimizer_creator,
"loss_creator": lambda config: nn.CrossEntropyLoss(),
"loss_creator": nn.CrossEntropyLoss,
"num_replicas": num_replicas,
"initialization_hook": initialization_hook,
"use_gpu": use_gpu,
"batch_size": 16 if test_mode else 512,
"config": {
"lr": tune.choice([1e-4, 1e-3, 5e-3, 1e-2]),
TEST_MODE: test_mode
"lr": tune.choice([1e-4, 1e-3]),
"test_mode": test_mode
},
"backend": "nccl" if use_gpu else "gloo"
}
+148 -141
View File
@@ -6,7 +6,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
@@ -16,25 +16,12 @@ from scipy.stats import entropy
import ray
from ray.util.sgd import PyTorchTrainer
from ray.util.sgd.pytorch.utils import TEST_MODE
# Training parameters
TRAIN_BATCHES = 5
# Number of channels in the training images. For color images this is 3
num_channels = 1
# Size of z latent vector (i.e. size of generator input)
latent_vector_size = 100
# Size of feature maps in generator
features_g = 32
# Size of feature maps in discriminator
features_d = 32
from ray.util.sgd.utils import override
from ray.util.sgd.pytorch import TrainingOperator
def data_creator(config):
return dset.MNIST(
dataset = datasets.MNIST(
root="~/mnist/",
download=True,
transform=transforms.Compose([
@@ -42,62 +29,56 @@ def data_creator(config):
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, )),
]))
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
if config.get("test_mode"):
dataset = torch.utils.data.Subset(dataset, list(range(64)))
return dataset
class Generator(nn.Module):
def __init__(self):
def __init__(self, latent_vector_size, features=32, num_channels=1):
super(Generator, self).__init__()
self.latent_vector_size = latent_vector_size
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(
latent_vector_size, features_g * 4, 4, 1, 0, bias=False),
nn.BatchNorm2d(features_g * 4),
latent_vector_size, features * 4, 4, 1, 0, bias=False),
nn.BatchNorm2d(features * 4),
nn.ReLU(True),
nn.ConvTranspose2d(
features_g * 4, features_g * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 2),
features * 4, features * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features * 2),
nn.ReLU(True),
nn.ConvTranspose2d(
features_g * 2, features_g, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g),
nn.ConvTranspose2d(features * 2, features, 4, 2, 1, bias=False),
nn.BatchNorm2d(features),
nn.ReLU(True),
nn.ConvTranspose2d(features_g, num_channels, 4, 2, 1, bias=False),
nn.ConvTranspose2d(features, num_channels, 4, 2, 1, bias=False),
nn.Tanh())
def forward(self, input):
return self.main(input)
def forward(self, x):
return self.main(x)
class Discriminator(nn.Module):
def __init__(self):
def __init__(self, features=32, num_channels=1):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(num_channels, features_d, 4, 2, 1, bias=False),
nn.Conv2d(num_channels, features, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 2), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 4), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features_d * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid())
nn.Conv2d(features, features * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features * 2), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features * 2, features * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features * 4), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(features * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid())
def forward(self, input):
return self.main(input)
def forward(self, x):
return self.main(x)
class Net(nn.Module):
class LeNet(nn.Module):
"""LeNet for MNist classification, used for inception_score."""
def __init__(self):
super(Net, self).__init__()
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
@@ -114,92 +95,22 @@ class Net(nn.Module):
return F.log_softmax(x, dim=1)
def inception_score(imgs, batch_size=32, splits=1):
N = len(imgs)
dtype = torch.FloatTensor
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
cm = Net()
cm.load_state_dict(torch.load(model_path))
cm.eval()
up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype)
def get_pred(x):
x = up(x)
x = cm(x)
return F.softmax(x).data.cpu().numpy()
preds = np.zeros((N, 10))
for i, batch in enumerate(dataloader, 0):
batch = batch.type(dtype)
batchv = Variable(batch)
batch_size_i = batch.size()[0]
preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)
# Now compute the mean kl-div
split_scores = []
for k in range(splits):
part = preds[k * (N // splits):(k + 1) * (N // splits), :]
py = np.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i, :]
scores.append(entropy(pyx, py))
split_scores.append(np.exp(np.mean(scores)))
return np.mean(split_scores), np.std(split_scores)
def model_creator(config):
netD = Discriminator()
netD.apply(weights_init)
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
netG = Generator()
netG.apply(weights_init)
return netD, netG
discriminator = Discriminator()
discriminator.apply(weights_init)
def train(config, models, dataloader, criterion, optimizers, **kwargs):
netD, netG = models
optimD, optimG = optimizers
real_label = 1
fake_label = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for i, data in enumerate(dataloader, 0):
if i >= TRAIN_BATCHES and config.get(TEST_MODE):
break
netD.zero_grad()
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size, ), real_label, device=device)
output = netD(real_cpu).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
noise = torch.randn(b_size, latent_vector_size, 1, 1, device=device)
fake = netG(noise)
label.fill_(fake_label)
output = netD(fake.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
errD = errD_real + errD_fake
optimD.step()
netG.zero_grad()
label.fill_(real_label)
output = netD(fake).view(-1)
errG = criterion(output, label)
errG.backward()
optimG.step()
is_score, is_std = inception_score(fake)
return {
"loss_g": errG.item(),
"loss_d": errD.item(),
"inception": is_score
}
generator = Generator(
latent_vector_size=config.get("latent_vector_size", 100))
generator.apply(weights_init)
return discriminator, generator
def optimizer_creator(models, config):
@@ -211,22 +122,122 @@ def optimizer_creator(models, config):
return discriminator_opt, generator_opt
class GANOperator(TrainingOperator):
def setup(self, config):
self.device = torch.device("cuda"
if torch.cuda.is_available() else "cpu")
self.classifier = LeNet()
self.classifier.load_state_dict(
torch.load(config["classification_model_path"]))
self.classifier.eval()
def inception_score(self, imgs, batch_size=32, splits=1):
"""Calculate the inception score of the generated images."""
N = len(imgs)
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
up = nn.Upsample(
size=(28, 28), mode="bilinear").type(torch.FloatTensor)
def get_pred(x):
x = up(x)
x = self.classifier(x)
return F.softmax(x).data.cpu().numpy()
# Obtain predictions for the fake provided images
preds = np.zeros((N, 10))
for i, batch in enumerate(dataloader, 0):
batch = batch.type(torch.FloatTensor)
batchv = Variable(batch)
batch_size_i = batch.size()[0]
preds[i * batch_size:i * batch_size +
batch_size_i] = get_pred(batchv)
# Now compute the mean kl-div
split_scores = []
for k in range(splits):
part = preds[k * (N // splits):(k + 1) * (N // splits), :]
py = np.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i, :]
scores.append(entropy(pyx, py))
split_scores.append(np.exp(np.mean(scores)))
return np.mean(split_scores), np.std(split_scores)
@override(TrainingOperator)
def train_batch(self, batch, batch_info):
"""Trains on one batch of data from the data creator."""
real_label = 1
fake_label = 0
discriminator, generator = self.models
optimD, optimG = self.optimizers
# Compute a discriminator update for real images
discriminator.zero_grad()
real_cpu = batch[0].to(self.device)
batch_size = real_cpu.size(0)
label = torch.full((batch_size, ), real_label, device=self.device)
output = discriminator(real_cpu).view(-1)
errD_real = self.criterion(output, label)
errD_real.backward()
# Compute a discriminator update for fake images
noise = torch.randn(
batch_size,
self.config.get("latent_vector_size", 100),
1,
1,
device=self.device)
fake = generator(noise)
label.fill_(fake_label)
output = discriminator(fake.detach()).view(-1)
errD_fake = self.criterion(output, label)
errD_fake.backward()
errD = errD_real + errD_fake
# Update the discriminator
optimD.step()
# Update the generator
generator.zero_grad()
label.fill_(real_label)
output = discriminator(fake).view(-1)
errG = self.criterion(output, label)
errG.backward()
optimG.step()
is_score, is_std = self.inception_score(fake)
return {
"loss_g": errG.item(),
"loss_d": errD.item(),
"inception": is_score,
"num_samples": batch_size
}
def train_example(num_replicas=1, use_gpu=False, test_mode=False):
config = {TEST_MODE: test_mode}
config = {
"test_mode": test_mode,
"classification_model_path": os.path.join(
os.path.dirname(ray.__file__),
"util/sgd/pytorch/examples/mnist_cnn.pt")
}
trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
nn.BCELoss,
train_function=train,
validation_function=False,
training_operator_cls=GANOperator,
num_replicas=num_replicas,
config=config,
use_gpu=use_gpu,
batch_size=16 if test_mode else 512,
backend="nccl" if use_gpu else "gloo")
for i in range(10):
stats = trainer.train(max_retries=3)
for i in range(5):
stats = trainer.train()
print(stats)
return trainer
@@ -240,7 +251,7 @@ if __name__ == "__main__":
"--address",
required=False,
type=str,
help="the address to use for Redis")
help="the address to use to connect to a cluster.")
parser.add_argument(
"--num-replicas",
"-n",
@@ -255,10 +266,6 @@ if __name__ == "__main__":
args, _ = parser.parse_known_args()
ray.init(address=args.address)
path = os.path.dirname(ray.__file__)
model_path = os.path.join(path, "util/sgd/pytorch/examples/mnist_cnn.pt")
# load the pretrained mnist classification model for inception_score
trainer = train_example(
num_replicas=args.num_replicas,
use_gpu=args.use_gpu,
+44 -37
View File
@@ -2,13 +2,15 @@ import collections
from filelock import FileLock
import logging
import inspect
import itertools
import os
import torch
import torch.utils.data
from torch.utils.data import Dataset
import ray
from ray.util.sgd.pytorch import utils as pytorch_utils
from ray.util.sgd.pytorch.constants import USE_FP16, SCHEDULER_STEP
from ray.util.sgd.pytorch.training_operator import TrainingOperator
from ray.util.sgd import utils
logger = logging.getLogger(__name__)
@@ -31,8 +33,7 @@ class PyTorchRunner:
loss_creator (dict -> loss | Loss class): see pytorch_trainer.py.
scheduler_creator (optimizers, dict -> schedulers): see
pytorch_trainer.py.
train_function: see pytorch_trainer.py
validation_function: see pytorch_trainer.py
training_operator_cls: see pytorch_trainer.py
config (dict): see pytorch_trainer.py.
dataloader_config (dict): See pytorch_trainer.py.
batch_size (int): see pytorch_trainer.py.
@@ -47,8 +48,7 @@ class PyTorchRunner:
optimizer_creator,
loss_creator,
scheduler_creator=None,
train_function=None,
validation_function=None,
training_operator_cls=None,
config=None,
dataloader_config=None,
batch_size=16,
@@ -60,17 +60,15 @@ class PyTorchRunner:
self.optimizer_creator = optimizer_creator
self.loss_creator = loss_creator
self.scheduler_creator = scheduler_creator
self.training_operator_cls = training_operator_cls or TrainingOperator
self.config = {} if config is None else config
self.dataloader_config = {
"num_workers": 2
} if dataloader_config is None else dataloader_config
self.train_function = train_function or pytorch_utils.train
self.validation_function = (validation_function
or pytorch_utils.validate)
self.batch_size = batch_size
self.verbose = True
self.epoch = 0
self.epochs = 0
self._timers = {
k: utils.TimerStat(window_size=1)
for k in [
@@ -160,6 +158,14 @@ class PyTorchRunner:
self.validation_loader = torch.utils.data.DataLoader(
val_set, batch_size=self.batch_size, **self.dataloader_config)
self.training_operator = self.training_operator_cls(
self.config,
models=self.models,
optimizers=self.optimizers,
criterion=self.criterion,
schedulers=self.schedulers,
use_fp16=self.use_fp16)
def get_node_ip(self):
"""Returns the IP address of the current node."""
return ray.services.get_node_ip_address()
@@ -168,47 +174,42 @@ class PyTorchRunner:
"""Finds a free port on the current node."""
return utils.find_free_port()
def step(self):
def train_epoch(self, num_steps=None, info=None):
"""Runs a training epoch and updates the model parameters."""
logger.debug("Begin Training Epoch {}".format(self.epoch + 1))
training_config = self.config.copy()
training_config.update({
pytorch_utils.USE_FP16: self.use_fp16,
pytorch_utils.SCHEDULER_STEP: self.scheduler_step_freq
logger.debug("Begin Training Step {}".format(self.epochs + 1))
info = info or {}
info.update({
USE_FP16: self.use_fp16,
SCHEDULER_STEP: self.scheduler_step_freq
})
with self._timers["training"]:
train_stats = self.train_function(
training_config,
self.given_models,
self.train_loader,
self.criterion,
self.given_optimizers,
scheduler=self.given_schedulers)
train_stats["epoch"] = self.epoch
self.epoch += 1
iterator = self.train_loader
if num_steps:
iterator = itertools.islice(iter(self.train_loader), num_steps)
train_stats = self.training_operator.train_epoch(iterator, info)
self.epochs += 1
train_stats.update(self.stats())
return train_stats
def validate(self):
def validate(self, num_steps=None, info=None):
"""Evaluates the model on the validation data set."""
if self.validation_loader is None:
raise ValueError("No validation dataloader provided.")
info = info or {}
with self._timers["validation"]:
validation_stats = self.validation_function(
self.config,
self.given_models,
self.validation_loader,
self.criterion,
scheduler=self.given_schedulers)
iterator = self.validation_loader
if num_steps:
iterator = itertools.islice(
iter(self.validation_loader), num_steps)
validation_stats = self.training_operator.validate(iterator, info)
validation_stats.update(self.stats())
return validation_stats
def stats(self):
"""Returns a dictionary of statistics collected."""
stats = {"epoch": self.epoch}
stats = {"epoch": self.epochs}
for k, t in self._timers.items():
stats[k + "_time_mean"] = t.mean
stats[k + "_time_total"] = t.sum
@@ -233,7 +234,8 @@ class PyTorchRunner:
"""Returns the state of the runner."""
state = {
"epoch": self.epoch,
"epoch": self.epochs,
"operator": self.training_operator.state_dict(),
"models": self._get_model_state_dicts(),
"optimizers": [opt.state_dict() for opt in self.optimizers],
"stats": self.stats()
@@ -262,13 +264,18 @@ class PyTorchRunner:
if self.use_fp16 and "amp" in state and amp:
amp.load_state_dict(state["amp"])
self.epoch = state["stats"]["epoch"]
self.epochs = state["stats"]["epoch"]
self.training_operator.load_state_dict(state_dict)
def apply_fn(self, fn):
return fn(self)
def apply(self, fn):
return fn()
def apply_operator(self, fn):
return fn(self.training_operator)
def shutdown(self):
"""Attempts to shut down the worker."""
del self.training_operator
del self.validation_loader
del self.train_loader
del self.criterion
+99 -46
View File
@@ -15,12 +15,20 @@ from ray.util.sgd.pytorch.distributed_pytorch_runner import (
DistributedPyTorchRunner)
from ray.util.sgd import utils
from ray.util.sgd.pytorch.pytorch_runner import PyTorchRunner
from ray.util.sgd.pytorch import utils as pytorch_utils
from ray.util.sgd.pytorch.constants import VALID_SCHEDULER_STEP
logger = logging.getLogger(__name__)
RESIZE_COOLDOWN_S = 10
def _validate_scheduler_step_freq(scheduler_step_freq):
if scheduler_step_freq:
if scheduler_step_freq not in VALID_SCHEDULER_STEP:
raise ValueError(
"Scheduler step freq must be in {}. Got {}".format(
VALID_SCHEDULER_STEP, scheduler_step_freq))
class PyTorchTrainer:
"""Train a PyTorch model using distributed PyTorch.
@@ -48,14 +56,15 @@ class PyTorchTrainer:
loss_creator=nn.MSELoss,
use_gpu=True
)
trainer.train()
for i in range(4):
trainer.train()
Args:
model_creator (dict -> Model(s)): Constructor function that takes in
config and returns the model(s) to be optimized. These must be
``torch.nn.Module`` objects. If multiple models are returned,
a ``train_function`` must be specified. You do not need to
a ``training_operator_cls`` must be specified. You do not need to
handle GPU/devices in this function; RaySGD will do that under
the hood.
data_creator (dict -> Dataset(s)): Constructor function
@@ -75,22 +84,18 @@ class PyTorchTrainer:
of ``torch.nn.modules.loss._Loss``, which is most Pytorch
loss classes. For example, ``loss_creator=torch.nn.BCELoss``.
scheduler_creator (optimizers, dict -> loss):
A constructor function for the scheduler loss. This is
A constructor function for the torch scheduler. This is
a function that takes in the generated optimizers (from
``optimizer_creator``) provided config for customization.
Be sure to set ``scheduler_step_freq`` to increment the
scheduler correctly.
train_function: Custom function for training. This function
will be executed in parallel across all workers at once. The
function needs to take in (models, train_dataloader, criterion,
optimizers, config), and return a dict of training stats.
validation_function: Custom function for validation. This function
will be executed in parallel across all workers at once.
This takes in (model, val_dataloader, criterion, config)
and returns a dict of validation stats.
training_operator_cls (type): Custom training operator class
that subclasses the TrainingOperator class. This class
will be copied onto all remote workers and used to specify
custom training and validation operations. Defaults to
TrainingOperator.
config (dict): Custom configuration value to be passed to
"model_creator", "data_creator", "optimizer_creator", and
"loss_creator".
all creator and operator constructors.
dataloader_config (dict): Configuration values to be passed into
the ``torch.utils.data.DataLoader`` object that wraps
the dataset on each parallel worker for both training
@@ -130,8 +135,7 @@ class PyTorchTrainer:
optimizer_creator,
loss_creator,
scheduler_creator=None,
train_function=None,
validation_function=None,
training_operator_cls=None,
initialization_hook=None,
config=None,
dataloader_config=None,
@@ -151,11 +155,10 @@ class PyTorchTrainer:
self.model_creator = model_creator
self.data_creator = data_creator
self.train_function = train_function
self.optimizer_creator = optimizer_creator
self.loss_creator = loss_creator
self.scheduler_creator = scheduler_creator
self.validation_function = validation_function
self.training_operator_cls = training_operator_cls
self.initialization_hook = initialization_hook
self.config = {} if config is None else config
self.dataloader_config = dataloader_config
@@ -166,6 +169,8 @@ class PyTorchTrainer:
logger.info("Using {} as backend.".format(backend))
self.backend = backend
# TODO: Have an auto "use_gpu" option to detect and use GPUs.
self.use_gpu = use_gpu
self.batch_size = batch_size
self.max_replicas = num_replicas
@@ -180,12 +185,7 @@ class PyTorchTrainer:
self._num_failures = 0
self._last_resize = float("-inf")
if scheduler_step_freq and (
scheduler_step_freq not in pytorch_utils.VALID_SCHEDULER_STEP):
raise ValueError(
"Scheduler step freq must be in {}. Got {}".format(
pytorch_utils.VALID_SCHEDULER_STEP, scheduler_step_freq))
_validate_scheduler_step_freq(scheduler_step_freq)
self.scheduler_step_freq = scheduler_step_freq
self._start_workers(self.max_replicas)
@@ -204,8 +204,7 @@ class PyTorchTrainer:
self.optimizer_creator,
self.loss_creator,
self.scheduler_creator,
train_function=self.train_function,
validation_function=self.validation_function,
training_operator_cls=self.training_operator_cls,
config=self.config,
dataloader_config=self.dataloader_config,
batch_size=self.batch_size,
@@ -243,8 +242,7 @@ class PyTorchTrainer:
self.loss_creator,
self.scheduler_creator,
backend=self.backend,
train_function=self.train_function,
validation_function=self.validation_function,
training_operator_cls=self.training_operator_cls,
config=self.config,
dataloader_config=self.dataloader_config,
batch_size=batch_size_per_replica,
@@ -266,21 +264,35 @@ class PyTorchTrainer:
for i, worker in enumerate(self.workers)
])
def train(self, max_retries=0, checkpoint="auto"):
def train(self,
num_steps=None,
max_retries=0,
checkpoint="auto",
info=None):
"""Runs a training epoch.
Runs an average over all values returned from workers. Set
`max_retries` to enable fault handling in case of instance preemption.
Args:
num_steps (int): Number of batches to compute update steps on.
This corresponds also to the number of times
``TrainingOperator.train_batch`` is called.
max_retries (int): Must be non-negative. If set to N, will
kill all current workers, query the Ray global state for
total available resources, and re-launch up to the
available resources. Behavior is not well-defined
in case of shared cluster usage.
checkpoint (str): Path to checkpoint to restore from if retrying.
If max_retries is set and checkpoint == "auto", PyTorchTrainer
will save a checkpoint before starting to train.
If max_retries is set and ``checkpoint == "auto"``,
PyTorchTrainer will save a checkpoint before starting to train.
info (dict): Optional dictionary passed to the training
operator for ``train_epoch`` and ``train_batch``.
Returns:
A dictionary of metrics for training.
You can provide custom metrics by passing in a custom
``training_operator_cls``.
"""
assert max_retries >= 0, "`max_retries` must be non-negative."
if max_retries:
@@ -296,7 +308,8 @@ class PyTorchTrainer:
self._resize_workers(checkpoint=checkpoint)
with self.optimizer_timer:
success, worker_stats = self._train_step()
success, worker_stats = self._train_epoch(
num_steps=num_steps, info=info)
# Fault handling
for i in range(max_retries):
if success:
@@ -306,7 +319,8 @@ class PyTorchTrainer:
self._resize_workers(checkpoint=checkpoint)
logger.info("Retrying training step with %d workers." % len(
self.workers))
success, worker_stats = self._train_step()
success, worker_stats = self._train_epoch(
num_steps=num_steps, info=info)
if not success:
raise RuntimeError("Training run failed.")
@@ -321,19 +335,58 @@ class PyTorchTrainer:
train_stats[stat_key] = worker_stats[0][stat_key]
return train_stats
def _train_step(self):
worker_stats = [w.step.remote() for w in self.workers]
def _train_epoch(self, num_steps=None, info=None):
worker_stats = [
w.train_epoch.remote(num_steps=num_steps, info=info)
for w in self.workers
]
success = utils.check_for_failure(worker_stats)
return success, worker_stats
def apply_all_workers(self, fn):
return ray.get([w.apply_fn.remote(fn) for w in self.workers])
"""Run a function on all operators on the workers.
def validate(self):
"""Evaluates the model on the validation data set."""
if self.validation_function is False:
return {}
worker_stats = ray.get([w.validate.remote() for w in self.workers])
Args:
fn (Callable): A function that takes in no arguments.
Returns:
A list of objects returned by ``fn`` on each worker.
"""
return ray.get([w.apply.remote(fn) for w in self.workers])
def apply_all_operators(self, fn):
"""Run a function on all operators on the workers.
Args:
fn (Callable[TrainingOperator]): A function that takes in a
TrainingOperator.
Returns:
A list of objects returned by ``fn`` on each operator.
"""
return ray.get([w.apply_operator.remote(fn) for w in self.workers])
def validate(self, num_steps=None, info=None):
"""Evaluates the model on the validation data set.
Args:
num_steps (int): Number of batches to compute update steps on.
This corresponds also to the number of times
``TrainingOperator.validate_batch`` is called.
info (dict): Optional dictionary passed to the training
operator for `validate` and `validate_batch`.
Returns:
A dictionary of metrics for validation.
You can provide custom metrics by passing in a custom
``training_operator_cls``.
"""
worker_stats = ray.get([
w.validate.remote(num_steps=num_steps, info=info)
for w in self.workers
])
validation_stats = {}
for stat_key in worker_stats[0]:
@@ -346,8 +399,8 @@ class PyTorchTrainer:
This is useful for lr_schedulers such as ``ReduceLROnPlateau``.
"""
self.apply_all_workers(
lambda runner: [sched.step(metric) for sched in runner.schedulers])
self.apply_all_operators(
lambda op: [sched.step(metric) for sched in op.schedulers])
def get_model(self):
"""Returns the learned model(s)."""
@@ -366,17 +419,18 @@ class PyTorchTrainer:
Args:
checkpoint (str): Path to target checkpoint file.
Returns:
checkpoint (str): Path to target checkpoint file.
"""
state = ray.get(self.workers[0].get_state.remote())
torch.save(state, checkpoint)
return checkpoint
def restore(self, checkpoint):
"""Restores the model from the provided checkpoint.
"""Restores the Trainer and all workers from the provided checkpoint.
Args:
checkpoint (str): Path to target checkpoint file.
"""
state = torch.load(checkpoint)
state_id = ray.put(state)
@@ -450,7 +504,6 @@ class PyTorchTrainable(Trainable):
validation_stats = self._trainer.validate()
train_stats.update(validation_stats)
# output {"mean_loss": test_loss, "mean_accuracy": accuracy}
return train_stats
@@ -0,0 +1,343 @@
import collections
import torch
from ray.util.sgd.utils import TimerStat, AverageMeter
from ray.util.sgd.pytorch.constants import (
SCHEDULER_STEP_EPOCH, SCHEDULER_STEP_BATCH, SCHEDULER_STEP, BATCH_COUNT)
amp = None
try:
from apex import amp
except ImportError:
# Apex library is not installed, so we cannot enable mixed precision.
# We don't log here because logging happens in the pytorch_runner,
# where amp is initialized.
pass
def _is_multiple(component):
"""Checks if a component (optimizer, model, etc) is not singular."""
return isinstance(component, collections.Iterable) and len(component) > 1
class TrainingOperator:
"""Abstract class for custom training or validation loops.
The scheduler will only be called at a batch or epoch frequency, depending
on the user parameter. Be sure to set ``scheduler_step_freq`` in
``PyTorchTrainer`` to either "batch" or "epoch" to increment the scheduler
correctly during training. If using a learning rate scheduler
that depends on validation loss, you can use ``trainer.update_scheduler``.
For both training and validation, there are two granularities that
you can provide customization: per epoch or per batch.
You do not need to override both.
.. image:: raysgd-custom.jpg
:scale: 80%
:align: center
Raises:
ValueError if multiple models/optimizers/schedulers are provided.
You are expected to subclass this class if you wish
to train over multiple models/optimizers/schedulers.
"""
def __init__(self,
config,
models,
optimizers,
criterion,
schedulers=None,
use_fp16=False):
# You are not expected to override this method.
self.timers = {
k: TimerStat()
for k in ["fwd", "grad", "apply", "epoch_time"]
}
self._validated_customization = False
self._models = models # List of models
assert isinstance(models, collections.Iterable), (
"Components need to be iterable. Got: {}".format(type(models)))
self._optimizers = optimizers # List of optimizers
assert isinstance(optimizers, collections.Iterable), (
"Components need to be iterable. Got: {}".format(type(optimizers)))
self._criterion = criterion
self._schedulers = schedulers
if schedulers:
assert isinstance(schedulers, collections.Iterable), (
"Components need to be iterable. Got: {}".format(
type(schedulers)))
self._config = config
self._use_fp16 = use_fp16
self.global_step = 0
if type(self) is TrainingOperator:
for component in (models, schedulers, optimizers):
if _is_multiple(component):
raise ValueError(
"Need to provide a custom operator subclassing "
"TrainingOperator if using multi-scheduler, "
"multi-model or multi-optimizer training/validation.")
self.setup(config)
def setup(self, config):
"""Override this method to implement custom operator setup.
Args:
config (dict): Custom configuration value to be passed to
all creator and operator constructors. Same as ``self.config``.
"""
pass
def train_epoch(self, iterator, info):
"""Runs one standard training pass over the train_iterator.
By default, this method will iterate over the given iterator and
call ``self.train_batch`` over each batch.
If ``scheduler_step_freq`` is set, this class will also step the
scheduler accordingly.
You do not need to call ``train_batch`` in this method if you plan
to implement a custom optimization/training routine here.
Args:
iterator (iter): Iterator over the training data for the entire
epoch. This iterator is expected to be entirely consumed.
info (dict): Dictionary for information to be used for custom
training operations.
Returns:
A dict of metrics from training.
"""
self._losses = AverageMeter()
self.model.train()
with self.timers["epoch_time"]:
for batch_idx, batch in enumerate(iterator):
batch_info = {
"batch_idx": batch_idx,
"global_step": self.global_step
}
batch_info.update(info)
metrics = self.train_batch(batch, batch_info=batch_info)
if self.scheduler and batch_info.get(
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
self.scheduler.step()
if "loss" in metrics:
self._losses.update(
metrics["loss"], n=metrics.get("num_samples", 1))
self.global_step += 1
if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:
self.scheduler.step()
stats = {
BATCH_COUNT: batch_idx + 1,
"mean_train_loss": self._losses.avg,
"last_train_loss": self._losses.val,
"epoch_time": self.timers["epoch_time"].last
}
stats.update({
timer_tag: timer.mean
for timer_tag, timer in self.timers.items()
})
return stats
def train_batch(self, batch, batch_info):
"""Computes loss and updates the model over one batch.
This method is responsible for computing the loss and gradient and
updating the model.
By default, this method implementation assumes that batches
are in (features, labels) format. If using amp/fp16
training, it will also scale the loss automatically.
You can provide custom loss metrics and training operations if you
override this method. If overriding this method, you can access model,
optimizer, criterion via ``self.model``, ``self.optimizer``,
and ``self.criterion``.
You do not need to override this method if you plan to
override ``train_epoch``.
Args:
batch: One item of the validation iterator.
batch_info (dict): Information dict passed in from ``train_epoch``.
Returns:
A dictionary of metrics.
By default, this dictionary contains "loss" and "num_samples".
"num_samples" corresponds to number of datapoints in the batch.
However, you can provide any number of other values.
"""
features, target = batch
# Create non_blocking tensors for distributed training
if torch.cuda.is_available():
features = features.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# Compute output.
with self.timers["fwd"]:
output = self.model(features)
loss = self.criterion(output, target)
# Compute gradients in a backward pass.
with self.timers["grad"]:
self.optimizer.zero_grad()
if self.use_fp16:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# Call step of optimizer to update model params.
with self.timers["apply"]:
self.optimizer.step()
return {"loss": loss.item(), "num_samples": features.size(0)}
def validate(self, val_iterator, info):
"""Runs one standard validation pass over the val_iterator.
This will call ``model.eval()`` and ``torch.no_grad`` when iterating
over the validation dataset.
If overriding this method, you can access model, criterion via
``self.model`` and ``self.criterion``. You also do not need to call
``validate_batch`` if overriding this method.
Args:
val_iterator (iter): Iterable constructed over the
validation dataset.
info: (dict): Dictionary for information to be used for custom
validation operations.
Returns:
A dict of metrics from the evaluation.
By default, returns "mean_accuracy" and "mean_validation_loss"
which is computed by aggregating "loss" and "correct" values
from ``validate_batch`` and dividing it by the sum of
``num_samples`` from all calls to ``self.validate_batch``.
"""
losses = AverageMeter()
total_correct = 0
# switch to evaluate mode
self.model.eval()
with torch.no_grad():
for batch_idx, batch in enumerate(val_iterator):
batch_info = {"batch_idx": batch_idx}
batch_info.update(info)
metrics = self.validate_batch(batch, batch_info)
if "loss" in metrics:
losses.update(
metrics["loss"], n=metrics.get("num_samples", 1))
if "num_correct" in metrics:
total_correct += metrics["num_correct"]
stats = {
"batch_count": batch_idx + 1,
"mean_validation_loss": losses.avg,
"mean_accuracy": total_correct / losses.count
}
return stats
def validate_batch(self, batch, batch_info):
"""Calcuates the loss and accuracy over a given batch.
You can override this method to provide arbitrary metrics.
Args:
batch: One item of the validation iterator.
batch_info (dict): Contains information per batch from
``validate()``.
Returns:
A dict of metrics.
By default, returns "loss", "num_correct", and "num_samples".
"""
features, target = batch
if torch.cuda.is_available():
features = features.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = self.model(features)
loss = self.criterion(output, target)
_, predicted = torch.max(output.data, 1)
return {
"loss": loss.item(),
"num_correct": (predicted == target).sum().item(),
"num_samples": target.size(0)
}
def state_dict(self):
"""Returns a serializable representation of the operator state."""
pass
def load_state_dict(self, state_dict):
"""Loads a serializable representation of the operator state."""
pass
@property
def config(self):
"""Dictionary as provided into PyTorchTrainer."""
return self._config
@property
def model(self):
"""First or only model created by the provided ``model_creator``."""
return self._models[0]
@property
def models(self):
"""List of models created by the provided ``model_creator``."""
return self._models
@property
def optimizer(self):
"""First or only optimizer(s) created by the ``optimizer_creator``."""
return self._optimizers[0]
@property
def optimizers(self):
"""List of optimizers created by the ``optimizer_creator``."""
return self._optimizers
@property
def criterion(self):
"""Criterion created by the provided ``loss_creator``."""
return self._criterion
@property
def scheduler(self):
"""First or only scheduler(s) created by the ``scheduler_creator``."""
if self._schedulers:
return self._schedulers[0]
@property
def schedulers(self):
"""List of schedulers created by the ``scheduler_creator``."""
return self._schedulers
@property
def use_fp16(self):
"""Whether the model and optimizer have been FP16 enabled."""
return self._use_fp16
class _TestingOperator(TrainingOperator):
def train_epoch(self, iterator, info):
func = self.config.get("custom_func")
if callable(func):
return func(self, iterator, info)
return {"done": 1}
-229
View File
@@ -1,229 +0,0 @@
import collections
import time
import torch
from ray.util.sgd.utils import TimerStat
amp = None
try:
from apex import amp
except ImportError:
# Apex library is not installed, so we cannot enable mixed precision.
# We don't log here because logging happens in the pytorch_runner,
# where amp is initialized.
pass
USE_FP16 = "__use_fp16__"
TEST_MODE = "__test_mode__"
BATCH_COUNT = "batch_processed"
SCHEDULER_STEP = "scheduler_step"
SCHEDULER_STEP_BATCH = "batch"
SCHEDULER_STEP_EPOCH = "epoch"
VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH}
def train(config, model, train_iterator, criterion, optimizer, scheduler=None):
"""Runs one standard training pass over the train_iterator.
This function automatically measures timing for various operations such
as host to device transfer, gradient calculation, and gradient application.
It also automatically detects and places the data on the given GPU device
if available.
The scheduler will only be called at a batch or epoch frequency, depending
on the user parameter. Be sure to set ``scheduler_step_freq`` in
``PyTorchTrainer`` to either "batch" or "epoch" to increment the scheduler
correctly during training. If using a learning rate scheduler
that depends on validation loss, you can use ``trainer.update_scheduler``.
Raises:
ValueError if multiple models/optimizers/schedulers are provided. You
are expected to have a custom training function if you wish
to use multiple models/optimizers/schedulers.
Args:
config: (dict): A user configuration provided into the Trainer
constructor.
model: The model as created by the model_creator.
train_iterator: An iterator created from the DataLoader which
wraps the provided Dataset.
criterion: The loss object created by the loss_creator.
optimizer: The torch.optim.Optimizer object as created by the
optimizer_creator.
scheduler (optional): The torch.optim.lr_scheduler object
as created by the scheduler_creator. Be sure to set
``scheduler_step_freq`` in ``PyTorchTrainer``
to increment the scheduler correctly.
Returns:
A dict of metrics from training.
"""
if isinstance(model, collections.Iterable) or isinstance(
optimizer, collections.Iterable) or isinstance(
scheduler, collections.Iterable):
raise ValueError(
"Need to provide custom training function if using multi-model "
"or multi-scheduler or multi-optimizer training.")
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
timers = {k: TimerStat() for k in ["h2d", "fwd", "grad", "apply"]}
# switch to train mode
model.train()
end = time.time()
for batch_idx, (features, target) in enumerate(train_iterator):
# measure data loading time
data_time.update(time.time() - end)
# Create non_blocking tensors for distributed training
with timers["h2d"]:
if torch.cuda.is_available():
features = features.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
with timers["fwd"]:
output = model(features)
loss = criterion(output, target)
# measure accuracy and record loss
losses.update(loss.item(), features.size(0))
with timers["grad"]:
# compute gradients in a backward pass
optimizer.zero_grad()
if config.get(USE_FP16):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
with timers["apply"]:
# Call step of optimizer to update model params
optimizer.step()
if scheduler and config.get(SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if config.get(TEST_MODE) and batch_idx == 0:
break
if scheduler and config.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:
scheduler.step()
stats = {
"batch_time": batch_time.avg,
BATCH_COUNT: batch_idx + 1,
"train_loss": losses.avg,
"data_time": data_time.avg,
}
stats.update({k: t.mean for k, t in timers.items()})
return stats
def validate(config, model, val_iterator, criterion, scheduler=None):
"""Runs one standard validation pass over the val_iterator.
This function automatically measures timing for various operations such
as host to device transfer and processing time for the batch.
It also automatically detects and places the data on the given GPU device
if available.
Raises:
ValueError if multiple models/schedulers are provided. You
are expected to have a custom validation function if you wish
to use multiple models/schedulers.
Args:
config: (dict): A user configuration provided into the Trainer
constructor.
model: The model as created by the model_creator.
train_iterator: An iterator created from the DataLoader which
wraps the provided Dataset.
criterion: The loss object created by the loss_creator.
scheduler (optional): The torch.optim.lr_scheduler object
as created by the scheduler_creator. By default,
this is not used in this function.
Returns:
A dict of metrics from the evaluation.
"""
if isinstance(model, collections.Iterable) or isinstance(
scheduler, collections.Iterable):
raise ValueError(
"Need to provide custom validation function if using multi-model "
"or multi-scheduler training.")
batch_time = AverageMeter()
losses = AverageMeter()
# switch to evaluate mode
model.eval()
correct = 0
total = 0
batch_idx = 0
with torch.no_grad():
end = time.time()
for batch_idx, (features, target) in enumerate(val_iterator):
if torch.cuda.is_available():
features = features.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(features)
loss = criterion(output, target)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
# measure accuracy and record loss
losses.update(loss.item(), features.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if config.get(TEST_MODE) and batch_idx == 0:
break
stats = {
BATCH_COUNT: batch_idx + 1,
"batch_time": batch_time.avg,
"validation_loss": losses.avg,
"mean_accuracy": correct / total,
"mean_loss": losses.sum / total,
}
return stats
class AverageMeter:
"""Computes and stores the average and current value."""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
+104 -59
View File
@@ -10,28 +10,34 @@ import torch.distributed as dist
import ray
from ray import tune
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
from ray.util.sgd.pytorch import PyTorchTrainer, PyTorchTrainable
from ray.util.sgd.pytorch.utils import (train, BATCH_COUNT, TEST_MODE,
SCHEDULER_STEP)
from ray.util.sgd.pytorch.training_operator import _TestingOperator
from ray.util.sgd.pytorch.constants import BATCH_COUNT, SCHEDULER_STEP
from ray.util.sgd.utils import check_for_failure
from ray.util.sgd.pytorch.examples.train_example import (
model_creator, optimizer_creator, data_creator, LinearDataset)
def test_test_mode(ray_start_2_cpus): # noqa: F811
@pytest.fixture
def ray_start_2_cpus():
address_info = ray.init(num_cpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
def test_single_step(ray_start_2_cpus): # noqa: F811
trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
config={TEST_MODE: True},
num_replicas=1)
metrics = trainer.train()
metrics = trainer.train(num_steps=1)
assert metrics[BATCH_COUNT] == 1
val_metrics = trainer.validate()
val_metrics = trainer.validate(num_steps=1)
assert val_metrics[BATCH_COUNT] == 1
@@ -45,29 +51,51 @@ def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
loss_creator=lambda config: nn.MSELoss(),
num_replicas=num_replicas)
for i in range(3):
train_loss1 = trainer.train()["train_loss"]
validation_loss1 = trainer.validate()["validation_loss"]
train_loss1 = trainer.train()["mean_train_loss"]
validation_loss1 = trainer.validate()["mean_validation_loss"]
for i in range(3):
train_loss2 = trainer.train()["train_loss"]
validation_loss2 = trainer.validate()["validation_loss"]
train_loss2 = trainer.train()["mean_train_loss"]
validation_loss2 = trainer.validate()["mean_validation_loss"]
print(train_loss1, train_loss2)
print(validation_loss1, validation_loss2)
assert train_loss2 <= train_loss1
assert validation_loss2 <= validation_loss1
assert train_loss2 <= train_loss1, (train_loss2, train_loss1)
assert validation_loss2 <= validation_loss1, (validation_loss2,
validation_loss1)
@pytest.mark.parametrize("num_replicas", [1, 2]
if dist.is_available() else [1])
def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
def custom_train(config, models, dataloader, criterion, optimizers,
**kwargs):
def test_multi_model(ray_start_2_cpus, num_replicas):
def train(*, model=None, criterion=None, optimizer=None, dataloader=None):
model.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(dataloader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return {
"accuracy": correct / total,
"train_loss": train_loss / (batch_idx + 1)
}
def train_epoch(self, iterator, info):
result = {}
for i, (model, optimizer) in enumerate(zip(models, optimizers)):
result["model_{}".format(i)] = train(config, model, dataloader,
criterion, optimizer)
for i, (model, optimizer) in enumerate(
zip(self.models, self.optimizers)):
result["model_{}".format(i)] = train(
model=model,
criterion=self.criterion,
optimizer=optimizer,
dataloader=iterator)
return result
def multi_model_creator(config):
@@ -84,7 +112,8 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
data_creator,
multi_optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
train_function=custom_train,
config={"custom_func": train_epoch},
training_operator_cls=_TestingOperator,
num_replicas=num_replicas)
trainer1.train()
@@ -100,6 +129,8 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
data_creator,
multi_optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
config={"custom_func": train_epoch},
training_operator_cls=_TestingOperator,
num_replicas=num_replicas)
trainer2.restore(filename)
@@ -123,16 +154,17 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
@pytest.mark.parametrize("num_replicas", [1, 2]
if dist.is_available() else [1])
def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811
def custom_train(config, model, dataloader, criterion, optimizer,
scheduler):
if config.get("models", 1) > 1:
assert len(model) == config["models"], config
def train_epoch(self, iterator, info):
if self.config.get("models", 1) > 1:
assert len(self.models) == self.config["models"], self.config
if config.get("optimizers", 1) > 1:
assert len(optimizer) == config["optimizers"], config
if self.config.get("optimizers", 1) > 1:
assert len(
self.optimizers) == self.config["optimizers"], self.config
if config.get("schedulers", 1) > 1:
assert len(scheduler) == config["schedulers"], config
if self.config.get("schedulers", 1) > 1:
assert len(
self.schedulers) == self.config["schedulers"], self.config
return {"done": 1}
def multi_model_creator(config):
@@ -167,12 +199,13 @@ def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811
multi_optimizer_creator,
loss_creator=nn.MSELoss,
scheduler_creator=multi_scheduler_creator,
train_function=custom_train,
training_operator_cls=_TestingOperator,
num_replicas=num_replicas,
config={
"models": model_count,
"optimizers": optimizer_count,
"schedulers": scheduler_count
"schedulers": scheduler_count,
"custom_func": train_epoch
})
trainer.train()
trainer.shutdown()
@@ -180,9 +213,8 @@ def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811
@pytest.mark.parametrize("scheduler_freq", ["epoch", "batch"])
def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
def custom_train(config, model, dataloader, criterion, optimizer,
scheduler):
assert config[SCHEDULER_STEP] == scheduler_freq
def train_epoch(self, iterator, info):
assert info[SCHEDULER_STEP] == scheduler_freq
return {"done": 1}
def scheduler_creator(optimizer, config):
@@ -194,18 +226,17 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
data_creator,
optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
scheduler_creator=scheduler_creator)
config={"custom_func": train_epoch},
training_operator_cls=_TestingOperator,
scheduler_creator=scheduler_creator,
scheduler_step_freq=scheduler_freq)
for i in range(3):
trainer.train()["train_loss"]
trainer.train()
trainer.shutdown()
def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
def custom_train(config, model, dataloader, criterion, optimizer,
scheduler):
return {"done": 1}
from torch.optim.lr_scheduler import ReduceLROnPlateau
trainer = PyTorchTrainer(
@@ -213,11 +244,13 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
data_creator,
optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer))
scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer),
training_operator_cls=_TestingOperator)
trainer.update_scheduler(0.5)
trainer.update_scheduler(0.5)
assert all(
trainer.apply_all_workers(lambda r: r.schedulers[0].last_epoch == 2))
trainer.apply_all_operators(
lambda op: op.schedulers[0].last_epoch == 2))
trainer.shutdown()
@@ -248,13 +281,13 @@ def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
# checks loss decreasing for every trials
for path, df in analysis.trial_dataframes.items():
train_loss1 = df.loc[0, "train_loss"]
train_loss2 = df.loc[1, "train_loss"]
validation_loss1 = df.loc[0, "validation_loss"]
validation_loss2 = df.loc[1, "validation_loss"]
mean_train_loss1 = df.loc[0, "mean_train_loss"]
mean_train_loss2 = df.loc[1, "mean_train_loss"]
mean_validation_loss1 = df.loc[0, "mean_validation_loss"]
mean_validation_loss2 = df.loc[1, "mean_validation_loss"]
assert train_loss2 <= train_loss1
assert validation_loss2 <= validation_loss1
assert mean_train_loss2 <= mean_train_loss1
assert mean_validation_loss2 <= mean_validation_loss1
@pytest.mark.parametrize("num_replicas", [1, 2]
@@ -303,15 +336,17 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
def single_loader(config):
return LinearDataset(2, 5, size=1000000)
def step_with_fail(self):
worker_stats = [w.step.remote() for w in self.workers]
def step_with_fail(self, *args, **kwargs):
worker_stats = [
w.train_epoch.remote(*args, **kwargs) for w in self.workers
]
if self._num_failures < 3:
time.sleep(1) # Make the batch will fail correctly.
self.workers[0].__ray_kill__()
success = check_for_failure(worker_stats)
return success, worker_stats
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail):
trainer1 = PyTorchTrainer(
model_creator,
single_loader,
@@ -331,15 +366,17 @@ def test_resize(ray_start_2_cpus): # noqa: F811
def single_loader(config):
return LinearDataset(2, 5, size=1000000)
def step_with_fail(self):
worker_stats = [w.step.remote() for w in self.workers]
def step_with_fail(self, *args, **kwargs):
worker_stats = [
w.train_epoch.remote(*args, **kwargs) for w in self.workers
]
if self._num_failures < 1:
time.sleep(1) # Make the batch will fail correctly.
self.workers[0].__ray_kill__()
success = check_for_failure(worker_stats)
return success, worker_stats
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail):
trainer1 = PyTorchTrainer(
model_creator,
single_loader,
@@ -365,15 +402,17 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
def single_loader(config):
return LinearDataset(2, 5, size=1000000)
def step_with_fail(self):
worker_stats = [w.step.remote() for w in self.workers]
def step_with_fail(self, *args, **kwargs):
worker_stats = [
w.train_epoch.remote(*args, **kwargs) for w in self.workers
]
if self._num_failures < 2:
time.sleep(1)
self.workers[0].__ray_kill__()
success = check_for_failure(worker_stats)
return success, worker_stats
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail):
trainer1 = PyTorchTrainer(
model_creator,
single_loader,
@@ -383,3 +422,9 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
num_replicas=2)
trainer1.train(max_retries=2)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", "-x", __file__]))
@@ -4,6 +4,7 @@ import torch.nn as nn
import unittest
from unittest.mock import MagicMock
from ray.util.sgd.pytorch.training_operator import TrainingOperator
from ray.util.sgd.pytorch.pytorch_runner import PyTorchRunner
@@ -46,39 +47,55 @@ def create_dataloaders(config):
class TestPyTorchRunner(unittest.TestCase):
def testValidate(self):
mock_function = MagicMock(returns=dict(mean_accuracy=10))
class MockOperator(TrainingOperator):
def setup(self, config):
self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
self.validate = MagicMock(returns=dict(mean_accuracy=10))
runner = PyTorchRunner(
model_creator,
create_dataloaders,
optimizer_creator,
loss_creator,
validation_function=mock_function)
training_operator_cls=MockOperator)
runner.setup()
runner.step()
runner.step()
runner.step()
self.assertEqual(mock_function.call_count, 0)
runner.train_epoch()
runner.train_epoch()
runner.train_epoch()
self.assertEqual(runner.training_operator.validate.call_count, 0)
runner.validate()
self.assertTrue(mock_function.called)
self.assertTrue(runner.training_operator.validate.called)
self.assertEqual(runner.stats()["epoch"], 3)
def testStep(self):
mock_function = MagicMock(return_value=dict(mean_accuracy=10))
def testtrain_epoch(self):
class MockOperator(TrainingOperator):
def setup(self, config):
self.count = 0
def train_epoch(self, *args, **kwargs):
self.count += 1
return {"count": self.count}
runner = PyTorchRunner(
model_creator,
create_dataloaders,
optimizer_creator,
loss_creator,
train_function=mock_function)
training_operator_cls=MockOperator)
runner.setup()
runner.step()
runner.step()
result = runner.step()
self.assertEqual(mock_function.call_count, 3)
self.assertEqual(result["epoch"], 3)
runner.train_epoch(num_steps=1)
runner.train_epoch(num_steps=1)
result = runner.train_epoch()
self.assertEqual(runner.training_operator.count, 3)
self.assertEqual(result["count"], 3)
self.assertEqual(runner.stats()["epoch"], 3)
def testGivens(self):
class MockOperator(TrainingOperator):
def setup(self, config):
self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
self.validate = MagicMock(returns=dict(mean_accuracy=10))
def three_model_creator(config):
return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)
@@ -88,8 +105,12 @@ class TestPyTorchRunner(unittest.TestCase):
]
return opts[0], opts[1], opts[2]
runner = PyTorchRunner(three_model_creator, single_loader,
three_optimizer_creator, loss_creator)
runner = PyTorchRunner(
three_model_creator,
single_loader,
three_optimizer_creator,
loss_creator,
training_operator_cls=MockOperator)
runner.setup()
self.assertEqual(len(runner.given_models), 3)
@@ -121,7 +142,7 @@ class TestPyTorchRunner(unittest.TestCase):
runner = PyTorchRunner(model_creator, single_loader, optimizer_creator,
loss_creator)
runner.setup()
runner.step()
runner.train_epoch()
with self.assertRaises(ValueError):
runner.validate()
@@ -132,7 +153,7 @@ class TestPyTorchRunner(unittest.TestCase):
optimizer_creator,
loss_creator=nn.MSELoss)
runner.setup()
runner.step()
runner.train_epoch()
def testMultiModel(self):
def multi_model_creator(config):
@@ -146,6 +167,6 @@ class TestPyTorchRunner(unittest.TestCase):
runner = PyTorchRunner(multi_model_creator, single_loader,
multi_optimizer_creator, loss_creator)
runner.setup()
with self.assertRaises(ValueError):
runner.step()
runner.setup()
+8
View File
@@ -149,3 +149,11 @@ def check_for_failure(remote_values):
except RayActorError as exc:
logger.exception(str(exc))
return False
def override(interface_class):
def overrider(method):
assert (method.__name__ in dir(interface_class))
return method
return overrider