[sgd] Distributed Training via PyTorch (#4797)

Implements distributed SGD using distributed PyTorch.
This commit is contained in:
Peter Schafhalter
2019-06-01 21:39:22 -07:00
committed by Richard Liaw
parent 88bab5d3c4
commit c2ade075a3
11 changed files with 751 additions and 23 deletions
+1 -22
View File
@@ -31,25 +31,4 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=60G --memory=60G $DOCKER_SHA \
######################## SGD TESTS #################################
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \
--batch-size=1 --strategy=simple
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \
--batch-size=1 --strategy=ps
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_save_and_restore.py --num-iters=2 \
--batch-size=1 --strategy=simple
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_save_and_restore.py --num-iters=2 \
--batch-size=1 --strategy=ps
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \
--num-workers=1 --devices-per-worker=1 --strategy=ps
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \
--num-workers=1 --devices-per-worker=1 --strategy=ps --tune
python -m pytest /ray/python/ray/experimental/sgd/tests
+4
View File
@@ -53,6 +53,10 @@ MOCK_MODULES = [
"tensorflow.python",
"tensorflow.python.client",
"tensorflow.python.util",
"torch",
"torch.distributed",
"torch.nn",
"torch.utils.data",
]
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock()
+48
View File
@@ -0,0 +1,48 @@
Distributed Training (Experimental)
===================================
Ray includes abstractions for distributed model training that integrate with
deep learning frameworks, such as PyTorch.
Ray Train is built on top of the Ray task and actor abstractions to provide
seamless integration into existing Ray applications.
PyTorch Interface
-----------------
To use Ray Train with PyTorch, pass model and data creator functions to the
``ray.experimental.sgd.pytorch.PyTorchTrainer`` class.
To drive the distributed training, ``trainer.train()`` can be called
repeatedly.
.. code-block:: python
model_creator = lambda config: YourPyTorchModel()
data_creator = lambda config: YourTrainingSet(), YourValidationSet()
trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator=utils.sgd_mse_optimizer,
config={"lr": 1e-4},
num_replicas=2,
resources_per_replica=Resources(num_gpus=1),
batch_size=16,
backend="auto")
for i in range(NUM_EPOCHS):
trainer.train()
Under the hood, Ray Train will create *replicas* of your model
(controlled by ``num_replicas``) which are each managed by a worker.
Multiple devices (e.g. GPUs) can be managed by each replica (controlled by ``resources_per_replica``),
which allows training of lage models across multiple GPUs.
The ``PyTorchTrainer`` class coordinates the distributed computation and training to improve the model.
The full documentation for ``PyTorchTrainer`` is as follows:
.. autoclass:: ray.experimental.sgd.pytorch.PyTorchTrainer
:members:
.. automethod:: __init__
+2 -1
View File
@@ -42,7 +42,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin
- `Tune`_: Scalable Hyperparameter Search
- `RLlib`_: Scalable Reinforcement Learning
- `Distributed Training <distributed_sgd.html>`__
- `Distributed Training <distributed_training.html>`__
.. _`Tune`: tune.html
.. _`RLlib`: rllib.html
@@ -107,6 +107,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin
:maxdepth: 1
:caption: Other Libraries
distributed_training.rst
distributed_sgd.rst
pandas_on_ray.rst
@@ -0,0 +1,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.experimental.sgd.pytorch.pytorch_trainer import PyTorchTrainer
from ray.experimental.sgd.pytorch.utils import Resources
__all__ = ["PyTorchTrainer", "Resources"]
@@ -0,0 +1,182 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
import torch
import torch.distributed as dist
import torch.utils.data
import ray
from ray.experimental.sgd.pytorch import utils
logger = logging.getLogger(__name__)
class PyTorchRunner(object):
"""Manages a distributed PyTorch model replica"""
def __init__(self,
model_creator,
data_creator,
optimizer_creator,
config=None,
batch_size=16,
backend="gloo"):
"""Initializes the runner.
Args:
model_creator (dict -> torch.nn.Module): creates the model using
the config.
data_creator (dict -> Dataset, Dataset): creates the training and
validation data sets using the config.
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
creates the loss and optimizer using the model and the config.
config (dict): configuration passed to 'model_creator',
'data_creator', and 'optimizer_creator'.
batch_size (int): batch size used in an update.
backend (string): backend used by distributed PyTorch.
"""
self.model_creator = model_creator
self.data_creator = data_creator
self.optimizer_creator = optimizer_creator
self.config = {} if config is None else config
self.batch_size = batch_size
self.backend = backend
self.verbose = True
self.epoch = 0
self._timers = {
k: utils.TimerStat(window_size=1)
for k in [
"setup_proc", "setup_model", "get_state", "set_state",
"validation", "training"
]
}
def setup(self, url, world_rank, world_size):
"""Connects to the distributed PyTorch backend and initializes the model.
Args:
url (str): the URL used to connect to distributed PyTorch.
world_rank (int): the index of the runner.
world_size (int): the total number of runners.
"""
self._setup_distributed_pytorch(url, world_rank, world_size)
self._setup_training()
def _setup_distributed_pytorch(self, url, world_rank, world_size):
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
with self._timers["setup_proc"]:
self.world_rank = world_rank
logger.debug(
"Connecting to {} world_rank: {} world_size: {}".format(
url, world_rank, world_size))
logger.debug("using {}".format(self.backend))
dist.init_process_group(
backend=self.backend,
init_method=url,
rank=world_rank,
world_size=world_size)
def _setup_training(self):
logger.debug("Creating model")
self.model = self.model_creator(self.config)
if torch.cuda.is_available():
self.model = torch.nn.parallel.DistributedDataParallel(
self.model.cuda())
else:
self.model = torch.nn.parallel.DistributedDataParallelCPU(
self.model)
logger.debug("Creating optimizer")
self.criterion, self.optimizer = self.optimizer_creator(
self.model, self.config)
if torch.cuda.is_available():
self.criterion = self.criterion.cuda()
logger.debug("Creating dataset")
self.training_set, self.validation_set = self.data_creator(self.config)
# TODO: make num_workers configurable
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
self.training_set)
self.train_loader = torch.utils.data.DataLoader(
self.training_set,
batch_size=self.batch_size,
shuffle=(self.train_sampler is None),
num_workers=2,
pin_memory=False,
sampler=self.train_sampler)
self.validation_sampler = (
torch.utils.data.distributed.DistributedSampler(
self.validation_set))
self.validation_loader = torch.utils.data.DataLoader(
self.validation_set,
batch_size=self.batch_size,
shuffle=(self.validation_sampler is None),
num_workers=2,
pin_memory=False,
sampler=self.validation_sampler)
def get_node_ip(self):
"""Returns the IP address of the current node"""
return ray.services.get_node_ip_address()
def step(self):
"""Runs a training epoch and updates the model parameters"""
logger.debug("Starting step")
self.train_sampler.set_epoch(self.epoch)
logger.debug("Begin Training Epoch {}".format(self.epoch + 1))
with self._timers["training"]:
train_stats = utils.train(self.train_loader, self.model,
self.criterion, self.optimizer)
train_stats["epoch"] = self.epoch
self.epoch += 1
train_stats.update(self.stats())
return train_stats
def validate(self):
"""Evaluates the model on the validation data set"""
with self._timers["validation"]:
validation_stats = utils.validate(self.validation_loader,
self.model, self.criterion)
validation_stats.update(self.stats())
return validation_stats
def stats(self):
"""Returns a dictionary of statistics collected"""
stats = {"epoch": self.epoch}
for k, t in self._timers.items():
stats[k + "_time_mean"] = t.mean
stats[k + "_time_total"] = t.sum
t.reset()
return stats
def get_state(self):
"""Returns the state of the runner"""
return {
"epoch": self.epoch,
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"stats": self.stats()
}
def set_state(self, state):
"""Sets the state of the model"""
# TODO: restore timer stats
self.model.load_state_dict(state["model"])
self.optimizer.load_state_dict(state["optimizer"])
self.epoch = state["stats"]["epoch"]
def shutdown(self):
"""Attempts to shut down the worker"""
dist.destroy_process_group()
@@ -0,0 +1,150 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import sys
import torch
import logging
import ray
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
from ray.experimental.sgd.pytorch import utils
logger = logging.getLogger(__name__)
class PyTorchTrainer(object):
"""Train a PyTorch model using distributed PyTorch.
Launches a set of actors which connect via distributed PyTorch and
coordinate gradient updates to train the provided model.
"""
def __init__(self,
model_creator,
data_creator,
optimizer_creator=utils.sgd_mse_optimizer,
config=None,
num_replicas=1,
resources_per_replica=None,
batch_size=16,
backend="auto"):
"""Sets up the PyTorch trainer.
Args:
model_creator (dict -> torch.nn.Module): creates the model
using the config.
data_creator (dict -> Dataset, Dataset): creates the training
and validation data sets using the config.
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
creates the loss and optimizer using the model and the config.
config (dict): configuration passed to 'model_creator',
'data_creator', and 'optimizer_creator'.
num_replicas (int): the number of workers used in distributed
training.
resources_per_replica (Resources): resources used by each worker.
Defaults to Resources(num_cpus=1).
batch_size (int): batch size for an update.
backend (string): backend used by distributed PyTorch.
"""
# TODO: add support for mixed precision
# TODO: add support for callbacks
if sys.platform == "darwin":
raise Exception(
("Distributed PyTorch is not supported on macOS. For more "
"information, see "
"https://github.com/pytorch/examples/issues/467."))
self.model_creator = model_creator
self.config = {} if config is None else config
self.optimizer_timer = utils.TimerStat(window_size=1)
if resources_per_replica is None:
resources_per_replica = utils.Resources(
num_cpus=1, num_gpus=0, resources={})
if backend == "auto":
backend = "nccl" if resources_per_replica.num_gpus > 0 else "gloo"
Runner = ray.remote(
num_cpus=resources_per_replica.num_cpus,
num_gpus=resources_per_replica.num_gpus,
resources=resources_per_replica.resources)(PyTorchRunner)
batch_size_per_replica = batch_size // num_replicas
if batch_size % num_replicas > 0:
new_batch_size = batch_size_per_replica * num_replicas
logger.warn(
("Changing batch size from {old_batch_size} to "
"{new_batch_size} to evenly distribute batches across "
"{num_replicas} replicas.").format(
old_batch_size=batch_size,
new_batch_size=new_batch_size,
num_replicas=num_replicas))
self.workers = [
Runner.remote(model_creator, data_creator, optimizer_creator,
self.config, batch_size_per_replica, backend)
for i in range(num_replicas)
]
ip = ray.get(self.workers[0].get_node_ip.remote())
port = utils.find_free_port()
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
# Get setup tasks in order to throw errors on failure
ray.get([
worker.setup.remote(address, i, len(self.workers))
for i, worker in enumerate(self.workers)
])
def train(self):
"""Runs a training epoch"""
with self.optimizer_timer:
worker_stats = ray.get([w.step.remote() for w in self.workers])
train_stats = worker_stats[0].copy()
train_stats["train_loss"] = np.mean(
[s["train_loss"] for s in worker_stats])
return train_stats
def validate(self):
"""Evaluates the model on the validation data set"""
worker_stats = ray.get([w.validate.remote() for w in self.workers])
validation_stats = worker_stats[0].copy()
validation_stats["validation_loss"] = np.mean(
[s["validation_loss"] for s in worker_stats])
return validation_stats
def get_model(self):
"""Returns the learned model"""
model = self.model_creator(self.config)
state = ray.get(self.workers[0].get_state.remote())
# Remove module. prefix added by distrbuted pytorch
state_dict = {
k.replace("module.", ""): v
for k, v in state["model"].items()
}
model.load_state_dict(state_dict)
return model
def save(self, ckpt):
"""Saves the model at the provided checkpoint"""
state = ray.get(self.workers[0].get_state.remote())
torch.save(state, ckpt)
def restore(self, ckpt):
"""Restores the model from the provided checkpoint"""
state = torch.load(ckpt)
state_id = ray.put(state)
ray.get([worker.set_state.remote(state_id) for worker in self.workers])
def shutdown(self):
"""Shuts down workers and releases resources"""
for worker in self.workers:
worker.shutdown.remote()
worker.__ray_terminate__.remote()
@@ -0,0 +1,240 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
from contextlib import closing
import numpy as np
import socket
import time
import torch
import torch.nn as nn
def train(train_iterator, model, criterion, optimizer):
"""Runs 1 training epoch"""
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
timers = {k: TimerStat() for k in ["d2h", "fwd", "grad", "apply"]}
# switch to train mode
model.train()
end = time.time()
for i, (features, target) in enumerate(train_iterator):
# measure data loading time
data_time.update(time.time() - end)
# Create non_blocking tensors for distributed training
with timers["d2h"]:
if torch.cuda.is_available():
features = features.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
with timers["fwd"]:
output = model(features)
loss = criterion(output, target)
# measure accuracy and record loss
losses.update(loss.item(), features.size(0))
with timers["grad"]:
# compute gradients in a backward pass
optimizer.zero_grad()
loss.backward()
with timers["apply"]:
# Call step of optimizer to update model params
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
stats = {
"batch_time": batch_time.avg,
"batch_processed": losses.count,
"train_loss": losses.avg,
"data_time": data_time.avg,
}
stats.update({k: t.mean for k, t in timers.items()})
return stats
def validate(val_loader, model, criterion):
batch_time = AverageMeter()
losses = AverageMeter()
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (features, target) in enumerate(val_loader):
if torch.cuda.is_available():
features = features.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(features)
loss = criterion(output, target)
# measure accuracy and record loss
losses.update(loss.item(), features.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
stats = {"batch_time": batch_time.avg, "validation_loss": losses.avg}
return stats
class TimerStat(object):
"""A running stat for conveniently logging the duration of a code block.
Note that this class is *not* thread-safe.
Examples:
Time a call to 'time.sleep'.
>>> import time
>>> sleep_timer = TimerStat()
>>> with sleep_timer:
... time.sleep(1)
>>> round(sleep_timer.mean)
1
"""
def __init__(self, window_size=10):
self._window_size = window_size
self._samples = []
self._units_processed = []
self._start_time = None
self._total_time = 0.0
self.count = 0
def __enter__(self):
assert self._start_time is None, "concurrent updates not supported"
self._start_time = time.time()
def __exit__(self, type, value, tb):
assert self._start_time is not None
time_delta = time.time() - self._start_time
self.push(time_delta)
self._start_time = None
def push(self, time_delta):
self._samples.append(time_delta)
if len(self._samples) > self._window_size:
self._samples.pop(0)
self.count += 1
self._total_time += time_delta
def push_units_processed(self, n):
self._units_processed.append(n)
if len(self._units_processed) > self._window_size:
self._units_processed.pop(0)
@property
def mean(self):
return np.mean(self._samples)
@property
def median(self):
return np.median(self._samples)
@property
def sum(self):
return np.sum(self._samples)
@property
def max(self):
return np.max(self._samples)
@property
def first(self):
return self._samples[0] if self._samples else None
@property
def last(self):
return self._samples[-1] if self._samples else None
@property
def size(self):
return len(self._samples)
@property
def mean_units_processed(self):
return float(np.mean(self._units_processed))
@property
def mean_throughput(self):
time_total = sum(self._samples)
if not time_total:
return 0.0
return sum(self._units_processed) / time_total
def reset(self):
self._samples = []
self._units_processed = []
self._start_time = None
self._total_time = 0.0
self.count = 0
def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class Resources(
namedtuple("Resources", ["num_cpus", "num_gpus", "resources"])):
__slots__ = ()
def __new__(cls, num_cpus=1, num_gpus=0, resources=None):
if resources is None:
resources = {}
return super(Resources, cls).__new__(cls, num_cpus, num_gpus,
resources)
def sgd_mse_optimizer(model, config):
"""Returns the mean squared error criterion and SGD optimizer.
Args:
model (torch.nn.Module): the model to optimize.
config (dict): configuration for the optimizer.
lr (float): the learning rate. defaults to 0.01.
"""
learning_rate = config.get("lr", 0.01)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
return criterion, optimizer
@@ -0,0 +1,40 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
class LinearDataset(torch.utils.data.Dataset):
"""y = a * x + b"""
def __init__(self, a, b, size=1000):
x = np.random.random(size).astype(np.float32) * 10
x = np.arange(0, 10, 10 / size, dtype=np.float32)
self.x = torch.from_numpy(x)
self.y = torch.from_numpy(a * x + b)
def __getitem__(self, index):
return self.x[index, None], self.y[index, None]
def __len__(self):
return len(self.x)
def model_creator(config):
return nn.Linear(1, 1)
def optimizer_creator(model, config):
"""Returns criterion, optimizer"""
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
return criterion, optimizer
def data_creator(config):
"""Returns training set, validation set"""
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
@@ -0,0 +1,76 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pytest
import sys
import tempfile
import torch
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
from ray.experimental.sgd.pytorch import PyTorchTrainer, Resources
from ray.experimental.sgd.tests.pytorch_utils import (
model_creator, optimizer_creator, data_creator)
@pytest.mark.skipif( # noqa: F811
sys.platform == "darwin", reason="Doesn't work on macOS.")
def test_train(ray_start_2_cpus): # noqa: F811
trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
num_replicas=2,
resources_per_replica=Resources(num_cpus=1))
train_loss1 = trainer.train()["train_loss"]
validation_loss1 = trainer.validate()["validation_loss"]
train_loss2 = trainer.train()["train_loss"]
validation_loss2 = trainer.validate()["validation_loss"]
print(train_loss1, train_loss2)
print(validation_loss1, validation_loss2)
assert train_loss2 <= train_loss1
assert validation_loss2 <= validation_loss1
@pytest.mark.skipif( # noqa: F811
sys.platform == "darwin", reason="Doesn't work on macOS.")
def test_save_and_restore(ray_start_2_cpus): # noqa: F811
trainer1 = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
num_replicas=2,
resources_per_replica=Resources(num_cpus=1))
trainer1.train()
filename = os.path.join(tempfile.mkdtemp(), "checkpoint")
trainer1.save(filename)
model1 = trainer1.get_model()
trainer1.shutdown()
trainer2 = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
num_replicas=2,
resources_per_replica=Resources(num_cpus=1))
trainer2.restore(filename)
os.remove(filename)
model2 = trainer2.get_model()
model1_state_dict = model1.state_dict()
model2_state_dict = model2.state_dict()
assert set(model1_state_dict.keys()) == set(model2_state_dict.keys())
for k in model1_state_dict:
assert torch.equal(model1_state_dict[k], model2_state_dict[k])