diff --git a/doc/source/raysgd/raysgd-custom.jpg b/doc/source/raysgd/raysgd-custom.jpg new file mode 100644 index 000000000..0145b099d Binary files /dev/null and b/doc/source/raysgd/raysgd-custom.jpg differ diff --git a/doc/source/raysgd/raysgd_pytorch.rst b/doc/source/raysgd/raysgd_pytorch.rst index e85fc10a4..4de9c90b6 100644 --- a/doc/source/raysgd/raysgd_pytorch.rst +++ b/doc/source/raysgd/raysgd_pytorch.rst @@ -10,6 +10,8 @@ Under the hood, ``PytorchTrainer`` will create *replicas* of your model (control For end to end examples leveraging RaySGD PyTorchTrainer, jump to :ref:`raysgd-pytorch-examples`. +.. contents:: :local: + Setting up training ------------------- @@ -81,6 +83,7 @@ for ``PyTorchTrainer(scheduler_creator=...)``. :start-after: __torch_scheduler_start__ :end-before: __torch_scheduler_end__ + .. _starting-pytorch-trainer: Putting things together @@ -115,30 +118,17 @@ You can also set the number of workers and whether the workers will use GPUs: num_replicas=100, use_gpu=True) -See the documentation on the PyTorchTrainer here: :ref:`ref-pytorch-trainer`. We'll look at the training APIs next. -Training APIs -------------- - -Now that the trainer is constructed, you'll naturally want to train the model. +Now that the trainer is constructed, here's how to train the model. .. code-block:: python - trainer.train() - -This takes one pass over the training data. - -To run the model on the validation data passed in by the ``data_creator``, you can simply call: - -.. code-block:: python - - trainer.validate() - -You can customize the exact function that is called by using a customized training function (see :ref:`raysgd-custom-training`). + for i in range(10): + metrics = trainer.train() + val_metrics = trainer.validate() -Shutting down training ----------------------- +Each ``train`` call makes one pass over the training data, and each ``validate`` call runs the model on the validation data passed in by the ``data_creator``. Provide a custom training operator (:ref:`raysgd-custom-training`) to calculate custom metrics or customize the training/validation process. After training, you may want to reappropriate the Ray cluster. To release Ray resources obtained by the Trainer: @@ -148,15 +138,131 @@ After training, you may want to reappropriate the Ray cluster. To release Ray re .. note:: Be sure to call ``trainer.save()`` or ``trainer.get_model()`` before shutting down. -Initialization Functions ------------------------- +See the documentation on the PyTorchTrainer here: :ref:`ref-pytorch-trainer`. -You may want to run some initializers on each worker when they are started. This may be something like setting an environment variable or downloading some data. You can do this via the ``initialization_hook`` parameter: + +.. _raysgd-custom-training: + +Custom Training and Validation (Operators) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``PyTorchTrainer`` allows you to run a custom training and validation loops in parallel on each worker, providing a flexible interface similar to using PyTorch natively. +This is done via the :ref:`ref-pytorch-operator` interface. + +For both training and validation, there are two granularities that you can provide customization - per epoch and per batch. These correspond to ``train_batch``, +``train_epoch``, ``validate``, and ``validate_batch``. Other useful methods to override include ``setup``, ``save`` and ``restore``. You can use these +to manage state (like a classifier neural network for calculating inception score, or a heavy tokenizer). + +Providing a custom operator is necessary if creator functions return multiple models, optimizers, or schedulers. + +Below is a partial example of a custom ``TrainingOperator`` that provides a ``train_batch`` implementation for a Deep Convolutional GAN. .. code-block:: python + import torch + from ray.util.sgd.pytorch import TrainingOperator - def initialization_hook(runner): + class GANOperator(TrainingOperator): + def setup(self, config): + """Custom setup for this operator. + + Args: + config (dict): Custom configuration value to be passed to + all creator and operator constructors. Same as ``self.config``. + """ + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def train_batch(self, batch, batch_info): + """Trains on one batch of data from the data creator. + + Example taken from: + https://github.com/eriklindernoren/PyTorch-GAN/blob/ + a163b82beff3d01688d8315a3fd39080400e7c01/implementations/dcgan/dcgan.py + + Args: + batch: One item of the validation iterator. + batch_info (dict): Information dict passed in from ``train_epoch``. + + Returns: + A dict of metrics. Defaults to "loss" and "num_samples", + corresponding to the total number of datapoints in the batch. + """ + Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor + discriminator, generator = self.models + optimizer_D, optimizer_G = self.optimizers + + # Adversarial ground truths + valid = Variable(Tensor(batch.shape[0], 1).fill_(1.0), requires_grad=False) + fake = Variable(Tensor(batch.shape[0], 1).fill_(0.0), requires_grad=False) + + # Configure input + real_imgs = Variable(batch.type(Tensor)) + + # ----------------- + # Train Generator + # ----------------- + + optimizer_G.zero_grad() + + # Sample noise as generator input + z = Variable(Tensor(np.random.normal(0, 1, ( + batch.shape[0], self.config["latent_dim"])))) + + # Generate a batch of images + gen_imgs = generator(z) + + # Loss measures generator's ability to fool the discriminator + g_loss = adversarial_loss(discriminator(gen_imgs), valid) + + g_loss.backward() + optimizer_G.step() + + # --------------------- + # Train Discriminator + # --------------------- + + optimizer_D.zero_grad() + + # Measure discriminator's ability to classify real from generated samples + real_loss = adversarial_loss(discriminator(real_imgs), valid) + fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) + d_loss = (real_loss + fake_loss) / 2 + + d_loss.backward() + optimizer_D.step() + + return { + "loss_g": g_loss.item(), + "loss_d": d_loss.item(), + "num_samples": imgs.shape[0] + } + + trainer = PyTorchTrainer( + model_creator, + data_creator, + optimizer_creator, + nn.BCELoss, + training_operator_cls=GANOperator, + num_replicas=num_replicas, + config=config, + use_gpu=True, + batch_size=128) + + for i in range(5): + stats = trainer.train() + print(stats) + +See the `DCGAN example `__ for an end to end example. It constructs two models and two optimizers and uses a custom training operator to provide a non-standard training loop. + + +Initialization Functions +------------------------ + +Use the ``initialization_hook`` parameter to initialize state on each worker process when they are started. This is useful when setting an environment variable: + +.. code-block:: python + + def initialization_hook(): print("NCCL DEBUG SET") # Need this for avoiding a connection restart issue os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" @@ -193,8 +299,8 @@ and ``trainer.load``, which wraps the relevant ``torch.save`` and ``torch.load`` trainer_2.restore(checkpoint_path) -Exporting a model for inference -------------------------------- +Retrieving the model +-------------------- The trained torch model can be extracted for use within the same Python program with ``trainer.get_model()``. This will load the state dictionary of the model(s). @@ -242,22 +348,23 @@ To specify particular parameters for ``amp.initialize``, you can use the ``apex_ } ) -Note that if using a custom training function, you will need to manage loss scaling manually. +Note that if using a custom training operator (:ref:`raysgd-custom-training`), you will need to manage loss scaling manually. Distributed Multi-node Training ------------------------------- -You can scale out your training onto multiple nodes without making any modifications to your training code. To train across a cluster, simply make sure that the Ray cluster is started. +You can scale your training to multiple nodes without making any modifications to your training code. -You can start a Ray cluster `via the Ray cluster launcher `_ or `manually `_. +To train across a cluster, first make sure that the Ray cluster is started. You can start a Ray cluster `via the Ray cluster launcher `_ or `manually `_. -.. code-block:: bash +Then, in your program, you'll need to connect to this cluster via ``ray.init``: - ray up CLUSTER.yaml - ray submit train.py --args="--address='auto'" +.. code-block:: python -Then, within ``train.py`` you can scale up the number of workers seamlessly across multiple nodes: + ray.init(address="auto") # or a specific redis address of the form "ip-address:port" + +After connecting, you can scale up the number of workers seamlessly across multiple nodes: .. code-block:: python @@ -266,7 +373,10 @@ Then, within ``train.py`` you can scale up the number of workers seamlessly acro data_creator, optimizer_creator, loss_creator=nn.MSELoss, - num_replicas=100) + num_replicas=100 + ) + trainer.train() + model = trainer.get_model() Advanced: Fault Tolerance @@ -310,22 +420,37 @@ Advanced: Hyperparameter Tuning Simultaneous Multi-model Training --------------------------------- -In certain scenarios such as training GANs, you may want to use multiple models in the training loop. You can do this in the ``PyTorchTrainer`` by allowing the ``model_creator``, ``optimizer_creator``, and ``scheduler_creator`` to return multiple values. - -If multiple models, optimizers, or schedulers are returned, you will need to provide a custom training function (and custom validation function if you plan to call ``validate``). +In certain scenarios, such as training GANs, you may want to use multiple models in the training loop. You can do this in the ``PyTorchTrainer`` by allowing the ``model_creator``, ``optimizer_creator``, and ``scheduler_creator`` to return multiple values. Provide a custom TrainingOperator (:ref:`raysgd-custom-training`) to train across multiple models. You can see the `DCGAN script `_ for an end-to-end example. .. code-block:: python + from ray.util.sgd.pytorch import PyTorchTrainer, TrainingOperator + + def train(*, model=None, criterion=None, optimizer=None, dataloader=None): + model.train() + train_loss = 0 + correct = 0 + total = 0 + for batch_idx, (inputs, targets) in enumerate(dataloader): + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + train_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + return { + "accuracy": correct / total, + "train_loss": train_loss / (batch_idx + 1) + } + def model_creator(config): - netD = Discriminator() - netD.apply(weights_init) - - netG = Generator() - netG.apply(weights_init) - return netD, netG - + return Discriminator(), Generator() def optimizer_creator(models, config): net_d, net_g = models @@ -335,125 +460,27 @@ You can see the `DCGAN script `__: - Training a ResNet18 model on CIFAR10. It uses a custom training - function, a custom validation function, and custom initialization code for each worker. + Training a ResNet18 model on CIFAR10. - `DCGAN example `__: - Training a Deep Convolutional GAN on MNIST. It constructs - two models and two optimizers and uses a custom training and validation function. + Training a Deep Convolutional GAN on MNIST. It constructs two models and two optimizers and uses a custom training operator. diff --git a/doc/source/raysgd/raysgd_ref.rst b/doc/source/raysgd/raysgd_ref.rst index d7ac5020b..15e99d62c 100644 --- a/doc/source/raysgd/raysgd_ref.rst +++ b/doc/source/raysgd/raysgd_ref.rst @@ -11,6 +11,14 @@ PyTorchTrainer .. automethod:: __init__ +.. _ref-pytorch-operator: + +PyTorch TrainingOperator +------------------------ + +.. autoclass:: ray.util.sgd.pytorch.TrainingOperator + :members: + PyTorchTrainable ---------------- diff --git a/python/ray/util/sgd/pytorch/__init__.py b/python/ray/util/sgd/pytorch/__init__.py index dd284cb9e..bdcc100d3 100644 --- a/python/ray/util/sgd/pytorch/__init__.py +++ b/python/ray/util/sgd/pytorch/__init__.py @@ -3,6 +3,7 @@ logger = logging.getLogger(__name__) PyTorchTrainer = None PyTorchTrainable = None +TrainingOperator = None try: import torch # noqa: F401 @@ -10,6 +11,8 @@ try: from ray.util.sgd.pytorch.pytorch_trainer import (PyTorchTrainer, PyTorchTrainable) - __all__ = ["PyTorchTrainer", "PyTorchTrainable"] + from ray.util.sgd.pytorch.training_operator import TrainingOperator + + __all__ = ["PyTorchTrainer", "PyTorchTrainable", "TrainingOperator"] except ImportError: logger.warning("PyTorch not found. PyTorchTrainer will not be available") diff --git a/python/ray/util/sgd/pytorch/constants.py b/python/ray/util/sgd/pytorch/constants.py new file mode 100644 index 000000000..0d7421a42 --- /dev/null +++ b/python/ray/util/sgd/pytorch/constants.py @@ -0,0 +1,7 @@ +USE_FP16 = "__use_fp16__" +BATCH_COUNT = "batch_count" +SCHEDULER_STEP = "scheduler_step" +SCHEDULER_STEP_BATCH = "batch" +SCHEDULER_STEP_EPOCH = "epoch" + +VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH} diff --git a/python/ray/util/sgd/pytorch/distributed_pytorch_runner.py b/python/ray/util/sgd/pytorch/distributed_pytorch_runner.py index 8db45dfe4..f3c8a8180 100644 --- a/python/ray/util/sgd/pytorch/distributed_pytorch_runner.py +++ b/python/ray/util/sgd/pytorch/distributed_pytorch_runner.py @@ -95,15 +95,22 @@ class DistributedPyTorchRunner(PyTorchRunner): self.validation_loader = torch.utils.data.DataLoader( val_set, batch_size=self.batch_size, **self.dataloader_config) - def step(self): + self.training_operator = self.training_operator_cls( + self.config, + models=self.models, + optimizers=self.optimizers, + criterion=self.criterion, + schedulers=self.schedulers, + use_fp16=self.use_fp16) + + def train_epoch(self, **kwargs): """Runs a training epoch and updates the model parameters. Automatically sets epoch of sampler if possible. """ - logger.debug("Starting step") if hasattr(self.train_loader.sampler, "set_epoch"): - self.train_loader.sampler.set_epoch(self.epoch) - return super(DistributedPyTorchRunner, self).step() + self.train_loader.sampler.set_epoch(self.epochs) + return super(DistributedPyTorchRunner, self).train_epoch(**kwargs) def _get_model_state_dicts(self): """Fetch state from ``model.module`` instead of ``model``. diff --git a/python/ray/util/sgd/pytorch/examples/cifar_pytorch_example.py b/python/ray/util/sgd/pytorch/examples/cifar_pytorch_example.py index e039ec839..2007dfc80 100644 --- a/python/ray/util/sgd/pytorch/examples/cifar_pytorch_example.py +++ b/python/ray/util/sgd/pytorch/examples/cifar_pytorch_example.py @@ -10,10 +10,9 @@ import torchvision.transforms as transforms import ray from ray.util.sgd.pytorch import (PyTorchTrainer, PyTorchTrainable) from ray.util.sgd.pytorch.resnet import ResNet18 -from ray.util.sgd.pytorch.utils import TEST_MODE -def initialization_hook(runner): +def initialization_hook(): print("NCCL DEBUG SET") # Need this for avoiding a connection restart issue os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" @@ -40,6 +39,11 @@ def cifar_creator(config): validation_dataset = torchvision.datasets.CIFAR10( root="~/data", train=False, download=False, transform=transform_test) + if config.get("test_mode"): + train_dataset = torch.utils.data.Subset(train_dataset, list(range(64))) + validation_dataset = torch.utils.data.Subset(validation_dataset, + list(range(64))) + return train_dataset, validation_dataset @@ -58,7 +62,6 @@ def train_example(num_replicas=1, use_gpu=False, use_fp16=False, test_mode=False): - config = {TEST_MODE: test_mode} trainer1 = PyTorchTrainer( ResNet18, cifar_creator, @@ -67,7 +70,10 @@ def train_example(num_replicas=1, scheduler_creator=scheduler_creator, initialization_hook=initialization_hook, num_replicas=num_replicas, - config=config, + config={ + "lr": 0.01, + "test_mode": test_mode + }, use_gpu=use_gpu, batch_size=16 if test_mode else 512, backend="nccl" if use_gpu else "gloo", @@ -88,14 +94,14 @@ def tune_example(num_replicas=1, use_gpu=False, test_mode=False): "model_creator": ResNet18, "data_creator": cifar_creator, "optimizer_creator": optimizer_creator, - "loss_creator": lambda config: nn.CrossEntropyLoss(), + "loss_creator": nn.CrossEntropyLoss, "num_replicas": num_replicas, "initialization_hook": initialization_hook, "use_gpu": use_gpu, "batch_size": 16 if test_mode else 512, "config": { - "lr": tune.choice([1e-4, 1e-3, 5e-3, 1e-2]), - TEST_MODE: test_mode + "lr": tune.choice([1e-4, 1e-3]), + "test_mode": test_mode }, "backend": "nccl" if use_gpu else "gloo" } diff --git a/python/ray/util/sgd/pytorch/examples/dcgan.py b/python/ray/util/sgd/pytorch/examples/dcgan.py index 110697d10..ad627f533 100644 --- a/python/ray/util/sgd/pytorch/examples/dcgan.py +++ b/python/ray/util/sgd/pytorch/examples/dcgan.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn import torch.optim as optim import torch.utils.data -import torchvision.datasets as dset +import torchvision.datasets as datasets import torchvision.transforms as transforms import numpy as np @@ -16,25 +16,12 @@ from scipy.stats import entropy import ray from ray.util.sgd import PyTorchTrainer -from ray.util.sgd.pytorch.utils import TEST_MODE - -# Training parameters -TRAIN_BATCHES = 5 -# Number of channels in the training images. For color images this is 3 -num_channels = 1 - -# Size of z latent vector (i.e. size of generator input) -latent_vector_size = 100 - -# Size of feature maps in generator -features_g = 32 - -# Size of feature maps in discriminator -features_d = 32 +from ray.util.sgd.utils import override +from ray.util.sgd.pytorch import TrainingOperator def data_creator(config): - return dset.MNIST( + dataset = datasets.MNIST( root="~/mnist/", download=True, transform=transforms.Compose([ @@ -42,62 +29,56 @@ def data_creator(config): transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )), ])) - - -def weights_init(m): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("BatchNorm") != -1: - nn.init.normal_(m.weight.data, 1.0, 0.02) - nn.init.constant_(m.bias.data, 0) + if config.get("test_mode"): + dataset = torch.utils.data.Subset(dataset, list(range(64))) + return dataset class Generator(nn.Module): - def __init__(self): + def __init__(self, latent_vector_size, features=32, num_channels=1): super(Generator, self).__init__() + self.latent_vector_size = latent_vector_size self.main = nn.Sequential( # input is Z, going into a convolution nn.ConvTranspose2d( - latent_vector_size, features_g * 4, 4, 1, 0, bias=False), - nn.BatchNorm2d(features_g * 4), + latent_vector_size, features * 4, 4, 1, 0, bias=False), + nn.BatchNorm2d(features * 4), nn.ReLU(True), nn.ConvTranspose2d( - features_g * 4, features_g * 2, 4, 2, 1, bias=False), - nn.BatchNorm2d(features_g * 2), + features * 4, features * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(features * 2), nn.ReLU(True), - nn.ConvTranspose2d( - features_g * 2, features_g, 4, 2, 1, bias=False), - nn.BatchNorm2d(features_g), + nn.ConvTranspose2d(features * 2, features, 4, 2, 1, bias=False), + nn.BatchNorm2d(features), nn.ReLU(True), - nn.ConvTranspose2d(features_g, num_channels, 4, 2, 1, bias=False), + nn.ConvTranspose2d(features, num_channels, 4, 2, 1, bias=False), nn.Tanh()) - def forward(self, input): - return self.main(input) + def forward(self, x): + return self.main(x) class Discriminator(nn.Module): - def __init__(self): + def __init__(self, features=32, num_channels=1): super(Discriminator, self).__init__() self.main = nn.Sequential( - nn.Conv2d(num_channels, features_d, 4, 2, 1, bias=False), + nn.Conv2d(num_channels, features, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False), - nn.BatchNorm2d(features_d * 2), nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False), - nn.BatchNorm2d(features_d * 4), nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(features_d * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid()) + nn.Conv2d(features, features * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(features * 2), nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(features * 2, features * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(features * 4), nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(features * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid()) - def forward(self, input): - return self.main(input) + def forward(self, x): + return self.main(x) -class Net(nn.Module): +class LeNet(nn.Module): """LeNet for MNist classification, used for inception_score.""" def __init__(self): - super(Net, self).__init__() + super(LeNet, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() @@ -114,92 +95,22 @@ class Net(nn.Module): return F.log_softmax(x, dim=1) -def inception_score(imgs, batch_size=32, splits=1): - N = len(imgs) - dtype = torch.FloatTensor - dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) - cm = Net() - cm.load_state_dict(torch.load(model_path)) - cm.eval() - up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype) - - def get_pred(x): - x = up(x) - x = cm(x) - return F.softmax(x).data.cpu().numpy() - - preds = np.zeros((N, 10)) - for i, batch in enumerate(dataloader, 0): - batch = batch.type(dtype) - batchv = Variable(batch) - batch_size_i = batch.size()[0] - preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv) - - # Now compute the mean kl-div - split_scores = [] - for k in range(splits): - part = preds[k * (N // splits):(k + 1) * (N // splits), :] - py = np.mean(part, axis=0) - scores = [] - for i in range(part.shape[0]): - pyx = part[i, :] - scores.append(entropy(pyx, py)) - split_scores.append(np.exp(np.mean(scores))) - - return np.mean(split_scores), np.std(split_scores) - - def model_creator(config): - netD = Discriminator() - netD.apply(weights_init) + def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) - netG = Generator() - netG.apply(weights_init) - return netD, netG + discriminator = Discriminator() + discriminator.apply(weights_init) - -def train(config, models, dataloader, criterion, optimizers, **kwargs): - netD, netG = models - optimD, optimG = optimizers - real_label = 1 - fake_label = 0 - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - for i, data in enumerate(dataloader, 0): - if i >= TRAIN_BATCHES and config.get(TEST_MODE): - break - - netD.zero_grad() - real_cpu = data[0].to(device) - b_size = real_cpu.size(0) - label = torch.full((b_size, ), real_label, device=device) - output = netD(real_cpu).view(-1) - errD_real = criterion(output, label) - errD_real.backward() - - noise = torch.randn(b_size, latent_vector_size, 1, 1, device=device) - fake = netG(noise) - label.fill_(fake_label) - output = netD(fake.detach()).view(-1) - errD_fake = criterion(output, label) - errD_fake.backward() - errD = errD_real + errD_fake - optimD.step() - - netG.zero_grad() - label.fill_(real_label) - output = netD(fake).view(-1) - errG = criterion(output, label) - errG.backward() - optimG.step() - - is_score, is_std = inception_score(fake) - - return { - "loss_g": errG.item(), - "loss_d": errD.item(), - "inception": is_score - } + generator = Generator( + latent_vector_size=config.get("latent_vector_size", 100)) + generator.apply(weights_init) + return discriminator, generator def optimizer_creator(models, config): @@ -211,22 +122,122 @@ def optimizer_creator(models, config): return discriminator_opt, generator_opt +class GANOperator(TrainingOperator): + def setup(self, config): + self.device = torch.device("cuda" + if torch.cuda.is_available() else "cpu") + + self.classifier = LeNet() + self.classifier.load_state_dict( + torch.load(config["classification_model_path"])) + self.classifier.eval() + + def inception_score(self, imgs, batch_size=32, splits=1): + """Calculate the inception score of the generated images.""" + N = len(imgs) + dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) + up = nn.Upsample( + size=(28, 28), mode="bilinear").type(torch.FloatTensor) + + def get_pred(x): + x = up(x) + x = self.classifier(x) + return F.softmax(x).data.cpu().numpy() + + # Obtain predictions for the fake provided images + preds = np.zeros((N, 10)) + for i, batch in enumerate(dataloader, 0): + batch = batch.type(torch.FloatTensor) + batchv = Variable(batch) + batch_size_i = batch.size()[0] + preds[i * batch_size:i * batch_size + + batch_size_i] = get_pred(batchv) + + # Now compute the mean kl-div + split_scores = [] + for k in range(splits): + part = preds[k * (N // splits):(k + 1) * (N // splits), :] + py = np.mean(part, axis=0) + scores = [] + for i in range(part.shape[0]): + pyx = part[i, :] + scores.append(entropy(pyx, py)) + split_scores.append(np.exp(np.mean(scores))) + + return np.mean(split_scores), np.std(split_scores) + + @override(TrainingOperator) + def train_batch(self, batch, batch_info): + """Trains on one batch of data from the data creator.""" + real_label = 1 + fake_label = 0 + discriminator, generator = self.models + optimD, optimG = self.optimizers + + # Compute a discriminator update for real images + discriminator.zero_grad() + real_cpu = batch[0].to(self.device) + batch_size = real_cpu.size(0) + label = torch.full((batch_size, ), real_label, device=self.device) + output = discriminator(real_cpu).view(-1) + errD_real = self.criterion(output, label) + errD_real.backward() + + # Compute a discriminator update for fake images + noise = torch.randn( + batch_size, + self.config.get("latent_vector_size", 100), + 1, + 1, + device=self.device) + fake = generator(noise) + label.fill_(fake_label) + output = discriminator(fake.detach()).view(-1) + errD_fake = self.criterion(output, label) + errD_fake.backward() + errD = errD_real + errD_fake + + # Update the discriminator + optimD.step() + + # Update the generator + generator.zero_grad() + label.fill_(real_label) + output = discriminator(fake).view(-1) + errG = self.criterion(output, label) + errG.backward() + optimG.step() + + is_score, is_std = self.inception_score(fake) + + return { + "loss_g": errG.item(), + "loss_d": errD.item(), + "inception": is_score, + "num_samples": batch_size + } + + def train_example(num_replicas=1, use_gpu=False, test_mode=False): - config = {TEST_MODE: test_mode} + config = { + "test_mode": test_mode, + "classification_model_path": os.path.join( + os.path.dirname(ray.__file__), + "util/sgd/pytorch/examples/mnist_cnn.pt") + } trainer = PyTorchTrainer( model_creator, data_creator, optimizer_creator, nn.BCELoss, - train_function=train, - validation_function=False, + training_operator_cls=GANOperator, num_replicas=num_replicas, config=config, use_gpu=use_gpu, batch_size=16 if test_mode else 512, backend="nccl" if use_gpu else "gloo") - for i in range(10): - stats = trainer.train(max_retries=3) + for i in range(5): + stats = trainer.train() print(stats) return trainer @@ -240,7 +251,7 @@ if __name__ == "__main__": "--address", required=False, type=str, - help="the address to use for Redis") + help="the address to use to connect to a cluster.") parser.add_argument( "--num-replicas", "-n", @@ -255,10 +266,6 @@ if __name__ == "__main__": args, _ = parser.parse_known_args() ray.init(address=args.address) - path = os.path.dirname(ray.__file__) - model_path = os.path.join(path, "util/sgd/pytorch/examples/mnist_cnn.pt") - # load the pretrained mnist classification model for inception_score - trainer = train_example( num_replicas=args.num_replicas, use_gpu=args.use_gpu, diff --git a/python/ray/util/sgd/pytorch/pytorch_runner.py b/python/ray/util/sgd/pytorch/pytorch_runner.py index 7c8803896..0be7ba9db 100644 --- a/python/ray/util/sgd/pytorch/pytorch_runner.py +++ b/python/ray/util/sgd/pytorch/pytorch_runner.py @@ -2,13 +2,15 @@ import collections from filelock import FileLock import logging import inspect +import itertools import os import torch import torch.utils.data from torch.utils.data import Dataset import ray -from ray.util.sgd.pytorch import utils as pytorch_utils +from ray.util.sgd.pytorch.constants import USE_FP16, SCHEDULER_STEP +from ray.util.sgd.pytorch.training_operator import TrainingOperator from ray.util.sgd import utils logger = logging.getLogger(__name__) @@ -31,8 +33,7 @@ class PyTorchRunner: loss_creator (dict -> loss | Loss class): see pytorch_trainer.py. scheduler_creator (optimizers, dict -> schedulers): see pytorch_trainer.py. - train_function: see pytorch_trainer.py - validation_function: see pytorch_trainer.py + training_operator_cls: see pytorch_trainer.py config (dict): see pytorch_trainer.py. dataloader_config (dict): See pytorch_trainer.py. batch_size (int): see pytorch_trainer.py. @@ -47,8 +48,7 @@ class PyTorchRunner: optimizer_creator, loss_creator, scheduler_creator=None, - train_function=None, - validation_function=None, + training_operator_cls=None, config=None, dataloader_config=None, batch_size=16, @@ -60,17 +60,15 @@ class PyTorchRunner: self.optimizer_creator = optimizer_creator self.loss_creator = loss_creator self.scheduler_creator = scheduler_creator + self.training_operator_cls = training_operator_cls or TrainingOperator self.config = {} if config is None else config self.dataloader_config = { "num_workers": 2 } if dataloader_config is None else dataloader_config - self.train_function = train_function or pytorch_utils.train - self.validation_function = (validation_function - or pytorch_utils.validate) self.batch_size = batch_size self.verbose = True - self.epoch = 0 + self.epochs = 0 self._timers = { k: utils.TimerStat(window_size=1) for k in [ @@ -160,6 +158,14 @@ class PyTorchRunner: self.validation_loader = torch.utils.data.DataLoader( val_set, batch_size=self.batch_size, **self.dataloader_config) + self.training_operator = self.training_operator_cls( + self.config, + models=self.models, + optimizers=self.optimizers, + criterion=self.criterion, + schedulers=self.schedulers, + use_fp16=self.use_fp16) + def get_node_ip(self): """Returns the IP address of the current node.""" return ray.services.get_node_ip_address() @@ -168,47 +174,42 @@ class PyTorchRunner: """Finds a free port on the current node.""" return utils.find_free_port() - def step(self): + def train_epoch(self, num_steps=None, info=None): """Runs a training epoch and updates the model parameters.""" - logger.debug("Begin Training Epoch {}".format(self.epoch + 1)) - training_config = self.config.copy() - training_config.update({ - pytorch_utils.USE_FP16: self.use_fp16, - pytorch_utils.SCHEDULER_STEP: self.scheduler_step_freq + logger.debug("Begin Training Step {}".format(self.epochs + 1)) + info = info or {} + info.update({ + USE_FP16: self.use_fp16, + SCHEDULER_STEP: self.scheduler_step_freq }) with self._timers["training"]: - train_stats = self.train_function( - training_config, - self.given_models, - self.train_loader, - self.criterion, - self.given_optimizers, - scheduler=self.given_schedulers) - train_stats["epoch"] = self.epoch - - self.epoch += 1 + iterator = self.train_loader + if num_steps: + iterator = itertools.islice(iter(self.train_loader), num_steps) + train_stats = self.training_operator.train_epoch(iterator, info) + self.epochs += 1 train_stats.update(self.stats()) return train_stats - def validate(self): + def validate(self, num_steps=None, info=None): """Evaluates the model on the validation data set.""" if self.validation_loader is None: raise ValueError("No validation dataloader provided.") + info = info or {} with self._timers["validation"]: - validation_stats = self.validation_function( - self.config, - self.given_models, - self.validation_loader, - self.criterion, - scheduler=self.given_schedulers) + iterator = self.validation_loader + if num_steps: + iterator = itertools.islice( + iter(self.validation_loader), num_steps) + validation_stats = self.training_operator.validate(iterator, info) validation_stats.update(self.stats()) return validation_stats def stats(self): """Returns a dictionary of statistics collected.""" - stats = {"epoch": self.epoch} + stats = {"epoch": self.epochs} for k, t in self._timers.items(): stats[k + "_time_mean"] = t.mean stats[k + "_time_total"] = t.sum @@ -233,7 +234,8 @@ class PyTorchRunner: """Returns the state of the runner.""" state = { - "epoch": self.epoch, + "epoch": self.epochs, + "operator": self.training_operator.state_dict(), "models": self._get_model_state_dicts(), "optimizers": [opt.state_dict() for opt in self.optimizers], "stats": self.stats() @@ -262,13 +264,18 @@ class PyTorchRunner: if self.use_fp16 and "amp" in state and amp: amp.load_state_dict(state["amp"]) - self.epoch = state["stats"]["epoch"] + self.epochs = state["stats"]["epoch"] + self.training_operator.load_state_dict(state_dict) - def apply_fn(self, fn): - return fn(self) + def apply(self, fn): + return fn() + + def apply_operator(self, fn): + return fn(self.training_operator) def shutdown(self): """Attempts to shut down the worker.""" + del self.training_operator del self.validation_loader del self.train_loader del self.criterion diff --git a/python/ray/util/sgd/pytorch/pytorch_trainer.py b/python/ray/util/sgd/pytorch/pytorch_trainer.py index 6eec29ba3..0316cd11a 100644 --- a/python/ray/util/sgd/pytorch/pytorch_trainer.py +++ b/python/ray/util/sgd/pytorch/pytorch_trainer.py @@ -15,12 +15,20 @@ from ray.util.sgd.pytorch.distributed_pytorch_runner import ( DistributedPyTorchRunner) from ray.util.sgd import utils from ray.util.sgd.pytorch.pytorch_runner import PyTorchRunner -from ray.util.sgd.pytorch import utils as pytorch_utils +from ray.util.sgd.pytorch.constants import VALID_SCHEDULER_STEP logger = logging.getLogger(__name__) RESIZE_COOLDOWN_S = 10 +def _validate_scheduler_step_freq(scheduler_step_freq): + if scheduler_step_freq: + if scheduler_step_freq not in VALID_SCHEDULER_STEP: + raise ValueError( + "Scheduler step freq must be in {}. Got {}".format( + VALID_SCHEDULER_STEP, scheduler_step_freq)) + + class PyTorchTrainer: """Train a PyTorch model using distributed PyTorch. @@ -48,14 +56,15 @@ class PyTorchTrainer: loss_creator=nn.MSELoss, use_gpu=True ) - trainer.train() + for i in range(4): + trainer.train() Args: model_creator (dict -> Model(s)): Constructor function that takes in config and returns the model(s) to be optimized. These must be ``torch.nn.Module`` objects. If multiple models are returned, - a ``train_function`` must be specified. You do not need to + a ``training_operator_cls`` must be specified. You do not need to handle GPU/devices in this function; RaySGD will do that under the hood. data_creator (dict -> Dataset(s)): Constructor function @@ -75,22 +84,18 @@ class PyTorchTrainer: of ``torch.nn.modules.loss._Loss``, which is most Pytorch loss classes. For example, ``loss_creator=torch.nn.BCELoss``. scheduler_creator (optimizers, dict -> loss): - A constructor function for the scheduler loss. This is + A constructor function for the torch scheduler. This is a function that takes in the generated optimizers (from ``optimizer_creator``) provided config for customization. Be sure to set ``scheduler_step_freq`` to increment the scheduler correctly. - train_function: Custom function for training. This function - will be executed in parallel across all workers at once. The - function needs to take in (models, train_dataloader, criterion, - optimizers, config), and return a dict of training stats. - validation_function: Custom function for validation. This function - will be executed in parallel across all workers at once. - This takes in (model, val_dataloader, criterion, config) - and returns a dict of validation stats. + training_operator_cls (type): Custom training operator class + that subclasses the TrainingOperator class. This class + will be copied onto all remote workers and used to specify + custom training and validation operations. Defaults to + TrainingOperator. config (dict): Custom configuration value to be passed to - "model_creator", "data_creator", "optimizer_creator", and - "loss_creator". + all creator and operator constructors. dataloader_config (dict): Configuration values to be passed into the ``torch.utils.data.DataLoader`` object that wraps the dataset on each parallel worker for both training @@ -130,8 +135,7 @@ class PyTorchTrainer: optimizer_creator, loss_creator, scheduler_creator=None, - train_function=None, - validation_function=None, + training_operator_cls=None, initialization_hook=None, config=None, dataloader_config=None, @@ -151,11 +155,10 @@ class PyTorchTrainer: self.model_creator = model_creator self.data_creator = data_creator - self.train_function = train_function self.optimizer_creator = optimizer_creator self.loss_creator = loss_creator self.scheduler_creator = scheduler_creator - self.validation_function = validation_function + self.training_operator_cls = training_operator_cls self.initialization_hook = initialization_hook self.config = {} if config is None else config self.dataloader_config = dataloader_config @@ -166,6 +169,8 @@ class PyTorchTrainer: logger.info("Using {} as backend.".format(backend)) self.backend = backend + + # TODO: Have an auto "use_gpu" option to detect and use GPUs. self.use_gpu = use_gpu self.batch_size = batch_size self.max_replicas = num_replicas @@ -180,12 +185,7 @@ class PyTorchTrainer: self._num_failures = 0 self._last_resize = float("-inf") - if scheduler_step_freq and ( - scheduler_step_freq not in pytorch_utils.VALID_SCHEDULER_STEP): - raise ValueError( - "Scheduler step freq must be in {}. Got {}".format( - pytorch_utils.VALID_SCHEDULER_STEP, scheduler_step_freq)) - + _validate_scheduler_step_freq(scheduler_step_freq) self.scheduler_step_freq = scheduler_step_freq self._start_workers(self.max_replicas) @@ -204,8 +204,7 @@ class PyTorchTrainer: self.optimizer_creator, self.loss_creator, self.scheduler_creator, - train_function=self.train_function, - validation_function=self.validation_function, + training_operator_cls=self.training_operator_cls, config=self.config, dataloader_config=self.dataloader_config, batch_size=self.batch_size, @@ -243,8 +242,7 @@ class PyTorchTrainer: self.loss_creator, self.scheduler_creator, backend=self.backend, - train_function=self.train_function, - validation_function=self.validation_function, + training_operator_cls=self.training_operator_cls, config=self.config, dataloader_config=self.dataloader_config, batch_size=batch_size_per_replica, @@ -266,21 +264,35 @@ class PyTorchTrainer: for i, worker in enumerate(self.workers) ]) - def train(self, max_retries=0, checkpoint="auto"): + def train(self, + num_steps=None, + max_retries=0, + checkpoint="auto", + info=None): """Runs a training epoch. Runs an average over all values returned from workers. Set `max_retries` to enable fault handling in case of instance preemption. Args: + num_steps (int): Number of batches to compute update steps on. + This corresponds also to the number of times + ``TrainingOperator.train_batch`` is called. max_retries (int): Must be non-negative. If set to N, will kill all current workers, query the Ray global state for total available resources, and re-launch up to the available resources. Behavior is not well-defined in case of shared cluster usage. checkpoint (str): Path to checkpoint to restore from if retrying. - If max_retries is set and checkpoint == "auto", PyTorchTrainer - will save a checkpoint before starting to train. + If max_retries is set and ``checkpoint == "auto"``, + PyTorchTrainer will save a checkpoint before starting to train. + info (dict): Optional dictionary passed to the training + operator for ``train_epoch`` and ``train_batch``. + + Returns: + A dictionary of metrics for training. + You can provide custom metrics by passing in a custom + ``training_operator_cls``. """ assert max_retries >= 0, "`max_retries` must be non-negative." if max_retries: @@ -296,7 +308,8 @@ class PyTorchTrainer: self._resize_workers(checkpoint=checkpoint) with self.optimizer_timer: - success, worker_stats = self._train_step() + success, worker_stats = self._train_epoch( + num_steps=num_steps, info=info) # Fault handling for i in range(max_retries): if success: @@ -306,7 +319,8 @@ class PyTorchTrainer: self._resize_workers(checkpoint=checkpoint) logger.info("Retrying training step with %d workers." % len( self.workers)) - success, worker_stats = self._train_step() + success, worker_stats = self._train_epoch( + num_steps=num_steps, info=info) if not success: raise RuntimeError("Training run failed.") @@ -321,19 +335,58 @@ class PyTorchTrainer: train_stats[stat_key] = worker_stats[0][stat_key] return train_stats - def _train_step(self): - worker_stats = [w.step.remote() for w in self.workers] + def _train_epoch(self, num_steps=None, info=None): + worker_stats = [ + w.train_epoch.remote(num_steps=num_steps, info=info) + for w in self.workers + ] success = utils.check_for_failure(worker_stats) return success, worker_stats def apply_all_workers(self, fn): - return ray.get([w.apply_fn.remote(fn) for w in self.workers]) + """Run a function on all operators on the workers. - def validate(self): - """Evaluates the model on the validation data set.""" - if self.validation_function is False: - return {} - worker_stats = ray.get([w.validate.remote() for w in self.workers]) + Args: + fn (Callable): A function that takes in no arguments. + + Returns: + A list of objects returned by ``fn`` on each worker. + + """ + return ray.get([w.apply.remote(fn) for w in self.workers]) + + def apply_all_operators(self, fn): + """Run a function on all operators on the workers. + + Args: + fn (Callable[TrainingOperator]): A function that takes in a + TrainingOperator. + + Returns: + A list of objects returned by ``fn`` on each operator. + + """ + return ray.get([w.apply_operator.remote(fn) for w in self.workers]) + + def validate(self, num_steps=None, info=None): + """Evaluates the model on the validation data set. + + Args: + num_steps (int): Number of batches to compute update steps on. + This corresponds also to the number of times + ``TrainingOperator.validate_batch`` is called. + info (dict): Optional dictionary passed to the training + operator for `validate` and `validate_batch`. + + Returns: + A dictionary of metrics for validation. + You can provide custom metrics by passing in a custom + ``training_operator_cls``. + """ + worker_stats = ray.get([ + w.validate.remote(num_steps=num_steps, info=info) + for w in self.workers + ]) validation_stats = {} for stat_key in worker_stats[0]: @@ -346,8 +399,8 @@ class PyTorchTrainer: This is useful for lr_schedulers such as ``ReduceLROnPlateau``. """ - self.apply_all_workers( - lambda runner: [sched.step(metric) for sched in runner.schedulers]) + self.apply_all_operators( + lambda op: [sched.step(metric) for sched in op.schedulers]) def get_model(self): """Returns the learned model(s).""" @@ -366,17 +419,18 @@ class PyTorchTrainer: Args: checkpoint (str): Path to target checkpoint file. + Returns: + checkpoint (str): Path to target checkpoint file. """ state = ray.get(self.workers[0].get_state.remote()) torch.save(state, checkpoint) return checkpoint def restore(self, checkpoint): - """Restores the model from the provided checkpoint. + """Restores the Trainer and all workers from the provided checkpoint. Args: checkpoint (str): Path to target checkpoint file. - """ state = torch.load(checkpoint) state_id = ray.put(state) @@ -450,7 +504,6 @@ class PyTorchTrainable(Trainable): validation_stats = self._trainer.validate() train_stats.update(validation_stats) - # output {"mean_loss": test_loss, "mean_accuracy": accuracy} return train_stats diff --git a/python/ray/util/sgd/pytorch/training_operator.py b/python/ray/util/sgd/pytorch/training_operator.py new file mode 100644 index 000000000..9b4c5f090 --- /dev/null +++ b/python/ray/util/sgd/pytorch/training_operator.py @@ -0,0 +1,343 @@ +import collections +import torch + +from ray.util.sgd.utils import TimerStat, AverageMeter +from ray.util.sgd.pytorch.constants import ( + SCHEDULER_STEP_EPOCH, SCHEDULER_STEP_BATCH, SCHEDULER_STEP, BATCH_COUNT) + +amp = None + +try: + from apex import amp +except ImportError: + # Apex library is not installed, so we cannot enable mixed precision. + # We don't log here because logging happens in the pytorch_runner, + # where amp is initialized. + pass + + +def _is_multiple(component): + """Checks if a component (optimizer, model, etc) is not singular.""" + return isinstance(component, collections.Iterable) and len(component) > 1 + + +class TrainingOperator: + """Abstract class for custom training or validation loops. + + The scheduler will only be called at a batch or epoch frequency, depending + on the user parameter. Be sure to set ``scheduler_step_freq`` in + ``PyTorchTrainer`` to either "batch" or "epoch" to increment the scheduler + correctly during training. If using a learning rate scheduler + that depends on validation loss, you can use ``trainer.update_scheduler``. + + For both training and validation, there are two granularities that + you can provide customization: per epoch or per batch. + You do not need to override both. + + .. image:: raysgd-custom.jpg + :scale: 80% + :align: center + + Raises: + ValueError if multiple models/optimizers/schedulers are provided. + You are expected to subclass this class if you wish + to train over multiple models/optimizers/schedulers. + """ + + def __init__(self, + config, + models, + optimizers, + criterion, + schedulers=None, + use_fp16=False): + # You are not expected to override this method. + self.timers = { + k: TimerStat() + for k in ["fwd", "grad", "apply", "epoch_time"] + } + self._validated_customization = False + self._models = models # List of models + assert isinstance(models, collections.Iterable), ( + "Components need to be iterable. Got: {}".format(type(models))) + self._optimizers = optimizers # List of optimizers + assert isinstance(optimizers, collections.Iterable), ( + "Components need to be iterable. Got: {}".format(type(optimizers))) + self._criterion = criterion + self._schedulers = schedulers + if schedulers: + assert isinstance(schedulers, collections.Iterable), ( + "Components need to be iterable. Got: {}".format( + type(schedulers))) + self._config = config + self._use_fp16 = use_fp16 + self.global_step = 0 + + if type(self) is TrainingOperator: + for component in (models, schedulers, optimizers): + if _is_multiple(component): + raise ValueError( + "Need to provide a custom operator subclassing " + "TrainingOperator if using multi-scheduler, " + "multi-model or multi-optimizer training/validation.") + + self.setup(config) + + def setup(self, config): + """Override this method to implement custom operator setup. + + Args: + config (dict): Custom configuration value to be passed to + all creator and operator constructors. Same as ``self.config``. + """ + pass + + def train_epoch(self, iterator, info): + """Runs one standard training pass over the train_iterator. + + By default, this method will iterate over the given iterator and + call ``self.train_batch`` over each batch. + + If ``scheduler_step_freq`` is set, this class will also step the + scheduler accordingly. + + You do not need to call ``train_batch`` in this method if you plan + to implement a custom optimization/training routine here. + + Args: + iterator (iter): Iterator over the training data for the entire + epoch. This iterator is expected to be entirely consumed. + info (dict): Dictionary for information to be used for custom + training operations. + + Returns: + A dict of metrics from training. + """ + self._losses = AverageMeter() + + self.model.train() + with self.timers["epoch_time"]: + for batch_idx, batch in enumerate(iterator): + batch_info = { + "batch_idx": batch_idx, + "global_step": self.global_step + } + batch_info.update(info) + metrics = self.train_batch(batch, batch_info=batch_info) + + if self.scheduler and batch_info.get( + SCHEDULER_STEP) == SCHEDULER_STEP_BATCH: + self.scheduler.step() + + if "loss" in metrics: + self._losses.update( + metrics["loss"], n=metrics.get("num_samples", 1)) + self.global_step += 1 + + if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH: + self.scheduler.step() + + stats = { + BATCH_COUNT: batch_idx + 1, + "mean_train_loss": self._losses.avg, + "last_train_loss": self._losses.val, + "epoch_time": self.timers["epoch_time"].last + } + stats.update({ + timer_tag: timer.mean + for timer_tag, timer in self.timers.items() + }) + return stats + + def train_batch(self, batch, batch_info): + """Computes loss and updates the model over one batch. + + This method is responsible for computing the loss and gradient and + updating the model. + + By default, this method implementation assumes that batches + are in (features, labels) format. If using amp/fp16 + training, it will also scale the loss automatically. + + You can provide custom loss metrics and training operations if you + override this method. If overriding this method, you can access model, + optimizer, criterion via ``self.model``, ``self.optimizer``, + and ``self.criterion``. + + You do not need to override this method if you plan to + override ``train_epoch``. + + Args: + batch: One item of the validation iterator. + batch_info (dict): Information dict passed in from ``train_epoch``. + + Returns: + A dictionary of metrics. + By default, this dictionary contains "loss" and "num_samples". + "num_samples" corresponds to number of datapoints in the batch. + However, you can provide any number of other values. + + """ + features, target = batch + # Create non_blocking tensors for distributed training + if torch.cuda.is_available(): + features = features.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # Compute output. + with self.timers["fwd"]: + output = self.model(features) + loss = self.criterion(output, target) + + # Compute gradients in a backward pass. + with self.timers["grad"]: + self.optimizer.zero_grad() + if self.use_fp16: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + # Call step of optimizer to update model params. + with self.timers["apply"]: + self.optimizer.step() + return {"loss": loss.item(), "num_samples": features.size(0)} + + def validate(self, val_iterator, info): + """Runs one standard validation pass over the val_iterator. + + This will call ``model.eval()`` and ``torch.no_grad`` when iterating + over the validation dataset. + + If overriding this method, you can access model, criterion via + ``self.model`` and ``self.criterion``. You also do not need to call + ``validate_batch`` if overriding this method. + + Args: + val_iterator (iter): Iterable constructed over the + validation dataset. + info: (dict): Dictionary for information to be used for custom + validation operations. + + Returns: + A dict of metrics from the evaluation. + By default, returns "mean_accuracy" and "mean_validation_loss" + which is computed by aggregating "loss" and "correct" values + from ``validate_batch`` and dividing it by the sum of + ``num_samples`` from all calls to ``self.validate_batch``. + """ + losses = AverageMeter() + total_correct = 0 + + # switch to evaluate mode + self.model.eval() + with torch.no_grad(): + for batch_idx, batch in enumerate(val_iterator): + batch_info = {"batch_idx": batch_idx} + batch_info.update(info) + metrics = self.validate_batch(batch, batch_info) + if "loss" in metrics: + losses.update( + metrics["loss"], n=metrics.get("num_samples", 1)) + + if "num_correct" in metrics: + total_correct += metrics["num_correct"] + + stats = { + "batch_count": batch_idx + 1, + "mean_validation_loss": losses.avg, + "mean_accuracy": total_correct / losses.count + } + return stats + + def validate_batch(self, batch, batch_info): + """Calcuates the loss and accuracy over a given batch. + + You can override this method to provide arbitrary metrics. + + Args: + batch: One item of the validation iterator. + batch_info (dict): Contains information per batch from + ``validate()``. + + Returns: + A dict of metrics. + By default, returns "loss", "num_correct", and "num_samples". + """ + features, target = batch + if torch.cuda.is_available(): + features = features.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # compute output + output = self.model(features) + loss = self.criterion(output, target) + _, predicted = torch.max(output.data, 1) + + return { + "loss": loss.item(), + "num_correct": (predicted == target).sum().item(), + "num_samples": target.size(0) + } + + def state_dict(self): + """Returns a serializable representation of the operator state.""" + pass + + def load_state_dict(self, state_dict): + """Loads a serializable representation of the operator state.""" + pass + + @property + def config(self): + """Dictionary as provided into PyTorchTrainer.""" + return self._config + + @property + def model(self): + """First or only model created by the provided ``model_creator``.""" + return self._models[0] + + @property + def models(self): + """List of models created by the provided ``model_creator``.""" + return self._models + + @property + def optimizer(self): + """First or only optimizer(s) created by the ``optimizer_creator``.""" + return self._optimizers[0] + + @property + def optimizers(self): + """List of optimizers created by the ``optimizer_creator``.""" + return self._optimizers + + @property + def criterion(self): + """Criterion created by the provided ``loss_creator``.""" + return self._criterion + + @property + def scheduler(self): + """First or only scheduler(s) created by the ``scheduler_creator``.""" + if self._schedulers: + return self._schedulers[0] + + @property + def schedulers(self): + """List of schedulers created by the ``scheduler_creator``.""" + return self._schedulers + + @property + def use_fp16(self): + """Whether the model and optimizer have been FP16 enabled.""" + return self._use_fp16 + + +class _TestingOperator(TrainingOperator): + def train_epoch(self, iterator, info): + func = self.config.get("custom_func") + if callable(func): + return func(self, iterator, info) + return {"done": 1} diff --git a/python/ray/util/sgd/pytorch/utils.py b/python/ray/util/sgd/pytorch/utils.py deleted file mode 100644 index a28407983..000000000 --- a/python/ray/util/sgd/pytorch/utils.py +++ /dev/null @@ -1,229 +0,0 @@ -import collections -import time -import torch - -from ray.util.sgd.utils import TimerStat - -amp = None - -try: - from apex import amp -except ImportError: - # Apex library is not installed, so we cannot enable mixed precision. - # We don't log here because logging happens in the pytorch_runner, - # where amp is initialized. - pass - -USE_FP16 = "__use_fp16__" -TEST_MODE = "__test_mode__" -BATCH_COUNT = "batch_processed" -SCHEDULER_STEP = "scheduler_step" -SCHEDULER_STEP_BATCH = "batch" -SCHEDULER_STEP_EPOCH = "epoch" - -VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH} - - -def train(config, model, train_iterator, criterion, optimizer, scheduler=None): - """Runs one standard training pass over the train_iterator. - - This function automatically measures timing for various operations such - as host to device transfer, gradient calculation, and gradient application. - - It also automatically detects and places the data on the given GPU device - if available. - - The scheduler will only be called at a batch or epoch frequency, depending - on the user parameter. Be sure to set ``scheduler_step_freq`` in - ``PyTorchTrainer`` to either "batch" or "epoch" to increment the scheduler - correctly during training. If using a learning rate scheduler - that depends on validation loss, you can use ``trainer.update_scheduler``. - - Raises: - ValueError if multiple models/optimizers/schedulers are provided. You - are expected to have a custom training function if you wish - to use multiple models/optimizers/schedulers. - - Args: - config: (dict): A user configuration provided into the Trainer - constructor. - model: The model as created by the model_creator. - train_iterator: An iterator created from the DataLoader which - wraps the provided Dataset. - criterion: The loss object created by the loss_creator. - optimizer: The torch.optim.Optimizer object as created by the - optimizer_creator. - scheduler (optional): The torch.optim.lr_scheduler object - as created by the scheduler_creator. Be sure to set - ``scheduler_step_freq`` in ``PyTorchTrainer`` - to increment the scheduler correctly. - - Returns: - A dict of metrics from training. - """ - if isinstance(model, collections.Iterable) or isinstance( - optimizer, collections.Iterable) or isinstance( - scheduler, collections.Iterable): - raise ValueError( - "Need to provide custom training function if using multi-model " - "or multi-scheduler or multi-optimizer training.") - - batch_time = AverageMeter() - data_time = AverageMeter() - losses = AverageMeter() - - timers = {k: TimerStat() for k in ["h2d", "fwd", "grad", "apply"]} - - # switch to train mode - model.train() - - end = time.time() - - for batch_idx, (features, target) in enumerate(train_iterator): - # measure data loading time - data_time.update(time.time() - end) - - # Create non_blocking tensors for distributed training - with timers["h2d"]: - if torch.cuda.is_available(): - features = features.cuda(non_blocking=True) - target = target.cuda(non_blocking=True) - - # compute output - with timers["fwd"]: - output = model(features) - loss = criterion(output, target) - - # measure accuracy and record loss - losses.update(loss.item(), features.size(0)) - - with timers["grad"]: - # compute gradients in a backward pass - optimizer.zero_grad() - - if config.get(USE_FP16): - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() - - with timers["apply"]: - # Call step of optimizer to update model params - optimizer.step() - - if scheduler and config.get(SCHEDULER_STEP) == SCHEDULER_STEP_BATCH: - scheduler.step() - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - if config.get(TEST_MODE) and batch_idx == 0: - break - - if scheduler and config.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH: - scheduler.step() - - stats = { - "batch_time": batch_time.avg, - BATCH_COUNT: batch_idx + 1, - "train_loss": losses.avg, - "data_time": data_time.avg, - } - stats.update({k: t.mean for k, t in timers.items()}) - return stats - - -def validate(config, model, val_iterator, criterion, scheduler=None): - """Runs one standard validation pass over the val_iterator. - - This function automatically measures timing for various operations such - as host to device transfer and processing time for the batch. - - It also automatically detects and places the data on the given GPU device - if available. - - Raises: - ValueError if multiple models/schedulers are provided. You - are expected to have a custom validation function if you wish - to use multiple models/schedulers. - - Args: - config: (dict): A user configuration provided into the Trainer - constructor. - model: The model as created by the model_creator. - train_iterator: An iterator created from the DataLoader which - wraps the provided Dataset. - criterion: The loss object created by the loss_creator. - scheduler (optional): The torch.optim.lr_scheduler object - as created by the scheduler_creator. By default, - this is not used in this function. - - Returns: - A dict of metrics from the evaluation. - """ - - if isinstance(model, collections.Iterable) or isinstance( - scheduler, collections.Iterable): - raise ValueError( - "Need to provide custom validation function if using multi-model " - "or multi-scheduler training.") - batch_time = AverageMeter() - losses = AverageMeter() - - # switch to evaluate mode - model.eval() - correct = 0 - total = 0 - batch_idx = 0 - with torch.no_grad(): - end = time.time() - for batch_idx, (features, target) in enumerate(val_iterator): - if torch.cuda.is_available(): - features = features.cuda(non_blocking=True) - target = target.cuda(non_blocking=True) - - # compute output - output = model(features) - loss = criterion(output, target) - _, predicted = torch.max(output.data, 1) - total += target.size(0) - correct += (predicted == target).sum().item() - - # measure accuracy and record loss - losses.update(loss.item(), features.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - if config.get(TEST_MODE) and batch_idx == 0: - break - - stats = { - BATCH_COUNT: batch_idx + 1, - "batch_time": batch_time.avg, - "validation_loss": losses.avg, - "mean_accuracy": correct / total, - "mean_loss": losses.sum / total, - } - return stats - - -class AverageMeter: - """Computes and stores the average and current value.""" - - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count diff --git a/python/ray/util/sgd/tests/test_pytorch.py b/python/ray/util/sgd/tests/test_pytorch.py index 08eb28172..d70f2e8c8 100644 --- a/python/ray/util/sgd/tests/test_pytorch.py +++ b/python/ray/util/sgd/tests/test_pytorch.py @@ -10,28 +10,34 @@ import torch.distributed as dist import ray from ray import tune -from ray.tests.conftest import ray_start_2_cpus # noqa: F401 from ray.util.sgd.pytorch import PyTorchTrainer, PyTorchTrainable -from ray.util.sgd.pytorch.utils import (train, BATCH_COUNT, TEST_MODE, - SCHEDULER_STEP) +from ray.util.sgd.pytorch.training_operator import _TestingOperator +from ray.util.sgd.pytorch.constants import BATCH_COUNT, SCHEDULER_STEP from ray.util.sgd.utils import check_for_failure from ray.util.sgd.pytorch.examples.train_example import ( model_creator, optimizer_creator, data_creator, LinearDataset) -def test_test_mode(ray_start_2_cpus): # noqa: F811 +@pytest.fixture +def ray_start_2_cpus(): + address_info = ray.init(num_cpus=2) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +def test_single_step(ray_start_2_cpus): # noqa: F811 trainer = PyTorchTrainer( model_creator, data_creator, optimizer_creator, loss_creator=lambda config: nn.MSELoss(), - config={TEST_MODE: True}, num_replicas=1) - metrics = trainer.train() + metrics = trainer.train(num_steps=1) assert metrics[BATCH_COUNT] == 1 - val_metrics = trainer.validate() + val_metrics = trainer.validate(num_steps=1) assert val_metrics[BATCH_COUNT] == 1 @@ -45,29 +51,51 @@ def test_train(ray_start_2_cpus, num_replicas): # noqa: F811 loss_creator=lambda config: nn.MSELoss(), num_replicas=num_replicas) for i in range(3): - train_loss1 = trainer.train()["train_loss"] - validation_loss1 = trainer.validate()["validation_loss"] + train_loss1 = trainer.train()["mean_train_loss"] + validation_loss1 = trainer.validate()["mean_validation_loss"] for i in range(3): - train_loss2 = trainer.train()["train_loss"] - validation_loss2 = trainer.validate()["validation_loss"] + train_loss2 = trainer.train()["mean_train_loss"] + validation_loss2 = trainer.validate()["mean_validation_loss"] - print(train_loss1, train_loss2) - print(validation_loss1, validation_loss2) - - assert train_loss2 <= train_loss1 - assert validation_loss2 <= validation_loss1 + assert train_loss2 <= train_loss1, (train_loss2, train_loss1) + assert validation_loss2 <= validation_loss1, (validation_loss2, + validation_loss1) @pytest.mark.parametrize("num_replicas", [1, 2] if dist.is_available() else [1]) -def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811 - def custom_train(config, models, dataloader, criterion, optimizers, - **kwargs): +def test_multi_model(ray_start_2_cpus, num_replicas): + def train(*, model=None, criterion=None, optimizer=None, dataloader=None): + model.train() + train_loss = 0 + correct = 0 + total = 0 + for batch_idx, (inputs, targets) in enumerate(dataloader): + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + train_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + return { + "accuracy": correct / total, + "train_loss": train_loss / (batch_idx + 1) + } + + def train_epoch(self, iterator, info): result = {} - for i, (model, optimizer) in enumerate(zip(models, optimizers)): - result["model_{}".format(i)] = train(config, model, dataloader, - criterion, optimizer) + for i, (model, optimizer) in enumerate( + zip(self.models, self.optimizers)): + result["model_{}".format(i)] = train( + model=model, + criterion=self.criterion, + optimizer=optimizer, + dataloader=iterator) return result def multi_model_creator(config): @@ -84,7 +112,8 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811 data_creator, multi_optimizer_creator, loss_creator=lambda config: nn.MSELoss(), - train_function=custom_train, + config={"custom_func": train_epoch}, + training_operator_cls=_TestingOperator, num_replicas=num_replicas) trainer1.train() @@ -100,6 +129,8 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811 data_creator, multi_optimizer_creator, loss_creator=lambda config: nn.MSELoss(), + config={"custom_func": train_epoch}, + training_operator_cls=_TestingOperator, num_replicas=num_replicas) trainer2.restore(filename) @@ -123,16 +154,17 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811 @pytest.mark.parametrize("num_replicas", [1, 2] if dist.is_available() else [1]) def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811 - def custom_train(config, model, dataloader, criterion, optimizer, - scheduler): - if config.get("models", 1) > 1: - assert len(model) == config["models"], config + def train_epoch(self, iterator, info): + if self.config.get("models", 1) > 1: + assert len(self.models) == self.config["models"], self.config - if config.get("optimizers", 1) > 1: - assert len(optimizer) == config["optimizers"], config + if self.config.get("optimizers", 1) > 1: + assert len( + self.optimizers) == self.config["optimizers"], self.config - if config.get("schedulers", 1) > 1: - assert len(scheduler) == config["schedulers"], config + if self.config.get("schedulers", 1) > 1: + assert len( + self.schedulers) == self.config["schedulers"], self.config return {"done": 1} def multi_model_creator(config): @@ -167,12 +199,13 @@ def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811 multi_optimizer_creator, loss_creator=nn.MSELoss, scheduler_creator=multi_scheduler_creator, - train_function=custom_train, + training_operator_cls=_TestingOperator, num_replicas=num_replicas, config={ "models": model_count, "optimizers": optimizer_count, - "schedulers": scheduler_count + "schedulers": scheduler_count, + "custom_func": train_epoch }) trainer.train() trainer.shutdown() @@ -180,9 +213,8 @@ def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811 @pytest.mark.parametrize("scheduler_freq", ["epoch", "batch"]) def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811 - def custom_train(config, model, dataloader, criterion, optimizer, - scheduler): - assert config[SCHEDULER_STEP] == scheduler_freq + def train_epoch(self, iterator, info): + assert info[SCHEDULER_STEP] == scheduler_freq return {"done": 1} def scheduler_creator(optimizer, config): @@ -194,18 +226,17 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811 data_creator, optimizer_creator, loss_creator=lambda config: nn.MSELoss(), - scheduler_creator=scheduler_creator) + config={"custom_func": train_epoch}, + training_operator_cls=_TestingOperator, + scheduler_creator=scheduler_creator, + scheduler_step_freq=scheduler_freq) for i in range(3): - trainer.train()["train_loss"] + trainer.train() trainer.shutdown() def test_scheduler_validate(ray_start_2_cpus): # noqa: F811 - def custom_train(config, model, dataloader, criterion, optimizer, - scheduler): - return {"done": 1} - from torch.optim.lr_scheduler import ReduceLROnPlateau trainer = PyTorchTrainer( @@ -213,11 +244,13 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811 data_creator, optimizer_creator, loss_creator=lambda config: nn.MSELoss(), - scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer)) + scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer), + training_operator_cls=_TestingOperator) trainer.update_scheduler(0.5) trainer.update_scheduler(0.5) assert all( - trainer.apply_all_workers(lambda r: r.schedulers[0].last_epoch == 2)) + trainer.apply_all_operators( + lambda op: op.schedulers[0].last_epoch == 2)) trainer.shutdown() @@ -248,13 +281,13 @@ def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811 # checks loss decreasing for every trials for path, df in analysis.trial_dataframes.items(): - train_loss1 = df.loc[0, "train_loss"] - train_loss2 = df.loc[1, "train_loss"] - validation_loss1 = df.loc[0, "validation_loss"] - validation_loss2 = df.loc[1, "validation_loss"] + mean_train_loss1 = df.loc[0, "mean_train_loss"] + mean_train_loss2 = df.loc[1, "mean_train_loss"] + mean_validation_loss1 = df.loc[0, "mean_validation_loss"] + mean_validation_loss2 = df.loc[1, "mean_validation_loss"] - assert train_loss2 <= train_loss1 - assert validation_loss2 <= validation_loss1 + assert mean_train_loss2 <= mean_train_loss1 + assert mean_validation_loss2 <= mean_validation_loss1 @pytest.mark.parametrize("num_replicas", [1, 2] @@ -303,15 +336,17 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811 def single_loader(config): return LinearDataset(2, 5, size=1000000) - def step_with_fail(self): - worker_stats = [w.step.remote() for w in self.workers] + def step_with_fail(self, *args, **kwargs): + worker_stats = [ + w.train_epoch.remote(*args, **kwargs) for w in self.workers + ] if self._num_failures < 3: time.sleep(1) # Make the batch will fail correctly. self.workers[0].__ray_kill__() success = check_for_failure(worker_stats) return success, worker_stats - with patch.object(PyTorchTrainer, "_train_step", step_with_fail): + with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail): trainer1 = PyTorchTrainer( model_creator, single_loader, @@ -331,15 +366,17 @@ def test_resize(ray_start_2_cpus): # noqa: F811 def single_loader(config): return LinearDataset(2, 5, size=1000000) - def step_with_fail(self): - worker_stats = [w.step.remote() for w in self.workers] + def step_with_fail(self, *args, **kwargs): + worker_stats = [ + w.train_epoch.remote(*args, **kwargs) for w in self.workers + ] if self._num_failures < 1: time.sleep(1) # Make the batch will fail correctly. self.workers[0].__ray_kill__() success = check_for_failure(worker_stats) return success, worker_stats - with patch.object(PyTorchTrainer, "_train_step", step_with_fail): + with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail): trainer1 = PyTorchTrainer( model_creator, single_loader, @@ -365,15 +402,17 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811 def single_loader(config): return LinearDataset(2, 5, size=1000000) - def step_with_fail(self): - worker_stats = [w.step.remote() for w in self.workers] + def step_with_fail(self, *args, **kwargs): + worker_stats = [ + w.train_epoch.remote(*args, **kwargs) for w in self.workers + ] if self._num_failures < 2: time.sleep(1) self.workers[0].__ray_kill__() success = check_for_failure(worker_stats) return success, worker_stats - with patch.object(PyTorchTrainer, "_train_step", step_with_fail): + with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail): trainer1 = PyTorchTrainer( model_creator, single_loader, @@ -383,3 +422,9 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811 num_replicas=2) trainer1.train(max_retries=2) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/sgd/tests/test_pytorch_runner.py b/python/ray/util/sgd/tests/test_pytorch_runner.py index e341194a6..cd7a262c1 100644 --- a/python/ray/util/sgd/tests/test_pytorch_runner.py +++ b/python/ray/util/sgd/tests/test_pytorch_runner.py @@ -4,6 +4,7 @@ import torch.nn as nn import unittest from unittest.mock import MagicMock +from ray.util.sgd.pytorch.training_operator import TrainingOperator from ray.util.sgd.pytorch.pytorch_runner import PyTorchRunner @@ -46,39 +47,55 @@ def create_dataloaders(config): class TestPyTorchRunner(unittest.TestCase): def testValidate(self): - mock_function = MagicMock(returns=dict(mean_accuracy=10)) + class MockOperator(TrainingOperator): + def setup(self, config): + self.train_epoch = MagicMock(returns=dict(mean_accuracy=10)) + self.validate = MagicMock(returns=dict(mean_accuracy=10)) + runner = PyTorchRunner( model_creator, create_dataloaders, optimizer_creator, loss_creator, - validation_function=mock_function) + training_operator_cls=MockOperator) runner.setup() - runner.step() - runner.step() - runner.step() - self.assertEqual(mock_function.call_count, 0) + runner.train_epoch() + runner.train_epoch() + runner.train_epoch() + self.assertEqual(runner.training_operator.validate.call_count, 0) runner.validate() - self.assertTrue(mock_function.called) + self.assertTrue(runner.training_operator.validate.called) self.assertEqual(runner.stats()["epoch"], 3) - def testStep(self): - mock_function = MagicMock(return_value=dict(mean_accuracy=10)) + def testtrain_epoch(self): + class MockOperator(TrainingOperator): + def setup(self, config): + self.count = 0 + + def train_epoch(self, *args, **kwargs): + self.count += 1 + return {"count": self.count} + runner = PyTorchRunner( model_creator, create_dataloaders, optimizer_creator, loss_creator, - train_function=mock_function) + training_operator_cls=MockOperator) runner.setup() - runner.step() - runner.step() - result = runner.step() - self.assertEqual(mock_function.call_count, 3) - self.assertEqual(result["epoch"], 3) + runner.train_epoch(num_steps=1) + runner.train_epoch(num_steps=1) + result = runner.train_epoch() + self.assertEqual(runner.training_operator.count, 3) + self.assertEqual(result["count"], 3) self.assertEqual(runner.stats()["epoch"], 3) def testGivens(self): + class MockOperator(TrainingOperator): + def setup(self, config): + self.train_epoch = MagicMock(returns=dict(mean_accuracy=10)) + self.validate = MagicMock(returns=dict(mean_accuracy=10)) + def three_model_creator(config): return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1) @@ -88,8 +105,12 @@ class TestPyTorchRunner(unittest.TestCase): ] return opts[0], opts[1], opts[2] - runner = PyTorchRunner(three_model_creator, single_loader, - three_optimizer_creator, loss_creator) + runner = PyTorchRunner( + three_model_creator, + single_loader, + three_optimizer_creator, + loss_creator, + training_operator_cls=MockOperator) runner.setup() self.assertEqual(len(runner.given_models), 3) @@ -121,7 +142,7 @@ class TestPyTorchRunner(unittest.TestCase): runner = PyTorchRunner(model_creator, single_loader, optimizer_creator, loss_creator) runner.setup() - runner.step() + runner.train_epoch() with self.assertRaises(ValueError): runner.validate() @@ -132,7 +153,7 @@ class TestPyTorchRunner(unittest.TestCase): optimizer_creator, loss_creator=nn.MSELoss) runner.setup() - runner.step() + runner.train_epoch() def testMultiModel(self): def multi_model_creator(config): @@ -146,6 +167,6 @@ class TestPyTorchRunner(unittest.TestCase): runner = PyTorchRunner(multi_model_creator, single_loader, multi_optimizer_creator, loss_creator) - runner.setup() + with self.assertRaises(ValueError): - runner.step() + runner.setup() diff --git a/python/ray/util/sgd/utils.py b/python/ray/util/sgd/utils.py index fc5d3ce8f..5e2f14087 100644 --- a/python/ray/util/sgd/utils.py +++ b/python/ray/util/sgd/utils.py @@ -149,3 +149,11 @@ def check_for_failure(remote_values): except RayActorError as exc: logger.exception(str(exc)) return False + + +def override(interface_class): + def overrider(method): + assert (method.__name__ in dir(interface_class)) + return method + + return overrider