[tune] distributed torch wrapper (#9550)

* changes

* add-working

* checkpoint

* ccleanu

* fix

* ok

* formatting

* ok

* tests

* some-good-stuff

* fix-torch

* ddp-torch

* torch-test

* sessions

* add-small-test

* fix

* remove

* gpu-working

* update-tests

* ok

* try-test

* formgat

* ok

* ok
This commit is contained in:
Richard Liaw
2020-07-26 09:37:22 -07:00
committed by GitHub
parent c6a7b3ac68
commit f3fdb5c5db
13 changed files with 515 additions and 52 deletions
+9
View File
@@ -26,6 +26,14 @@ py_test(
deps = [":sgd_lib"],
)
py_test(
name = "test_torch_trainable",
size = "small",
srcs = ["tests/test_torch_trainable.py"],
tags = ["exclusive", "pytorch"],
deps = [":sgd_lib"],
)
# --------------------------------------------------------------------
# Tests from the python/ray/util/sgd/tf/examples directory.
# Please keep these sorted alphabetically.
@@ -181,6 +189,7 @@ py_test(
args = ["--num-workers=2"]
)
# --------------------------------------------------------------------
# Tests from the python/ray/util/sgd/torch/examples/* directories.
# Only covers subdirectories.
@@ -0,0 +1,85 @@
import os
import pytest
from unittest.mock import patch
import torch
import torch.distributed as dist
import ray
from ray import tune
from ray.util.sgd.torch.func_trainable import (
DistributedTrainableCreator, distributed_checkpoint, _train_simple)
@pytest.fixture
def ray_start_2_cpus():
address_info = ray.init(num_cpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
# Ensure that tests don't ALL fail
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture
def ray_start_4_cpus():
address_info = ray.init(num_cpus=4)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
# Ensure that tests don't ALL fail
if dist.is_initialized():
dist.destroy_process_group()
def test_single_step(ray_start_2_cpus): # noqa: F811
trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
trainer = trainable_cls()
trainer.train()
trainer.stop()
def test_step_after_completion(ray_start_2_cpus): # noqa: F811
trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
trainer = trainable_cls(config={"epochs": 1})
with pytest.raises(RuntimeError):
for i in range(10):
trainer.train()
def test_save_checkpoint(ray_start_2_cpus): # noqa: F811
trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
trainer = trainable_cls(config={"epochs": 1})
trainer.train()
path = trainer.save()
model_state_dict, opt_state_dict = torch.load(path)
trainer.stop()
@pytest.mark.parametrize("enabled_checkpoint", [True, False])
def test_simple_tune(ray_start_4_cpus, enabled_checkpoint):
trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
analysis = tune.run(
trainable_cls,
config={"enable_checkpoint": enabled_checkpoint},
num_samples=2,
stop={"training_iteration": 2})
assert analysis.trials[0].last_result["training_iteration"] == 2
assert analysis.trials[0].has_checkpoint() == enabled_checkpoint
@pytest.mark.parametrize("rank", [0, 1])
def test_checkpoint(ray_start_2_cpus, rank): # noqa: F811
with patch("torch.distributed.get_rank") as rank_method:
rank_method.return_value = rank
with distributed_checkpoint(label="test") as path:
if rank == 0:
assert path
else:
assert path == os.devnull
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
+8 -2
View File
@@ -12,7 +12,13 @@ try:
BaseTorchTrainable)
from ray.util.sgd.torch.training_operator import TrainingOperator
from ray.util.sgd.torch.func_trainable import (DistributedTrainableCreator,
distributed_checkpoint)
__all__ = ["TorchTrainer", "BaseTorchTrainable", "TrainingOperator"]
except ImportError:
__all__ = [
"TorchTrainer", "BaseTorchTrainable", "TrainingOperator",
"distributed_checkpoint", "DistributedTrainableCreator"
]
except ImportError as e:
logger.warning(e)
logger.warning("PyTorch not found. TorchTrainer will not be available")
@@ -1,4 +1,3 @@
from datetime import timedelta
import logging
import io
import os
@@ -8,7 +7,7 @@ import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.data.distributed import DistributedSampler
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
from ray.util.sgd.torch.utils import setup_process_group
import ray
from ray.util.sgd.torch.torch_runner import TorchRunner
@@ -47,31 +46,19 @@ class DistributedTorchRunner(TorchRunner):
def setup(self):
raise RuntimeError("Need to call setup commands separately.")
def setup_process_group(self, url, world_rank, world_size):
def setup_process_group(self, url, world_rank, world_size, timeout):
"""Connects the distributed PyTorch backend.
Args:
url (str): the URL used to connect to distributed PyTorch.
world_rank (int): the index of the runner.
world_size (int): the total number of runners.
timeout (timedelta): Seconds for process group
operations to timeout.
"""
self.world_rank = world_rank
logger.debug("Connecting to {} world_rank: {} world_size: {}".format(
url, world_rank, world_size))
logger.debug("using {}".format(self.backend))
if self.backend == "nccl" and "NCCL_BLOCKING_WAIT" not in os.environ:
logger.debug(
"Setting NCCL_BLOCKING_WAIT for detecting node failure. "
"To override this behavior, you can set NCCL_BLOCKING_WAIT=0.")
os.environ["NCCL_BLOCKING_WAIT"] = "1"
timeout = timedelta(seconds=NCCL_TIMEOUT_S)
dist.init_process_group(
backend=self.backend,
init_method=url,
rank=world_rank,
world_size=world_size,
timeout=timeout)
setup_process_group(
url, world_rank, world_size, timeout, backend=self.backend)
def setup_ddp_and_operator(self):
"""Runs distributed coordination components.
+235
View File
@@ -0,0 +1,235 @@
# Original Code here:
# https://github.com/pytorch/examples/blob/master/mnist/main.py
import os
import logging
import torch
from datetime import timedelta
import ray
from ray import tune
from ray.tune.result import RESULT_DUPLICATE
from ray.tune.logger import NoopLogger
from ray.tune.function_runner import wrap_function
from ray.tune.resources import Resources
from ray.tune.trainable import TrainableUtil
from ray.util.sgd.torch.utils import setup_process_group
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
from ray.util.sgd.torch.utils import setup_address
logger = logging.getLogger(__name__)
def logger_creator(log_config, logdir, rank):
worker_dir = os.path.join(logdir, "worker_{}".format(rank))
os.makedirs(worker_dir, exist_ok=True)
return NoopLogger(log_config, worker_dir)
class _TorchTrainable(tune.Trainable):
"""Base class for distributed training on Tune.
A wrapper class is needed to actually create a working
version of this trainable.
"""
_function = None
_num_workers = None
_use_gpu = None
_num_cpus_per_worker = None
__slots__ = ["workers", "_finished"]
@classmethod
def default_process_group_parameters(self):
return dict(timeout=timedelta(NCCL_TIMEOUT_S), backend="gloo")
@classmethod
def get_remote_worker_options(self):
num_gpus = 1 if self._use_gpu else 0
num_cpus = int(self._num_cpus_per_worker or 1)
return dict(num_cpus=num_cpus, num_gpus=num_gpus)
def setup(self, config):
self._finished = False
num_workers = self._num_workers
logdir = self.logdir
assert self._function
func_trainable = wrap_function(self.__class__._function)
remote_trainable = ray.remote(func_trainable)
remote_trainable = remote_trainable.options(
**self.get_remote_worker_options())
address = setup_address()
self.workers = [
remote_trainable.remote(
config=config,
logger_creator=lambda cfg: logger_creator(cfg, logdir, rank))
for rank in range(num_workers)
]
pgroup_params = self.default_process_group_parameters()
from functools import partial
setup_on_worker = partial(
setup_process_group,
url=address,
world_size=num_workers,
**pgroup_params)
ray.get([
w.execute.remote(lambda _: setup_on_worker(world_rank=rank))
for rank, w in enumerate(self.workers)
])
def step(self):
if self._finished:
raise RuntimeError("Training has already finished.")
result = ray.get([w.step.remote() for w in self.workers])[0]
if RESULT_DUPLICATE in result:
self._finished = True
return result
def save_checkpoint(self, checkpoint_dir):
# TODO: optimize if colocated
save_obj = ray.get(self.workers[0].save_to_object.remote())
checkpoint_path = TrainableUtil.create_from_pickle(
save_obj, checkpoint_dir)
return checkpoint_path
def load_checkpoint(self, checkpoint_dir):
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
return ray.get(
w.restore_from_object.remote(checkpoint_obj) for w in self.workers)
def stop(self):
ray.get([worker.stop.remote() for worker in self.workers])
def DistributedTrainableCreator(func,
use_gpu=False,
num_workers=1,
num_cpus_per_worker=1,
backend="gloo",
timeout_s=NCCL_TIMEOUT_S):
"""Creates a class that executes distributed training.
Note that you typically should not instantiate the object
created.
Example:
.. code-block::
trainable_cls = DistributedTrainableCreator(
train_func, num_workers=2)
analysis = tune.run(trainable_cls)
"""
class WrappedDistributedTorchTrainable(_TorchTrainable):
_function = func
_num_workers = num_workers
_use_gpu = use_gpu
_num_cpus_per_worker = num_cpus_per_worker
@classmethod
def default_process_group_parameters(self):
return dict(timeout=timedelta(timeout_s), backend=backend)
@classmethod
def default_resource_request(cls, config):
num_workers_ = int(config.get("num_workers", num_workers))
num_cpus = int(
config.get("num_cpus_per_worker", num_cpus_per_worker))
use_gpu_ = config.get("use_gpu", use_gpu)
return Resources(
cpu=0,
gpu=0,
extra_cpu=num_cpus * num_workers_,
extra_gpu=num_workers_ if use_gpu_ else 0)
return WrappedDistributedTorchTrainable
class distributed_checkpoint:
"""ContextManager for creating a distributed checkpoint.
Only checkpoints a file on the "main" training actor, avoiding
redundant work.
Args:
label (int | str): Used to label the checkpoint
disable (bool): Disable for prototyping.
Example:
.. code-block::
if epoch % 3 == 0:
with distributed_checkpoint(label=epoch) as path:
torch.save(model.state_dict(), path)
"""
def __init__(self, label, disable=False):
self.label = label
self.file = None
self.disable = disable
def __enter__(self):
if torch.distributed.get_rank() == 0 and not self.disable:
checkpoint_dir = tune.make_checkpoint_dir(step=self.label)
path = os.path.join(checkpoint_dir, "checkpoint")
else:
path = os.devnull
self.file = path
return path
def __exit__(self, type, value, traceback):
if torch.distributed.get_rank() == 0 and not self.disable:
tune.save_checkpoint(self.file)
def _train_simple(config, checkpoint=False):
"""For testing only. Putting this here because Ray has problems
serializing within the test file."""
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 8, 5, 5, 5
# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
loss_fn = nn.MSELoss()
# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out),
)
optimizer = optim.SGD(model.parameters(), lr=0.1)
if checkpoint:
with open(checkpoint) as f:
model_state, optimizer_state = torch.load(f)
model.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
model = DistributedDataParallel(model)
for epoch in range(config.get("epochs", 10)):
optimizer.zero_grad()
output = model(x)
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
if epoch % 3 == 0:
if config.get("enable_checkpoint", True):
with distributed_checkpoint(label=epoch) as path:
torch.save((model.state_dict(), optimizer.state_dict()),
path)
tune.report(mean_loss=loss.item())
+10 -7
View File
@@ -1,3 +1,4 @@
from datetime import timedelta
import numpy as np
import logging
import os
@@ -16,7 +17,8 @@ from ray.util.sgd.torch.distributed_torch_runner import (
DistributedTorchRunner, LocalDistributedRunner, DeactivatedRunner)
from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE
from ray.util.sgd.torch.torch_runner import TorchRunner
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP, NCCL_TIMEOUT_S
from ray.util.sgd.torch.utils import setup_address
from ray.util.sgd.data import Dataset
logger = logging.getLogger(__name__)
@@ -180,6 +182,7 @@ class TorchTrainer:
use_gpu="auto",
backend="auto",
wrap_ddp=True,
timeout_s=NCCL_TIMEOUT_S,
serialize_data_creation=True,
use_fp16=False,
use_tqdm=False,
@@ -248,6 +251,7 @@ class TorchTrainer:
self.serialize_data_creation = serialize_data_creation
self.wrap_ddp = wrap_ddp
self.timeout_s = timeout_s
self.use_fp16 = use_fp16
self.use_tqdm = use_tqdm
self.add_dist_sampler = add_dist_sampler
@@ -347,10 +351,7 @@ class TorchTrainer:
self.apply_all_workers(self.initialization_hook)
# Compute URL for initializing distributed PyTorch
ip = ray.services.get_node_ip_address()
port = self.local_worker.find_free_port()
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
address = setup_address()
# Runs the creator functions.
remote_component_setup = [
@@ -363,10 +364,12 @@ class TorchTrainer:
# Setup the process group among all workers.
remote_pgroup_setups = [
worker.setup_process_group.remote(address, i + 1, num_workers)
worker.setup_process_group.remote(address, i + 1, num_workers,
timedelta(self.timeout_s))
for i, worker in enumerate(self.remote_workers)
]
self.local_worker.setup_process_group(address, 0, num_workers)
self.local_worker.setup_process_group(address, 0, num_workers,
timedelta(self.timeout_s))
# Get setup tasks in order to throw errors on failure
ray.get(remote_pgroup_setups)
+43
View File
@@ -0,0 +1,43 @@
import os
import logging
import torch.distributed as dist
import ray
from ray.util.sgd.utils import find_free_port
logger = logging.getLogger(__name__)
def setup_address():
ip = ray.services.get_node_ip_address()
port = find_free_port()
return "tcp://{ip}:{port}".format(ip=ip, port=port)
def setup_process_group(url, world_rank, world_size, timeout, backend="gloo"):
"""Connects the distributed PyTorch backend.
Args:
url (str): the URL used to connect to distributed PyTorch.
world_rank (int): the index of the runner.
world_size (int): the total number of runners.
timeout (timedelta): Timeout for operations executed against
the process group.
backend (str): One of gloo or nccl. Depending on
build-time configuration
"""
logger.debug("Connecting to {} world_rank: {} world_size: {}".format(
url, world_rank, world_size))
logger.debug("using {}".format(backend))
if backend == "nccl" and "NCCL_BLOCKING_WAIT" not in os.environ:
logger.debug(
"Setting NCCL_BLOCKING_WAIT for detecting node failure. "
"To override this behavior, you can set NCCL_BLOCKING_WAIT=0.")
os.environ["NCCL_BLOCKING_WAIT"] = "1"
dist.init_process_group(
backend=backend,
init_method=url,
rank=world_rank,
world_size=world_size,
timeout=timeout)