[sgd] Distributed Training via PyTorch (#4797)

Implements distributed SGD using distributed PyTorch.
This commit is contained in:
Peter Schafhalter
2019-06-01 21:39:22 -07:00
committed by Richard Liaw
parent 88bab5d3c4
commit c2ade075a3
11 changed files with 751 additions and 23 deletions
@@ -0,0 +1,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.experimental.sgd.pytorch.pytorch_trainer import PyTorchTrainer
from ray.experimental.sgd.pytorch.utils import Resources
__all__ = ["PyTorchTrainer", "Resources"]
@@ -0,0 +1,182 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
import torch
import torch.distributed as dist
import torch.utils.data
import ray
from ray.experimental.sgd.pytorch import utils
logger = logging.getLogger(__name__)
class PyTorchRunner(object):
"""Manages a distributed PyTorch model replica"""
def __init__(self,
model_creator,
data_creator,
optimizer_creator,
config=None,
batch_size=16,
backend="gloo"):
"""Initializes the runner.
Args:
model_creator (dict -> torch.nn.Module): creates the model using
the config.
data_creator (dict -> Dataset, Dataset): creates the training and
validation data sets using the config.
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
creates the loss and optimizer using the model and the config.
config (dict): configuration passed to 'model_creator',
'data_creator', and 'optimizer_creator'.
batch_size (int): batch size used in an update.
backend (string): backend used by distributed PyTorch.
"""
self.model_creator = model_creator
self.data_creator = data_creator
self.optimizer_creator = optimizer_creator
self.config = {} if config is None else config
self.batch_size = batch_size
self.backend = backend
self.verbose = True
self.epoch = 0
self._timers = {
k: utils.TimerStat(window_size=1)
for k in [
"setup_proc", "setup_model", "get_state", "set_state",
"validation", "training"
]
}
def setup(self, url, world_rank, world_size):
"""Connects to the distributed PyTorch backend and initializes the model.
Args:
url (str): the URL used to connect to distributed PyTorch.
world_rank (int): the index of the runner.
world_size (int): the total number of runners.
"""
self._setup_distributed_pytorch(url, world_rank, world_size)
self._setup_training()
def _setup_distributed_pytorch(self, url, world_rank, world_size):
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
with self._timers["setup_proc"]:
self.world_rank = world_rank
logger.debug(
"Connecting to {} world_rank: {} world_size: {}".format(
url, world_rank, world_size))
logger.debug("using {}".format(self.backend))
dist.init_process_group(
backend=self.backend,
init_method=url,
rank=world_rank,
world_size=world_size)
def _setup_training(self):
logger.debug("Creating model")
self.model = self.model_creator(self.config)
if torch.cuda.is_available():
self.model = torch.nn.parallel.DistributedDataParallel(
self.model.cuda())
else:
self.model = torch.nn.parallel.DistributedDataParallelCPU(
self.model)
logger.debug("Creating optimizer")
self.criterion, self.optimizer = self.optimizer_creator(
self.model, self.config)
if torch.cuda.is_available():
self.criterion = self.criterion.cuda()
logger.debug("Creating dataset")
self.training_set, self.validation_set = self.data_creator(self.config)
# TODO: make num_workers configurable
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
self.training_set)
self.train_loader = torch.utils.data.DataLoader(
self.training_set,
batch_size=self.batch_size,
shuffle=(self.train_sampler is None),
num_workers=2,
pin_memory=False,
sampler=self.train_sampler)
self.validation_sampler = (
torch.utils.data.distributed.DistributedSampler(
self.validation_set))
self.validation_loader = torch.utils.data.DataLoader(
self.validation_set,
batch_size=self.batch_size,
shuffle=(self.validation_sampler is None),
num_workers=2,
pin_memory=False,
sampler=self.validation_sampler)
def get_node_ip(self):
"""Returns the IP address of the current node"""
return ray.services.get_node_ip_address()
def step(self):
"""Runs a training epoch and updates the model parameters"""
logger.debug("Starting step")
self.train_sampler.set_epoch(self.epoch)
logger.debug("Begin Training Epoch {}".format(self.epoch + 1))
with self._timers["training"]:
train_stats = utils.train(self.train_loader, self.model,
self.criterion, self.optimizer)
train_stats["epoch"] = self.epoch
self.epoch += 1
train_stats.update(self.stats())
return train_stats
def validate(self):
"""Evaluates the model on the validation data set"""
with self._timers["validation"]:
validation_stats = utils.validate(self.validation_loader,
self.model, self.criterion)
validation_stats.update(self.stats())
return validation_stats
def stats(self):
"""Returns a dictionary of statistics collected"""
stats = {"epoch": self.epoch}
for k, t in self._timers.items():
stats[k + "_time_mean"] = t.mean
stats[k + "_time_total"] = t.sum
t.reset()
return stats
def get_state(self):
"""Returns the state of the runner"""
return {
"epoch": self.epoch,
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"stats": self.stats()
}
def set_state(self, state):
"""Sets the state of the model"""
# TODO: restore timer stats
self.model.load_state_dict(state["model"])
self.optimizer.load_state_dict(state["optimizer"])
self.epoch = state["stats"]["epoch"]
def shutdown(self):
"""Attempts to shut down the worker"""
dist.destroy_process_group()
@@ -0,0 +1,150 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import sys
import torch
import logging
import ray
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
from ray.experimental.sgd.pytorch import utils
logger = logging.getLogger(__name__)
class PyTorchTrainer(object):
"""Train a PyTorch model using distributed PyTorch.
Launches a set of actors which connect via distributed PyTorch and
coordinate gradient updates to train the provided model.
"""
def __init__(self,
model_creator,
data_creator,
optimizer_creator=utils.sgd_mse_optimizer,
config=None,
num_replicas=1,
resources_per_replica=None,
batch_size=16,
backend="auto"):
"""Sets up the PyTorch trainer.
Args:
model_creator (dict -> torch.nn.Module): creates the model
using the config.
data_creator (dict -> Dataset, Dataset): creates the training
and validation data sets using the config.
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
creates the loss and optimizer using the model and the config.
config (dict): configuration passed to 'model_creator',
'data_creator', and 'optimizer_creator'.
num_replicas (int): the number of workers used in distributed
training.
resources_per_replica (Resources): resources used by each worker.
Defaults to Resources(num_cpus=1).
batch_size (int): batch size for an update.
backend (string): backend used by distributed PyTorch.
"""
# TODO: add support for mixed precision
# TODO: add support for callbacks
if sys.platform == "darwin":
raise Exception(
("Distributed PyTorch is not supported on macOS. For more "
"information, see "
"https://github.com/pytorch/examples/issues/467."))
self.model_creator = model_creator
self.config = {} if config is None else config
self.optimizer_timer = utils.TimerStat(window_size=1)
if resources_per_replica is None:
resources_per_replica = utils.Resources(
num_cpus=1, num_gpus=0, resources={})
if backend == "auto":
backend = "nccl" if resources_per_replica.num_gpus > 0 else "gloo"
Runner = ray.remote(
num_cpus=resources_per_replica.num_cpus,
num_gpus=resources_per_replica.num_gpus,
resources=resources_per_replica.resources)(PyTorchRunner)
batch_size_per_replica = batch_size // num_replicas
if batch_size % num_replicas > 0:
new_batch_size = batch_size_per_replica * num_replicas
logger.warn(
("Changing batch size from {old_batch_size} to "
"{new_batch_size} to evenly distribute batches across "
"{num_replicas} replicas.").format(
old_batch_size=batch_size,
new_batch_size=new_batch_size,
num_replicas=num_replicas))
self.workers = [
Runner.remote(model_creator, data_creator, optimizer_creator,
self.config, batch_size_per_replica, backend)
for i in range(num_replicas)
]
ip = ray.get(self.workers[0].get_node_ip.remote())
port = utils.find_free_port()
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
# Get setup tasks in order to throw errors on failure
ray.get([
worker.setup.remote(address, i, len(self.workers))
for i, worker in enumerate(self.workers)
])
def train(self):
"""Runs a training epoch"""
with self.optimizer_timer:
worker_stats = ray.get([w.step.remote() for w in self.workers])
train_stats = worker_stats[0].copy()
train_stats["train_loss"] = np.mean(
[s["train_loss"] for s in worker_stats])
return train_stats
def validate(self):
"""Evaluates the model on the validation data set"""
worker_stats = ray.get([w.validate.remote() for w in self.workers])
validation_stats = worker_stats[0].copy()
validation_stats["validation_loss"] = np.mean(
[s["validation_loss"] for s in worker_stats])
return validation_stats
def get_model(self):
"""Returns the learned model"""
model = self.model_creator(self.config)
state = ray.get(self.workers[0].get_state.remote())
# Remove module. prefix added by distrbuted pytorch
state_dict = {
k.replace("module.", ""): v
for k, v in state["model"].items()
}
model.load_state_dict(state_dict)
return model
def save(self, ckpt):
"""Saves the model at the provided checkpoint"""
state = ray.get(self.workers[0].get_state.remote())
torch.save(state, ckpt)
def restore(self, ckpt):
"""Restores the model from the provided checkpoint"""
state = torch.load(ckpt)
state_id = ray.put(state)
ray.get([worker.set_state.remote(state_id) for worker in self.workers])
def shutdown(self):
"""Shuts down workers and releases resources"""
for worker in self.workers:
worker.shutdown.remote()
worker.__ray_terminate__.remote()
@@ -0,0 +1,240 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
from contextlib import closing
import numpy as np
import socket
import time
import torch
import torch.nn as nn
def train(train_iterator, model, criterion, optimizer):
"""Runs 1 training epoch"""
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
timers = {k: TimerStat() for k in ["d2h", "fwd", "grad", "apply"]}
# switch to train mode
model.train()
end = time.time()
for i, (features, target) in enumerate(train_iterator):
# measure data loading time
data_time.update(time.time() - end)
# Create non_blocking tensors for distributed training
with timers["d2h"]:
if torch.cuda.is_available():
features = features.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
with timers["fwd"]:
output = model(features)
loss = criterion(output, target)
# measure accuracy and record loss
losses.update(loss.item(), features.size(0))
with timers["grad"]:
# compute gradients in a backward pass
optimizer.zero_grad()
loss.backward()
with timers["apply"]:
# Call step of optimizer to update model params
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
stats = {
"batch_time": batch_time.avg,
"batch_processed": losses.count,
"train_loss": losses.avg,
"data_time": data_time.avg,
}
stats.update({k: t.mean for k, t in timers.items()})
return stats
def validate(val_loader, model, criterion):
batch_time = AverageMeter()
losses = AverageMeter()
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (features, target) in enumerate(val_loader):
if torch.cuda.is_available():
features = features.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(features)
loss = criterion(output, target)
# measure accuracy and record loss
losses.update(loss.item(), features.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
stats = {"batch_time": batch_time.avg, "validation_loss": losses.avg}
return stats
class TimerStat(object):
"""A running stat for conveniently logging the duration of a code block.
Note that this class is *not* thread-safe.
Examples:
Time a call to 'time.sleep'.
>>> import time
>>> sleep_timer = TimerStat()
>>> with sleep_timer:
... time.sleep(1)
>>> round(sleep_timer.mean)
1
"""
def __init__(self, window_size=10):
self._window_size = window_size
self._samples = []
self._units_processed = []
self._start_time = None
self._total_time = 0.0
self.count = 0
def __enter__(self):
assert self._start_time is None, "concurrent updates not supported"
self._start_time = time.time()
def __exit__(self, type, value, tb):
assert self._start_time is not None
time_delta = time.time() - self._start_time
self.push(time_delta)
self._start_time = None
def push(self, time_delta):
self._samples.append(time_delta)
if len(self._samples) > self._window_size:
self._samples.pop(0)
self.count += 1
self._total_time += time_delta
def push_units_processed(self, n):
self._units_processed.append(n)
if len(self._units_processed) > self._window_size:
self._units_processed.pop(0)
@property
def mean(self):
return np.mean(self._samples)
@property
def median(self):
return np.median(self._samples)
@property
def sum(self):
return np.sum(self._samples)
@property
def max(self):
return np.max(self._samples)
@property
def first(self):
return self._samples[0] if self._samples else None
@property
def last(self):
return self._samples[-1] if self._samples else None
@property
def size(self):
return len(self._samples)
@property
def mean_units_processed(self):
return float(np.mean(self._units_processed))
@property
def mean_throughput(self):
time_total = sum(self._samples)
if not time_total:
return 0.0
return sum(self._units_processed) / time_total
def reset(self):
self._samples = []
self._units_processed = []
self._start_time = None
self._total_time = 0.0
self.count = 0
def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class Resources(
namedtuple("Resources", ["num_cpus", "num_gpus", "resources"])):
__slots__ = ()
def __new__(cls, num_cpus=1, num_gpus=0, resources=None):
if resources is None:
resources = {}
return super(Resources, cls).__new__(cls, num_cpus, num_gpus,
resources)
def sgd_mse_optimizer(model, config):
"""Returns the mean squared error criterion and SGD optimizer.
Args:
model (torch.nn.Module): the model to optimize.
config (dict): configuration for the optimizer.
lr (float): the learning rate. defaults to 0.01.
"""
learning_rate = config.get("lr", 0.01)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
return criterion, optimizer
@@ -0,0 +1,40 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
class LinearDataset(torch.utils.data.Dataset):
"""y = a * x + b"""
def __init__(self, a, b, size=1000):
x = np.random.random(size).astype(np.float32) * 10
x = np.arange(0, 10, 10 / size, dtype=np.float32)
self.x = torch.from_numpy(x)
self.y = torch.from_numpy(a * x + b)
def __getitem__(self, index):
return self.x[index, None], self.y[index, None]
def __len__(self):
return len(self.x)
def model_creator(config):
return nn.Linear(1, 1)
def optimizer_creator(model, config):
"""Returns criterion, optimizer"""
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
return criterion, optimizer
def data_creator(config):
"""Returns training set, validation set"""
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
@@ -0,0 +1,76 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pytest
import sys
import tempfile
import torch
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
from ray.experimental.sgd.pytorch import PyTorchTrainer, Resources
from ray.experimental.sgd.tests.pytorch_utils import (
model_creator, optimizer_creator, data_creator)
@pytest.mark.skipif( # noqa: F811
sys.platform == "darwin", reason="Doesn't work on macOS.")
def test_train(ray_start_2_cpus): # noqa: F811
trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
num_replicas=2,
resources_per_replica=Resources(num_cpus=1))
train_loss1 = trainer.train()["train_loss"]
validation_loss1 = trainer.validate()["validation_loss"]
train_loss2 = trainer.train()["train_loss"]
validation_loss2 = trainer.validate()["validation_loss"]
print(train_loss1, train_loss2)
print(validation_loss1, validation_loss2)
assert train_loss2 <= train_loss1
assert validation_loss2 <= validation_loss1
@pytest.mark.skipif( # noqa: F811
sys.platform == "darwin", reason="Doesn't work on macOS.")
def test_save_and_restore(ray_start_2_cpus): # noqa: F811
trainer1 = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
num_replicas=2,
resources_per_replica=Resources(num_cpus=1))
trainer1.train()
filename = os.path.join(tempfile.mkdtemp(), "checkpoint")
trainer1.save(filename)
model1 = trainer1.get_model()
trainer1.shutdown()
trainer2 = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
num_replicas=2,
resources_per_replica=Resources(num_cpus=1))
trainer2.restore(filename)
os.remove(filename)
model2 = trainer2.get_model()
model1_state_dict = model1.state_dict()
model2_state_dict = model2.state_dict()
assert set(model1_state_dict.keys()) == set(model2_state_dict.keys())
for k in model1_state_dict:
assert torch.equal(model1_state_dict[k], model2_state_dict[k])