diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index fa40e3c70..d5334546a 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -287,6 +287,15 @@ py_test( args = ["--smoke-test"] ) +py_test( + name = "ddp_mnist_torch", + size = "small", + srcs = ["examples/ddp_mnist_torch.py"], + deps = [":tune_lib"], + tags = ["exclusive", "example", "pytorch"], + args = ["--num-workers=2"] +) + py_test( name = "dragonfly_example", size = "medium", diff --git a/python/ray/tune/examples/ddp_mnist_torch.py b/python/ray/tune/examples/ddp_mnist_torch.py new file mode 100644 index 000000000..2641bd662 --- /dev/null +++ b/python/ray/tune/examples/ddp_mnist_torch.py @@ -0,0 +1,73 @@ +# Original Code here: +# https://github.com/pytorch/examples/blob/master/mnist/main.py +import argparse +import logging +import torch +import torch.optim as optim +from torch.nn.parallel import DistributedDataParallel + +import ray +from ray import tune +from ray.tune.examples.mnist_pytorch import (train, test, get_data_loaders, + ConvNet) +from ray.util.sgd.torch.func_trainable import (DistributedTrainableCreator, + distributed_checkpoint) + +logger = logging.getLogger(__name__) + + +def train_mnist(config, checkpoint=False): + use_cuda = torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + train_loader, test_loader = get_data_loaders() + model = ConvNet().to(device) + optimizer = optim.SGD(model.parameters(), lr=0.1) + + if checkpoint: + with open(checkpoint) as f: + model_state, optimizer_state = torch.load(f) + + model.load_state_dict(model_state) + optimizer.load_state_dict(optimizer_state) + + model = DistributedDataParallel(model) + + for epoch in range(40): + train(model, optimizer, train_loader, device) + acc = test(model, test_loader, device) + + if epoch % 3 == 0: + with distributed_checkpoint(label=epoch) as path: + torch.save((model.state_dict(), optimizer.state_dict()), path) + tune.report(mean_accuracy=acc) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--num-workers", + "-n", + type=int, + default=2, + help="Sets number of workers for training.") + parser.add_argument( + "--use-gpu", + action="store_true", + default=False, + help="enables CUDA training") + parser.add_argument( + "--cluster", + action="store_true", + default=False, + help="enables multi-node tuning") + + args = parser.parse_args() + + if args.cluster: + options = dict(address="auto") + else: + options = dict(num_cpus=2) + ray.init(**options) + trainable_cls = DistributedTrainableCreator( + train_mnist, num_workers=args.num_workers, use_gpu=args.use_gpu) + tune.run(trainable_cls, num_samples=4, stop={"training_iteration": 10}) diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index 9849d9434..dede70789 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -1,6 +1,5 @@ import logging import os -import io import time import inspect import shutil @@ -87,6 +86,7 @@ class StatusReporter: def make_checkpoint_dir(self, step=None): checkpoint_dir = TrainableUtil.make_checkpoint_dir( self.logdir, index=step) + logger.debug("Making checkpoint dir at %s", checkpoint_dir) return checkpoint_dir def save_checkpoint(self, checkpoint): @@ -279,6 +279,9 @@ class FunctionRunner(Trainable): result[SHOULD_CHECKPOINT] = True return result + def execute(self, fn): + return fn(self) + def create_default_checkpoint_dir(self): self.default_checkpoint_dir = TrainableUtil.make_checkpoint_dir( self.logdir, index="default") @@ -306,12 +309,8 @@ class FunctionRunner(Trainable): def save_to_object(self): checkpoint_path = self.save() - data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path) - out = io.BytesIO() - if len(data_dict) > 10e6: # getting pretty large - logger.info("Checkpoint size is {} bytes".format(len(data_dict))) - out.write(data_dict) - return out.getvalue() + obj = TrainableUtil.checkpoint_to_object(checkpoint_path) + return obj def load_checkpoint(self, checkpoint): # This should be removed once Trainables are refactored. diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index bfecdb19a..165e55f1a 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -670,9 +670,9 @@ class RayTrialExecutor(TrialExecutor): # This provides FT backwards compatibility in the # case where a DurableTrainable is not provided. logger.debug("Trial %s: Reading checkpoint into memory", trial) - data_dict = TrainableUtil.pickle_checkpoint(value) + obj = TrainableUtil.checkpoint_to_object(value) with self._change_working_directory(trial): - remote = trial.runner.restore_from_object.remote(data_dict) + remote = trial.runner.restore_from_object.remote(obj) else: raise AbortTrialExecution( "Pass in `sync_on_checkpoint=True` for driver-based trial" diff --git a/python/ray/tune/session.py b/python/ray/tune/session.py index 349f8fda8..09baf7c08 100644 --- a/python/ray/tune/session.py +++ b/python/ray/tune/session.py @@ -1,3 +1,4 @@ +import os import logging logger = logging.getLogger(__name__) @@ -7,8 +8,8 @@ _session = None def get_session(): global _session - if _session is None: - raise ValueError( + if not _session: + logger.warning( "Session not detected. You should not be calling this function " "outside `tune.run` or while using the class API. ") return _session @@ -67,7 +68,8 @@ def report(**kwargs): metrics can be used for early stopping or optimization. """ _session = get_session() - return _session(**kwargs) + if _session: + return _session(**kwargs) def make_checkpoint_dir(step=None): @@ -106,7 +108,10 @@ def make_checkpoint_dir(step=None): """ _session = get_session() - return _session.make_checkpoint_dir(step=step) + if _session: + return _session.make_checkpoint_dir(step=step) + else: + return os.path.abspath("./") def save_checkpoint(checkpoint): @@ -149,7 +154,8 @@ def save_checkpoint(checkpoint): .. versionadded:: 0.8.6 """ _session = get_session() - return _session.save_checkpoint(checkpoint) + if _session: + return _session.save_checkpoint(checkpoint) def get_trial_dir(): @@ -158,7 +164,8 @@ def get_trial_dir(): For function API use only. """ _session = get_session() - return _session.logdir + if _session: + return _session.logdir def get_trial_name(): @@ -167,7 +174,8 @@ def get_trial_name(): For function API use only. """ _session = get_session() - return _session.trial_name + if _session: + return _session.trial_name def get_trial_id(): @@ -176,7 +184,8 @@ def get_trial_id(): For function API use only. """ _session = get_session() - return _session.trial_id + if _session: + return _session.trial_id __all__ = ["report", "get_trial_dir", "get_trial_name", "get_trial_id"] diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index f6bfa8e2c..cb8c2a35e 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -77,6 +77,15 @@ class TrainableUtil: }) return data_dict + @staticmethod + def checkpoint_to_object(checkpoint_path): + data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path) + out = io.BytesIO() + if len(data_dict) > 10e6: # getting pretty large + logger.info("Checkpoint size is {} bytes".format(len(data_dict))) + out.write(data_dict) + return out.getvalue() + @staticmethod def find_checkpoint_dir(checkpoint_path): """Returns the directory containing the checkpoint path. @@ -424,14 +433,10 @@ class Trainable: """ tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir) checkpoint_path = self.save(tmpdir) - # Save all files in subtree. - data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path) - out = io.BytesIO() - if len(data_dict) > 10e6: # getting pretty large - logger.info("Checkpoint size is {} bytes".format(len(data_dict))) - out.write(data_dict) + # Save all files in subtree and delete the tmpdir. + obj = TrainableUtil.checkpoint_to_object(checkpoint_path) shutil.rmtree(tmpdir) - return out.getvalue() + return obj def restore(self, checkpoint_path): """Restores training state from a given model checkpoint. diff --git a/python/ray/util/sgd/BUILD b/python/ray/util/sgd/BUILD index 284897c22..15ece2006 100644 --- a/python/ray/util/sgd/BUILD +++ b/python/ray/util/sgd/BUILD @@ -26,6 +26,14 @@ py_test( deps = [":sgd_lib"], ) +py_test( + name = "test_torch_trainable", + size = "small", + srcs = ["tests/test_torch_trainable.py"], + tags = ["exclusive", "pytorch"], + deps = [":sgd_lib"], +) + # -------------------------------------------------------------------- # Tests from the python/ray/util/sgd/tf/examples directory. # Please keep these sorted alphabetically. @@ -181,6 +189,7 @@ py_test( args = ["--num-workers=2"] ) + # -------------------------------------------------------------------- # Tests from the python/ray/util/sgd/torch/examples/* directories. # Only covers subdirectories. diff --git a/python/ray/util/sgd/tests/test_torch_trainable.py b/python/ray/util/sgd/tests/test_torch_trainable.py new file mode 100644 index 000000000..570827284 --- /dev/null +++ b/python/ray/util/sgd/tests/test_torch_trainable.py @@ -0,0 +1,85 @@ +import os +import pytest +from unittest.mock import patch +import torch +import torch.distributed as dist + +import ray +from ray import tune +from ray.util.sgd.torch.func_trainable import ( + DistributedTrainableCreator, distributed_checkpoint, _train_simple) + + +@pytest.fixture +def ray_start_2_cpus(): + address_info = ray.init(num_cpus=2) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + # Ensure that tests don't ALL fail + if dist.is_initialized(): + dist.destroy_process_group() + + +@pytest.fixture +def ray_start_4_cpus(): + address_info = ray.init(num_cpus=4) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + # Ensure that tests don't ALL fail + if dist.is_initialized(): + dist.destroy_process_group() + + +def test_single_step(ray_start_2_cpus): # noqa: F811 + trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2) + trainer = trainable_cls() + trainer.train() + trainer.stop() + + +def test_step_after_completion(ray_start_2_cpus): # noqa: F811 + trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2) + trainer = trainable_cls(config={"epochs": 1}) + with pytest.raises(RuntimeError): + for i in range(10): + trainer.train() + + +def test_save_checkpoint(ray_start_2_cpus): # noqa: F811 + trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2) + trainer = trainable_cls(config={"epochs": 1}) + trainer.train() + path = trainer.save() + model_state_dict, opt_state_dict = torch.load(path) + trainer.stop() + + +@pytest.mark.parametrize("enabled_checkpoint", [True, False]) +def test_simple_tune(ray_start_4_cpus, enabled_checkpoint): + trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2) + analysis = tune.run( + trainable_cls, + config={"enable_checkpoint": enabled_checkpoint}, + num_samples=2, + stop={"training_iteration": 2}) + assert analysis.trials[0].last_result["training_iteration"] == 2 + assert analysis.trials[0].has_checkpoint() == enabled_checkpoint + + +@pytest.mark.parametrize("rank", [0, 1]) +def test_checkpoint(ray_start_2_cpus, rank): # noqa: F811 + with patch("torch.distributed.get_rank") as rank_method: + rank_method.return_value = rank + with distributed_checkpoint(label="test") as path: + if rank == 0: + assert path + else: + assert path == os.devnull + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/sgd/torch/__init__.py b/python/ray/util/sgd/torch/__init__.py index c3772de18..d37c27f6b 100644 --- a/python/ray/util/sgd/torch/__init__.py +++ b/python/ray/util/sgd/torch/__init__.py @@ -12,7 +12,13 @@ try: BaseTorchTrainable) from ray.util.sgd.torch.training_operator import TrainingOperator + from ray.util.sgd.torch.func_trainable import (DistributedTrainableCreator, + distributed_checkpoint) - __all__ = ["TorchTrainer", "BaseTorchTrainable", "TrainingOperator"] -except ImportError: + __all__ = [ + "TorchTrainer", "BaseTorchTrainable", "TrainingOperator", + "distributed_checkpoint", "DistributedTrainableCreator" + ] +except ImportError as e: + logger.warning(e) logger.warning("PyTorch not found. TorchTrainer will not be available") diff --git a/python/ray/util/sgd/torch/distributed_torch_runner.py b/python/ray/util/sgd/torch/distributed_torch_runner.py index d97b04e5f..0ccae62c5 100644 --- a/python/ray/util/sgd/torch/distributed_torch_runner.py +++ b/python/ray/util/sgd/torch/distributed_torch_runner.py @@ -1,4 +1,3 @@ -from datetime import timedelta import logging import io import os @@ -8,7 +7,7 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, IterableDataset from torch.utils.data.distributed import DistributedSampler -from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S +from ray.util.sgd.torch.utils import setup_process_group import ray from ray.util.sgd.torch.torch_runner import TorchRunner @@ -47,31 +46,19 @@ class DistributedTorchRunner(TorchRunner): def setup(self): raise RuntimeError("Need to call setup commands separately.") - def setup_process_group(self, url, world_rank, world_size): + def setup_process_group(self, url, world_rank, world_size, timeout): """Connects the distributed PyTorch backend. Args: url (str): the URL used to connect to distributed PyTorch. world_rank (int): the index of the runner. world_size (int): the total number of runners. + timeout (timedelta): Seconds for process group + operations to timeout. """ self.world_rank = world_rank - logger.debug("Connecting to {} world_rank: {} world_size: {}".format( - url, world_rank, world_size)) - logger.debug("using {}".format(self.backend)) - if self.backend == "nccl" and "NCCL_BLOCKING_WAIT" not in os.environ: - logger.debug( - "Setting NCCL_BLOCKING_WAIT for detecting node failure. " - "To override this behavior, you can set NCCL_BLOCKING_WAIT=0.") - os.environ["NCCL_BLOCKING_WAIT"] = "1" - - timeout = timedelta(seconds=NCCL_TIMEOUT_S) - dist.init_process_group( - backend=self.backend, - init_method=url, - rank=world_rank, - world_size=world_size, - timeout=timeout) + setup_process_group( + url, world_rank, world_size, timeout, backend=self.backend) def setup_ddp_and_operator(self): """Runs distributed coordination components. diff --git a/python/ray/util/sgd/torch/func_trainable.py b/python/ray/util/sgd/torch/func_trainable.py new file mode 100644 index 000000000..3dc4ac443 --- /dev/null +++ b/python/ray/util/sgd/torch/func_trainable.py @@ -0,0 +1,235 @@ +# Original Code here: +# https://github.com/pytorch/examples/blob/master/mnist/main.py +import os +import logging +import torch +from datetime import timedelta + +import ray +from ray import tune +from ray.tune.result import RESULT_DUPLICATE +from ray.tune.logger import NoopLogger +from ray.tune.function_runner import wrap_function +from ray.tune.resources import Resources +from ray.tune.trainable import TrainableUtil +from ray.util.sgd.torch.utils import setup_process_group +from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S +from ray.util.sgd.torch.utils import setup_address + +logger = logging.getLogger(__name__) + + +def logger_creator(log_config, logdir, rank): + worker_dir = os.path.join(logdir, "worker_{}".format(rank)) + os.makedirs(worker_dir, exist_ok=True) + return NoopLogger(log_config, worker_dir) + + +class _TorchTrainable(tune.Trainable): + """Base class for distributed training on Tune. + + A wrapper class is needed to actually create a working + version of this trainable. + """ + _function = None + _num_workers = None + _use_gpu = None + _num_cpus_per_worker = None + + __slots__ = ["workers", "_finished"] + + @classmethod + def default_process_group_parameters(self): + return dict(timeout=timedelta(NCCL_TIMEOUT_S), backend="gloo") + + @classmethod + def get_remote_worker_options(self): + num_gpus = 1 if self._use_gpu else 0 + num_cpus = int(self._num_cpus_per_worker or 1) + return dict(num_cpus=num_cpus, num_gpus=num_gpus) + + def setup(self, config): + self._finished = False + num_workers = self._num_workers + logdir = self.logdir + assert self._function + + func_trainable = wrap_function(self.__class__._function) + + remote_trainable = ray.remote(func_trainable) + remote_trainable = remote_trainable.options( + **self.get_remote_worker_options()) + + address = setup_address() + self.workers = [ + remote_trainable.remote( + config=config, + logger_creator=lambda cfg: logger_creator(cfg, logdir, rank)) + for rank in range(num_workers) + ] + + pgroup_params = self.default_process_group_parameters() + from functools import partial + setup_on_worker = partial( + setup_process_group, + url=address, + world_size=num_workers, + **pgroup_params) + ray.get([ + w.execute.remote(lambda _: setup_on_worker(world_rank=rank)) + for rank, w in enumerate(self.workers) + ]) + + def step(self): + if self._finished: + raise RuntimeError("Training has already finished.") + result = ray.get([w.step.remote() for w in self.workers])[0] + if RESULT_DUPLICATE in result: + self._finished = True + return result + + def save_checkpoint(self, checkpoint_dir): + # TODO: optimize if colocated + save_obj = ray.get(self.workers[0].save_to_object.remote()) + checkpoint_path = TrainableUtil.create_from_pickle( + save_obj, checkpoint_dir) + return checkpoint_path + + def load_checkpoint(self, checkpoint_dir): + checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir) + return ray.get( + w.restore_from_object.remote(checkpoint_obj) for w in self.workers) + + def stop(self): + ray.get([worker.stop.remote() for worker in self.workers]) + + +def DistributedTrainableCreator(func, + use_gpu=False, + num_workers=1, + num_cpus_per_worker=1, + backend="gloo", + timeout_s=NCCL_TIMEOUT_S): + """Creates a class that executes distributed training. + + Note that you typically should not instantiate the object + created. + + Example: + + .. code-block:: + + trainable_cls = DistributedTrainableCreator( + train_func, num_workers=2) + analysis = tune.run(trainable_cls) + """ + + class WrappedDistributedTorchTrainable(_TorchTrainable): + _function = func + _num_workers = num_workers + _use_gpu = use_gpu + _num_cpus_per_worker = num_cpus_per_worker + + @classmethod + def default_process_group_parameters(self): + return dict(timeout=timedelta(timeout_s), backend=backend) + + @classmethod + def default_resource_request(cls, config): + num_workers_ = int(config.get("num_workers", num_workers)) + num_cpus = int( + config.get("num_cpus_per_worker", num_cpus_per_worker)) + use_gpu_ = config.get("use_gpu", use_gpu) + + return Resources( + cpu=0, + gpu=0, + extra_cpu=num_cpus * num_workers_, + extra_gpu=num_workers_ if use_gpu_ else 0) + + return WrappedDistributedTorchTrainable + + +class distributed_checkpoint: + """ContextManager for creating a distributed checkpoint. + + Only checkpoints a file on the "main" training actor, avoiding + redundant work. + + Args: + label (int | str): Used to label the checkpoint + disable (bool): Disable for prototyping. + + Example: + + .. code-block:: + + if epoch % 3 == 0: + with distributed_checkpoint(label=epoch) as path: + torch.save(model.state_dict(), path) + """ + + def __init__(self, label, disable=False): + self.label = label + self.file = None + self.disable = disable + + def __enter__(self): + if torch.distributed.get_rank() == 0 and not self.disable: + checkpoint_dir = tune.make_checkpoint_dir(step=self.label) + path = os.path.join(checkpoint_dir, "checkpoint") + else: + path = os.devnull + self.file = path + return path + + def __exit__(self, type, value, traceback): + if torch.distributed.get_rank() == 0 and not self.disable: + tune.save_checkpoint(self.file) + + +def _train_simple(config, checkpoint=False): + """For testing only. Putting this here because Ray has problems + serializing within the test file.""" + import torch.nn as nn + from torch.nn.parallel import DistributedDataParallel + import torch.optim as optim + # N is batch size; D_in is input dimension; + # H is hidden dimension; D_out is output dimension. + N, D_in, H, D_out = 8, 5, 5, 5 + + # Create random Tensors to hold inputs and outputs + x = torch.randn(N, D_in) + y = torch.randn(N, D_out) + loss_fn = nn.MSELoss() + + # Use the nn package to define our model and loss function. + model = torch.nn.Sequential( + torch.nn.Linear(D_in, H), + torch.nn.ReLU(), + torch.nn.Linear(H, D_out), + ) + optimizer = optim.SGD(model.parameters(), lr=0.1) + + if checkpoint: + with open(checkpoint) as f: + model_state, optimizer_state = torch.load(f) + + model.load_state_dict(model_state) + optimizer.load_state_dict(optimizer_state) + + model = DistributedDataParallel(model) + + for epoch in range(config.get("epochs", 10)): + optimizer.zero_grad() + output = model(x) + loss = loss_fn(output, y) + loss.backward() + optimizer.step() + + if epoch % 3 == 0: + if config.get("enable_checkpoint", True): + with distributed_checkpoint(label=epoch) as path: + torch.save((model.state_dict(), optimizer.state_dict()), + path) + tune.report(mean_loss=loss.item()) diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index b7d7a2bd7..019e104cd 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -1,3 +1,4 @@ +from datetime import timedelta import numpy as np import logging import os @@ -16,7 +17,8 @@ from ray.util.sgd.torch.distributed_torch_runner import ( DistributedTorchRunner, LocalDistributedRunner, DeactivatedRunner) from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE from ray.util.sgd.torch.torch_runner import TorchRunner -from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP +from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP, NCCL_TIMEOUT_S +from ray.util.sgd.torch.utils import setup_address from ray.util.sgd.data import Dataset logger = logging.getLogger(__name__) @@ -180,6 +182,7 @@ class TorchTrainer: use_gpu="auto", backend="auto", wrap_ddp=True, + timeout_s=NCCL_TIMEOUT_S, serialize_data_creation=True, use_fp16=False, use_tqdm=False, @@ -248,6 +251,7 @@ class TorchTrainer: self.serialize_data_creation = serialize_data_creation self.wrap_ddp = wrap_ddp + self.timeout_s = timeout_s self.use_fp16 = use_fp16 self.use_tqdm = use_tqdm self.add_dist_sampler = add_dist_sampler @@ -347,10 +351,7 @@ class TorchTrainer: self.apply_all_workers(self.initialization_hook) # Compute URL for initializing distributed PyTorch - ip = ray.services.get_node_ip_address() - port = self.local_worker.find_free_port() - - address = "tcp://{ip}:{port}".format(ip=ip, port=port) + address = setup_address() # Runs the creator functions. remote_component_setup = [ @@ -363,10 +364,12 @@ class TorchTrainer: # Setup the process group among all workers. remote_pgroup_setups = [ - worker.setup_process_group.remote(address, i + 1, num_workers) + worker.setup_process_group.remote(address, i + 1, num_workers, + timedelta(self.timeout_s)) for i, worker in enumerate(self.remote_workers) ] - self.local_worker.setup_process_group(address, 0, num_workers) + self.local_worker.setup_process_group(address, 0, num_workers, + timedelta(self.timeout_s)) # Get setup tasks in order to throw errors on failure ray.get(remote_pgroup_setups) diff --git a/python/ray/util/sgd/torch/utils.py b/python/ray/util/sgd/torch/utils.py new file mode 100644 index 000000000..c57b2d7b6 --- /dev/null +++ b/python/ray/util/sgd/torch/utils.py @@ -0,0 +1,43 @@ +import os +import logging +import torch.distributed as dist + +import ray +from ray.util.sgd.utils import find_free_port + +logger = logging.getLogger(__name__) + + +def setup_address(): + ip = ray.services.get_node_ip_address() + port = find_free_port() + return "tcp://{ip}:{port}".format(ip=ip, port=port) + + +def setup_process_group(url, world_rank, world_size, timeout, backend="gloo"): + """Connects the distributed PyTorch backend. + + Args: + url (str): the URL used to connect to distributed PyTorch. + world_rank (int): the index of the runner. + world_size (int): the total number of runners. + timeout (timedelta): Timeout for operations executed against + the process group. + backend (str): One of gloo or nccl. Depending on + build-time configuration + """ + logger.debug("Connecting to {} world_rank: {} world_size: {}".format( + url, world_rank, world_size)) + logger.debug("using {}".format(backend)) + if backend == "nccl" and "NCCL_BLOCKING_WAIT" not in os.environ: + logger.debug( + "Setting NCCL_BLOCKING_WAIT for detecting node failure. " + "To override this behavior, you can set NCCL_BLOCKING_WAIT=0.") + os.environ["NCCL_BLOCKING_WAIT"] = "1" + + dist.init_process_group( + backend=backend, + init_method=url, + rank=world_rank, + world_size=world_size, + timeout=timeout)