diff --git a/ci/jenkins_tests/run_multi_node_tests.sh b/ci/jenkins_tests/run_multi_node_tests.sh index 1ab086792..a07e36f87 100755 --- a/ci/jenkins_tests/run_multi_node_tests.sh +++ b/ci/jenkins_tests/run_multi_node_tests.sh @@ -31,25 +31,4 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=60G --memory=60G $DOCKER_SHA \ ######################## SGD TESTS ################################# $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \ - --batch-size=1 --strategy=simple - -$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \ - --batch-size=1 --strategy=ps - -$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/experimental/sgd/test_save_and_restore.py --num-iters=2 \ - --batch-size=1 --strategy=simple - -$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/experimental/sgd/test_save_and_restore.py --num-iters=2 \ - --batch-size=1 --strategy=ps - -$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \ - --num-workers=1 --devices-per-worker=1 --strategy=ps - -$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \ - --num-workers=1 --devices-per-worker=1 --strategy=ps --tune + python -m pytest /ray/python/ray/experimental/sgd/tests diff --git a/doc/source/conf.py b/doc/source/conf.py index e0bd2c6da..b0ae3416d 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -53,6 +53,10 @@ MOCK_MODULES = [ "tensorflow.python", "tensorflow.python.client", "tensorflow.python.util", + "torch", + "torch.distributed", + "torch.nn", + "torch.utils.data", ] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.Mock() diff --git a/doc/source/distributed_training.rst b/doc/source/distributed_training.rst new file mode 100644 index 000000000..61f3442eb --- /dev/null +++ b/doc/source/distributed_training.rst @@ -0,0 +1,48 @@ +Distributed Training (Experimental) +=================================== + + +Ray includes abstractions for distributed model training that integrate with +deep learning frameworks, such as PyTorch. + +Ray Train is built on top of the Ray task and actor abstractions to provide +seamless integration into existing Ray applications. + +PyTorch Interface +----------------- + +To use Ray Train with PyTorch, pass model and data creator functions to the +``ray.experimental.sgd.pytorch.PyTorchTrainer`` class. +To drive the distributed training, ``trainer.train()`` can be called +repeatedly. + +.. code-block:: python + + model_creator = lambda config: YourPyTorchModel() + data_creator = lambda config: YourTrainingSet(), YourValidationSet() + + trainer = PyTorchTrainer( + model_creator, + data_creator, + optimizer_creator=utils.sgd_mse_optimizer, + config={"lr": 1e-4}, + num_replicas=2, + resources_per_replica=Resources(num_gpus=1), + batch_size=16, + backend="auto") + + for i in range(NUM_EPOCHS): + trainer.train() + +Under the hood, Ray Train will create *replicas* of your model +(controlled by ``num_replicas``) which are each managed by a worker. +Multiple devices (e.g. GPUs) can be managed by each replica (controlled by ``resources_per_replica``), +which allows training of lage models across multiple GPUs. +The ``PyTorchTrainer`` class coordinates the distributed computation and training to improve the model. + +The full documentation for ``PyTorchTrainer`` is as follows: + +.. autoclass:: ray.experimental.sgd.pytorch.PyTorchTrainer + :members: + + .. automethod:: __init__ diff --git a/doc/source/index.rst b/doc/source/index.rst index a90e0224b..a8efb7a53 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -42,7 +42,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin - `Tune`_: Scalable Hyperparameter Search - `RLlib`_: Scalable Reinforcement Learning -- `Distributed Training `__ +- `Distributed Training `__ .. _`Tune`: tune.html .. _`RLlib`: rllib.html @@ -107,6 +107,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin :maxdepth: 1 :caption: Other Libraries + distributed_training.rst distributed_sgd.rst pandas_on_ray.rst diff --git a/python/ray/experimental/sgd/pytorch/__init__.py b/python/ray/experimental/sgd/pytorch/__init__.py new file mode 100644 index 000000000..74a33016d --- /dev/null +++ b/python/ray/experimental/sgd/pytorch/__init__.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.experimental.sgd.pytorch.pytorch_trainer import PyTorchTrainer +from ray.experimental.sgd.pytorch.utils import Resources + +__all__ = ["PyTorchTrainer", "Resources"] diff --git a/python/ray/experimental/sgd/pytorch/pytorch_runner.py b/python/ray/experimental/sgd/pytorch/pytorch_runner.py new file mode 100644 index 000000000..5fe4ba100 --- /dev/null +++ b/python/ray/experimental/sgd/pytorch/pytorch_runner.py @@ -0,0 +1,182 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import os +import torch +import torch.distributed as dist +import torch.utils.data + +import ray +from ray.experimental.sgd.pytorch import utils + +logger = logging.getLogger(__name__) + + +class PyTorchRunner(object): + """Manages a distributed PyTorch model replica""" + + def __init__(self, + model_creator, + data_creator, + optimizer_creator, + config=None, + batch_size=16, + backend="gloo"): + """Initializes the runner. + + Args: + model_creator (dict -> torch.nn.Module): creates the model using + the config. + data_creator (dict -> Dataset, Dataset): creates the training and + validation data sets using the config. + optimizer_creator (torch.nn.Module, dict -> loss, optimizer): + creates the loss and optimizer using the model and the config. + config (dict): configuration passed to 'model_creator', + 'data_creator', and 'optimizer_creator'. + batch_size (int): batch size used in an update. + backend (string): backend used by distributed PyTorch. + """ + + self.model_creator = model_creator + self.data_creator = data_creator + self.optimizer_creator = optimizer_creator + self.config = {} if config is None else config + self.batch_size = batch_size + self.backend = backend + self.verbose = True + + self.epoch = 0 + self._timers = { + k: utils.TimerStat(window_size=1) + for k in [ + "setup_proc", "setup_model", "get_state", "set_state", + "validation", "training" + ] + } + + def setup(self, url, world_rank, world_size): + """Connects to the distributed PyTorch backend and initializes the model. + + Args: + url (str): the URL used to connect to distributed PyTorch. + world_rank (int): the index of the runner. + world_size (int): the total number of runners. + """ + self._setup_distributed_pytorch(url, world_rank, world_size) + self._setup_training() + + def _setup_distributed_pytorch(self, url, world_rank, world_size): + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + with self._timers["setup_proc"]: + self.world_rank = world_rank + logger.debug( + "Connecting to {} world_rank: {} world_size: {}".format( + url, world_rank, world_size)) + logger.debug("using {}".format(self.backend)) + dist.init_process_group( + backend=self.backend, + init_method=url, + rank=world_rank, + world_size=world_size) + + def _setup_training(self): + logger.debug("Creating model") + self.model = self.model_creator(self.config) + if torch.cuda.is_available(): + self.model = torch.nn.parallel.DistributedDataParallel( + self.model.cuda()) + else: + self.model = torch.nn.parallel.DistributedDataParallelCPU( + self.model) + + logger.debug("Creating optimizer") + self.criterion, self.optimizer = self.optimizer_creator( + self.model, self.config) + + if torch.cuda.is_available(): + self.criterion = self.criterion.cuda() + + logger.debug("Creating dataset") + self.training_set, self.validation_set = self.data_creator(self.config) + + # TODO: make num_workers configurable + self.train_sampler = torch.utils.data.distributed.DistributedSampler( + self.training_set) + self.train_loader = torch.utils.data.DataLoader( + self.training_set, + batch_size=self.batch_size, + shuffle=(self.train_sampler is None), + num_workers=2, + pin_memory=False, + sampler=self.train_sampler) + + self.validation_sampler = ( + torch.utils.data.distributed.DistributedSampler( + self.validation_set)) + self.validation_loader = torch.utils.data.DataLoader( + self.validation_set, + batch_size=self.batch_size, + shuffle=(self.validation_sampler is None), + num_workers=2, + pin_memory=False, + sampler=self.validation_sampler) + + def get_node_ip(self): + """Returns the IP address of the current node""" + return ray.services.get_node_ip_address() + + def step(self): + """Runs a training epoch and updates the model parameters""" + logger.debug("Starting step") + self.train_sampler.set_epoch(self.epoch) + + logger.debug("Begin Training Epoch {}".format(self.epoch + 1)) + with self._timers["training"]: + train_stats = utils.train(self.train_loader, self.model, + self.criterion, self.optimizer) + train_stats["epoch"] = self.epoch + + self.epoch += 1 + + train_stats.update(self.stats()) + return train_stats + + def validate(self): + """Evaluates the model on the validation data set""" + with self._timers["validation"]: + validation_stats = utils.validate(self.validation_loader, + self.model, self.criterion) + + validation_stats.update(self.stats()) + return validation_stats + + def stats(self): + """Returns a dictionary of statistics collected""" + stats = {"epoch": self.epoch} + for k, t in self._timers.items(): + stats[k + "_time_mean"] = t.mean + stats[k + "_time_total"] = t.sum + t.reset() + return stats + + def get_state(self): + """Returns the state of the runner""" + return { + "epoch": self.epoch, + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + "stats": self.stats() + } + + def set_state(self, state): + """Sets the state of the model""" + # TODO: restore timer stats + self.model.load_state_dict(state["model"]) + self.optimizer.load_state_dict(state["optimizer"]) + self.epoch = state["stats"]["epoch"] + + def shutdown(self): + """Attempts to shut down the worker""" + dist.destroy_process_group() diff --git a/python/ray/experimental/sgd/pytorch/pytorch_trainer.py b/python/ray/experimental/sgd/pytorch/pytorch_trainer.py new file mode 100644 index 000000000..073ad3d34 --- /dev/null +++ b/python/ray/experimental/sgd/pytorch/pytorch_trainer.py @@ -0,0 +1,150 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import sys +import torch +import logging + +import ray + +from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner +from ray.experimental.sgd.pytorch import utils + +logger = logging.getLogger(__name__) + + +class PyTorchTrainer(object): + """Train a PyTorch model using distributed PyTorch. + + Launches a set of actors which connect via distributed PyTorch and + coordinate gradient updates to train the provided model. + """ + + def __init__(self, + model_creator, + data_creator, + optimizer_creator=utils.sgd_mse_optimizer, + config=None, + num_replicas=1, + resources_per_replica=None, + batch_size=16, + backend="auto"): + """Sets up the PyTorch trainer. + + Args: + model_creator (dict -> torch.nn.Module): creates the model + using the config. + data_creator (dict -> Dataset, Dataset): creates the training + and validation data sets using the config. + optimizer_creator (torch.nn.Module, dict -> loss, optimizer): + creates the loss and optimizer using the model and the config. + config (dict): configuration passed to 'model_creator', + 'data_creator', and 'optimizer_creator'. + num_replicas (int): the number of workers used in distributed + training. + resources_per_replica (Resources): resources used by each worker. + Defaults to Resources(num_cpus=1). + batch_size (int): batch size for an update. + backend (string): backend used by distributed PyTorch. + """ + # TODO: add support for mixed precision + # TODO: add support for callbacks + if sys.platform == "darwin": + raise Exception( + ("Distributed PyTorch is not supported on macOS. For more " + "information, see " + "https://github.com/pytorch/examples/issues/467.")) + + self.model_creator = model_creator + self.config = {} if config is None else config + self.optimizer_timer = utils.TimerStat(window_size=1) + + if resources_per_replica is None: + resources_per_replica = utils.Resources( + num_cpus=1, num_gpus=0, resources={}) + + if backend == "auto": + backend = "nccl" if resources_per_replica.num_gpus > 0 else "gloo" + + Runner = ray.remote( + num_cpus=resources_per_replica.num_cpus, + num_gpus=resources_per_replica.num_gpus, + resources=resources_per_replica.resources)(PyTorchRunner) + + batch_size_per_replica = batch_size // num_replicas + if batch_size % num_replicas > 0: + new_batch_size = batch_size_per_replica * num_replicas + logger.warn( + ("Changing batch size from {old_batch_size} to " + "{new_batch_size} to evenly distribute batches across " + "{num_replicas} replicas.").format( + old_batch_size=batch_size, + new_batch_size=new_batch_size, + num_replicas=num_replicas)) + + self.workers = [ + Runner.remote(model_creator, data_creator, optimizer_creator, + self.config, batch_size_per_replica, backend) + for i in range(num_replicas) + ] + + ip = ray.get(self.workers[0].get_node_ip.remote()) + port = utils.find_free_port() + address = "tcp://{ip}:{port}".format(ip=ip, port=port) + + # Get setup tasks in order to throw errors on failure + ray.get([ + worker.setup.remote(address, i, len(self.workers)) + for i, worker in enumerate(self.workers) + ]) + + def train(self): + """Runs a training epoch""" + with self.optimizer_timer: + worker_stats = ray.get([w.step.remote() for w in self.workers]) + + train_stats = worker_stats[0].copy() + train_stats["train_loss"] = np.mean( + [s["train_loss"] for s in worker_stats]) + return train_stats + + def validate(self): + """Evaluates the model on the validation data set""" + worker_stats = ray.get([w.validate.remote() for w in self.workers]) + validation_stats = worker_stats[0].copy() + validation_stats["validation_loss"] = np.mean( + [s["validation_loss"] for s in worker_stats]) + return validation_stats + + def get_model(self): + """Returns the learned model""" + model = self.model_creator(self.config) + state = ray.get(self.workers[0].get_state.remote()) + + # Remove module. prefix added by distrbuted pytorch + state_dict = { + k.replace("module.", ""): v + for k, v in state["model"].items() + } + + model.load_state_dict(state_dict) + return model + + def save(self, ckpt): + """Saves the model at the provided checkpoint""" + state = ray.get(self.workers[0].get_state.remote()) + torch.save(state, ckpt) + + def restore(self, ckpt): + """Restores the model from the provided checkpoint""" + state = torch.load(ckpt) + state_id = ray.put(state) + ray.get([worker.set_state.remote(state_id) for worker in self.workers]) + + def shutdown(self): + """Shuts down workers and releases resources""" + for worker in self.workers: + worker.shutdown.remote() + worker.__ray_terminate__.remote() diff --git a/python/ray/experimental/sgd/pytorch/utils.py b/python/ray/experimental/sgd/pytorch/utils.py new file mode 100644 index 000000000..f7c6e4aba --- /dev/null +++ b/python/ray/experimental/sgd/pytorch/utils.py @@ -0,0 +1,240 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +from contextlib import closing +import numpy as np +import socket +import time +import torch +import torch.nn as nn + + +def train(train_iterator, model, criterion, optimizer): + """Runs 1 training epoch""" + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + + timers = {k: TimerStat() for k in ["d2h", "fwd", "grad", "apply"]} + + # switch to train mode + model.train() + + end = time.time() + + for i, (features, target) in enumerate(train_iterator): + # measure data loading time + data_time.update(time.time() - end) + + # Create non_blocking tensors for distributed training + with timers["d2h"]: + if torch.cuda.is_available(): + features = features.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # compute output + with timers["fwd"]: + output = model(features) + loss = criterion(output, target) + + # measure accuracy and record loss + losses.update(loss.item(), features.size(0)) + + with timers["grad"]: + # compute gradients in a backward pass + optimizer.zero_grad() + loss.backward() + + with timers["apply"]: + # Call step of optimizer to update model params + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + stats = { + "batch_time": batch_time.avg, + "batch_processed": losses.count, + "train_loss": losses.avg, + "data_time": data_time.avg, + } + stats.update({k: t.mean for k, t in timers.items()}) + return stats + + +def validate(val_loader, model, criterion): + batch_time = AverageMeter() + losses = AverageMeter() + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, (features, target) in enumerate(val_loader): + + if torch.cuda.is_available(): + features = features.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # compute output + output = model(features) + loss = criterion(output, target) + + # measure accuracy and record loss + losses.update(loss.item(), features.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + stats = {"batch_time": batch_time.avg, "validation_loss": losses.avg} + return stats + + +class TimerStat(object): + """A running stat for conveniently logging the duration of a code block. + + Note that this class is *not* thread-safe. + + Examples: + Time a call to 'time.sleep'. + + >>> import time + >>> sleep_timer = TimerStat() + >>> with sleep_timer: + ... time.sleep(1) + >>> round(sleep_timer.mean) + 1 + """ + + def __init__(self, window_size=10): + self._window_size = window_size + self._samples = [] + self._units_processed = [] + self._start_time = None + self._total_time = 0.0 + self.count = 0 + + def __enter__(self): + assert self._start_time is None, "concurrent updates not supported" + self._start_time = time.time() + + def __exit__(self, type, value, tb): + assert self._start_time is not None + time_delta = time.time() - self._start_time + self.push(time_delta) + self._start_time = None + + def push(self, time_delta): + self._samples.append(time_delta) + if len(self._samples) > self._window_size: + self._samples.pop(0) + self.count += 1 + self._total_time += time_delta + + def push_units_processed(self, n): + self._units_processed.append(n) + if len(self._units_processed) > self._window_size: + self._units_processed.pop(0) + + @property + def mean(self): + return np.mean(self._samples) + + @property + def median(self): + return np.median(self._samples) + + @property + def sum(self): + return np.sum(self._samples) + + @property + def max(self): + return np.max(self._samples) + + @property + def first(self): + return self._samples[0] if self._samples else None + + @property + def last(self): + return self._samples[-1] if self._samples else None + + @property + def size(self): + return len(self._samples) + + @property + def mean_units_processed(self): + return float(np.mean(self._units_processed)) + + @property + def mean_throughput(self): + time_total = sum(self._samples) + if not time_total: + return 0.0 + return sum(self._units_processed) / time_total + + def reset(self): + self._samples = [] + self._units_processed = [] + self._start_time = None + self._total_time = 0.0 + self.count = 0 + + +def find_free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class Resources( + namedtuple("Resources", ["num_cpus", "num_gpus", "resources"])): + __slots__ = () + + def __new__(cls, num_cpus=1, num_gpus=0, resources=None): + if resources is None: + resources = {} + + return super(Resources, cls).__new__(cls, num_cpus, num_gpus, + resources) + + +def sgd_mse_optimizer(model, config): + """Returns the mean squared error criterion and SGD optimizer. + + Args: + model (torch.nn.Module): the model to optimize. + config (dict): configuration for the optimizer. + lr (float): the learning rate. defaults to 0.01. + """ + learning_rate = config.get("lr", 0.01) + criterion = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + return criterion, optimizer diff --git a/python/ray/experimental/sgd/tests/__init__.py b/python/ray/experimental/sgd/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/ray/experimental/sgd/tests/pytorch_utils.py b/python/ray/experimental/sgd/tests/pytorch_utils.py new file mode 100644 index 000000000..6299fff1c --- /dev/null +++ b/python/ray/experimental/sgd/tests/pytorch_utils.py @@ -0,0 +1,40 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.data + + +class LinearDataset(torch.utils.data.Dataset): + """y = a * x + b""" + + def __init__(self, a, b, size=1000): + x = np.random.random(size).astype(np.float32) * 10 + x = np.arange(0, 10, 10 / size, dtype=np.float32) + self.x = torch.from_numpy(x) + self.y = torch.from_numpy(a * x + b) + + def __getitem__(self, index): + return self.x[index, None], self.y[index, None] + + def __len__(self): + return len(self.x) + + +def model_creator(config): + return nn.Linear(1, 1) + + +def optimizer_creator(model, config): + """Returns criterion, optimizer""" + criterion = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) + return criterion, optimizer + + +def data_creator(config): + """Returns training set, validation set""" + return LinearDataset(2, 5), LinearDataset(2, 5, size=400) diff --git a/python/ray/experimental/sgd/tests/test_pytorch.py b/python/ray/experimental/sgd/tests/test_pytorch.py new file mode 100644 index 000000000..faff23f8a --- /dev/null +++ b/python/ray/experimental/sgd/tests/test_pytorch.py @@ -0,0 +1,76 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pytest +import sys +import tempfile +import torch + +from ray.tests.conftest import ray_start_2_cpus # noqa: F401 +from ray.experimental.sgd.pytorch import PyTorchTrainer, Resources + +from ray.experimental.sgd.tests.pytorch_utils import ( + model_creator, optimizer_creator, data_creator) + + +@pytest.mark.skipif( # noqa: F811 + sys.platform == "darwin", reason="Doesn't work on macOS.") +def test_train(ray_start_2_cpus): # noqa: F811 + trainer = PyTorchTrainer( + model_creator, + data_creator, + optimizer_creator, + num_replicas=2, + resources_per_replica=Resources(num_cpus=1)) + train_loss1 = trainer.train()["train_loss"] + validation_loss1 = trainer.validate()["validation_loss"] + + train_loss2 = trainer.train()["train_loss"] + validation_loss2 = trainer.validate()["validation_loss"] + + print(train_loss1, train_loss2) + print(validation_loss1, validation_loss2) + + assert train_loss2 <= train_loss1 + assert validation_loss2 <= validation_loss1 + + +@pytest.mark.skipif( # noqa: F811 + sys.platform == "darwin", reason="Doesn't work on macOS.") +def test_save_and_restore(ray_start_2_cpus): # noqa: F811 + trainer1 = PyTorchTrainer( + model_creator, + data_creator, + optimizer_creator, + num_replicas=2, + resources_per_replica=Resources(num_cpus=1)) + trainer1.train() + + filename = os.path.join(tempfile.mkdtemp(), "checkpoint") + trainer1.save(filename) + + model1 = trainer1.get_model() + + trainer1.shutdown() + + trainer2 = PyTorchTrainer( + model_creator, + data_creator, + optimizer_creator, + num_replicas=2, + resources_per_replica=Resources(num_cpus=1)) + trainer2.restore(filename) + + os.remove(filename) + + model2 = trainer2.get_model() + + model1_state_dict = model1.state_dict() + model2_state_dict = model2.state_dict() + + assert set(model1_state_dict.keys()) == set(model2_state_dict.keys()) + + for k in model1_state_dict: + assert torch.equal(model1_state_dict[k], model2_state_dict[k])