mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[tune] distributed torch wrapper (#9550)
* changes * add-working * checkpoint * ccleanu * fix * ok * formatting * ok * tests * some-good-stuff * fix-torch * ddp-torch * torch-test * sessions * add-small-test * fix * remove * gpu-working * update-tests * ok * try-test * formgat * ok * ok
This commit is contained in:
@@ -287,6 +287,15 @@ py_test(
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ddp_mnist_torch",
|
||||
size = "small",
|
||||
srcs = ["examples/ddp_mnist_torch.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example", "pytorch"],
|
||||
args = ["--num-workers=2"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dragonfly_example",
|
||||
size = "medium",
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
# Original Code here:
|
||||
# https://github.com/pytorch/examples/blob/master/mnist/main.py
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.examples.mnist_pytorch import (train, test, get_data_loaders,
|
||||
ConvNet)
|
||||
from ray.util.sgd.torch.func_trainable import (DistributedTrainableCreator,
|
||||
distributed_checkpoint)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def train_mnist(config, checkpoint=False):
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
train_loader, test_loader = get_data_loaders()
|
||||
model = ConvNet().to(device)
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.1)
|
||||
|
||||
if checkpoint:
|
||||
with open(checkpoint) as f:
|
||||
model_state, optimizer_state = torch.load(f)
|
||||
|
||||
model.load_state_dict(model_state)
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
model = DistributedDataParallel(model)
|
||||
|
||||
for epoch in range(40):
|
||||
train(model, optimizer, train_loader, device)
|
||||
acc = test(model, test_loader, device)
|
||||
|
||||
if epoch % 3 == 0:
|
||||
with distributed_checkpoint(label=epoch) as path:
|
||||
torch.save((model.state_dict(), optimizer.state_dict()), path)
|
||||
tune.report(mean_accuracy=acc)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Sets number of workers for training.")
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="enables CUDA training")
|
||||
parser.add_argument(
|
||||
"--cluster",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="enables multi-node tuning")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cluster:
|
||||
options = dict(address="auto")
|
||||
else:
|
||||
options = dict(num_cpus=2)
|
||||
ray.init(**options)
|
||||
trainable_cls = DistributedTrainableCreator(
|
||||
train_mnist, num_workers=args.num_workers, use_gpu=args.use_gpu)
|
||||
tune.run(trainable_cls, num_samples=4, stop={"training_iteration": 10})
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
import io
|
||||
import time
|
||||
import inspect
|
||||
import shutil
|
||||
@@ -87,6 +86,7 @@ class StatusReporter:
|
||||
def make_checkpoint_dir(self, step=None):
|
||||
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
self.logdir, index=step)
|
||||
logger.debug("Making checkpoint dir at %s", checkpoint_dir)
|
||||
return checkpoint_dir
|
||||
|
||||
def save_checkpoint(self, checkpoint):
|
||||
@@ -279,6 +279,9 @@ class FunctionRunner(Trainable):
|
||||
result[SHOULD_CHECKPOINT] = True
|
||||
return result
|
||||
|
||||
def execute(self, fn):
|
||||
return fn(self)
|
||||
|
||||
def create_default_checkpoint_dir(self):
|
||||
self.default_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
self.logdir, index="default")
|
||||
@@ -306,12 +309,8 @@ class FunctionRunner(Trainable):
|
||||
|
||||
def save_to_object(self):
|
||||
checkpoint_path = self.save()
|
||||
data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
|
||||
out = io.BytesIO()
|
||||
if len(data_dict) > 10e6: # getting pretty large
|
||||
logger.info("Checkpoint size is {} bytes".format(len(data_dict)))
|
||||
out.write(data_dict)
|
||||
return out.getvalue()
|
||||
obj = TrainableUtil.checkpoint_to_object(checkpoint_path)
|
||||
return obj
|
||||
|
||||
def load_checkpoint(self, checkpoint):
|
||||
# This should be removed once Trainables are refactored.
|
||||
|
||||
@@ -670,9 +670,9 @@ class RayTrialExecutor(TrialExecutor):
|
||||
# This provides FT backwards compatibility in the
|
||||
# case where a DurableTrainable is not provided.
|
||||
logger.debug("Trial %s: Reading checkpoint into memory", trial)
|
||||
data_dict = TrainableUtil.pickle_checkpoint(value)
|
||||
obj = TrainableUtil.checkpoint_to_object(value)
|
||||
with self._change_working_directory(trial):
|
||||
remote = trial.runner.restore_from_object.remote(data_dict)
|
||||
remote = trial.runner.restore_from_object.remote(obj)
|
||||
else:
|
||||
raise AbortTrialExecution(
|
||||
"Pass in `sync_on_checkpoint=True` for driver-based trial"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -7,8 +8,8 @@ _session = None
|
||||
|
||||
def get_session():
|
||||
global _session
|
||||
if _session is None:
|
||||
raise ValueError(
|
||||
if not _session:
|
||||
logger.warning(
|
||||
"Session not detected. You should not be calling this function "
|
||||
"outside `tune.run` or while using the class API. ")
|
||||
return _session
|
||||
@@ -67,7 +68,8 @@ def report(**kwargs):
|
||||
metrics can be used for early stopping or optimization.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session(**kwargs)
|
||||
if _session:
|
||||
return _session(**kwargs)
|
||||
|
||||
|
||||
def make_checkpoint_dir(step=None):
|
||||
@@ -106,7 +108,10 @@ def make_checkpoint_dir(step=None):
|
||||
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.make_checkpoint_dir(step=step)
|
||||
if _session:
|
||||
return _session.make_checkpoint_dir(step=step)
|
||||
else:
|
||||
return os.path.abspath("./")
|
||||
|
||||
|
||||
def save_checkpoint(checkpoint):
|
||||
@@ -149,7 +154,8 @@ def save_checkpoint(checkpoint):
|
||||
.. versionadded:: 0.8.6
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.save_checkpoint(checkpoint)
|
||||
if _session:
|
||||
return _session.save_checkpoint(checkpoint)
|
||||
|
||||
|
||||
def get_trial_dir():
|
||||
@@ -158,7 +164,8 @@ def get_trial_dir():
|
||||
For function API use only.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.logdir
|
||||
if _session:
|
||||
return _session.logdir
|
||||
|
||||
|
||||
def get_trial_name():
|
||||
@@ -167,7 +174,8 @@ def get_trial_name():
|
||||
For function API use only.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.trial_name
|
||||
if _session:
|
||||
return _session.trial_name
|
||||
|
||||
|
||||
def get_trial_id():
|
||||
@@ -176,7 +184,8 @@ def get_trial_id():
|
||||
For function API use only.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.trial_id
|
||||
if _session:
|
||||
return _session.trial_id
|
||||
|
||||
|
||||
__all__ = ["report", "get_trial_dir", "get_trial_name", "get_trial_id"]
|
||||
|
||||
@@ -77,6 +77,15 @@ class TrainableUtil:
|
||||
})
|
||||
return data_dict
|
||||
|
||||
@staticmethod
|
||||
def checkpoint_to_object(checkpoint_path):
|
||||
data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
|
||||
out = io.BytesIO()
|
||||
if len(data_dict) > 10e6: # getting pretty large
|
||||
logger.info("Checkpoint size is {} bytes".format(len(data_dict)))
|
||||
out.write(data_dict)
|
||||
return out.getvalue()
|
||||
|
||||
@staticmethod
|
||||
def find_checkpoint_dir(checkpoint_path):
|
||||
"""Returns the directory containing the checkpoint path.
|
||||
@@ -424,14 +433,10 @@ class Trainable:
|
||||
"""
|
||||
tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir)
|
||||
checkpoint_path = self.save(tmpdir)
|
||||
# Save all files in subtree.
|
||||
data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
|
||||
out = io.BytesIO()
|
||||
if len(data_dict) > 10e6: # getting pretty large
|
||||
logger.info("Checkpoint size is {} bytes".format(len(data_dict)))
|
||||
out.write(data_dict)
|
||||
# Save all files in subtree and delete the tmpdir.
|
||||
obj = TrainableUtil.checkpoint_to_object(checkpoint_path)
|
||||
shutil.rmtree(tmpdir)
|
||||
return out.getvalue()
|
||||
return obj
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
"""Restores training state from a given model checkpoint.
|
||||
|
||||
@@ -26,6 +26,14 @@ py_test(
|
||||
deps = [":sgd_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_torch_trainable",
|
||||
size = "small",
|
||||
srcs = ["tests/test_torch_trainable.py"],
|
||||
tags = ["exclusive", "pytorch"],
|
||||
deps = [":sgd_lib"],
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Tests from the python/ray/util/sgd/tf/examples directory.
|
||||
# Please keep these sorted alphabetically.
|
||||
@@ -181,6 +189,7 @@ py_test(
|
||||
args = ["--num-workers=2"]
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Tests from the python/ray/util/sgd/torch/examples/* directories.
|
||||
# Only covers subdirectories.
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.util.sgd.torch.func_trainable import (
|
||||
DistributedTrainableCreator, distributed_checkpoint, _train_simple)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_2_cpus():
|
||||
address_info = ray.init(num_cpus=2)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
# Ensure that tests don't ALL fail
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_4_cpus():
|
||||
address_info = ray.init(num_cpus=4)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
# Ensure that tests don't ALL fail
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def test_single_step(ray_start_2_cpus): # noqa: F811
|
||||
trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
|
||||
trainer = trainable_cls()
|
||||
trainer.train()
|
||||
trainer.stop()
|
||||
|
||||
|
||||
def test_step_after_completion(ray_start_2_cpus): # noqa: F811
|
||||
trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
|
||||
trainer = trainable_cls(config={"epochs": 1})
|
||||
with pytest.raises(RuntimeError):
|
||||
for i in range(10):
|
||||
trainer.train()
|
||||
|
||||
|
||||
def test_save_checkpoint(ray_start_2_cpus): # noqa: F811
|
||||
trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
|
||||
trainer = trainable_cls(config={"epochs": 1})
|
||||
trainer.train()
|
||||
path = trainer.save()
|
||||
model_state_dict, opt_state_dict = torch.load(path)
|
||||
trainer.stop()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enabled_checkpoint", [True, False])
|
||||
def test_simple_tune(ray_start_4_cpus, enabled_checkpoint):
|
||||
trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
|
||||
analysis = tune.run(
|
||||
trainable_cls,
|
||||
config={"enable_checkpoint": enabled_checkpoint},
|
||||
num_samples=2,
|
||||
stop={"training_iteration": 2})
|
||||
assert analysis.trials[0].last_result["training_iteration"] == 2
|
||||
assert analysis.trials[0].has_checkpoint() == enabled_checkpoint
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rank", [0, 1])
|
||||
def test_checkpoint(ray_start_2_cpus, rank): # noqa: F811
|
||||
with patch("torch.distributed.get_rank") as rank_method:
|
||||
rank_method.return_value = rank
|
||||
with distributed_checkpoint(label="test") as path:
|
||||
if rank == 0:
|
||||
assert path
|
||||
else:
|
||||
assert path == os.devnull
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
@@ -12,7 +12,13 @@ try:
|
||||
BaseTorchTrainable)
|
||||
|
||||
from ray.util.sgd.torch.training_operator import TrainingOperator
|
||||
from ray.util.sgd.torch.func_trainable import (DistributedTrainableCreator,
|
||||
distributed_checkpoint)
|
||||
|
||||
__all__ = ["TorchTrainer", "BaseTorchTrainable", "TrainingOperator"]
|
||||
except ImportError:
|
||||
__all__ = [
|
||||
"TorchTrainer", "BaseTorchTrainable", "TrainingOperator",
|
||||
"distributed_checkpoint", "DistributedTrainableCreator"
|
||||
]
|
||||
except ImportError as e:
|
||||
logger.warning(e)
|
||||
logger.warning("PyTorch not found. TorchTrainer will not be available")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
import io
|
||||
import os
|
||||
@@ -8,7 +7,7 @@ import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
|
||||
from ray.util.sgd.torch.utils import setup_process_group
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner
|
||||
@@ -47,31 +46,19 @@ class DistributedTorchRunner(TorchRunner):
|
||||
def setup(self):
|
||||
raise RuntimeError("Need to call setup commands separately.")
|
||||
|
||||
def setup_process_group(self, url, world_rank, world_size):
|
||||
def setup_process_group(self, url, world_rank, world_size, timeout):
|
||||
"""Connects the distributed PyTorch backend.
|
||||
|
||||
Args:
|
||||
url (str): the URL used to connect to distributed PyTorch.
|
||||
world_rank (int): the index of the runner.
|
||||
world_size (int): the total number of runners.
|
||||
timeout (timedelta): Seconds for process group
|
||||
operations to timeout.
|
||||
"""
|
||||
self.world_rank = world_rank
|
||||
logger.debug("Connecting to {} world_rank: {} world_size: {}".format(
|
||||
url, world_rank, world_size))
|
||||
logger.debug("using {}".format(self.backend))
|
||||
if self.backend == "nccl" and "NCCL_BLOCKING_WAIT" not in os.environ:
|
||||
logger.debug(
|
||||
"Setting NCCL_BLOCKING_WAIT for detecting node failure. "
|
||||
"To override this behavior, you can set NCCL_BLOCKING_WAIT=0.")
|
||||
os.environ["NCCL_BLOCKING_WAIT"] = "1"
|
||||
|
||||
timeout = timedelta(seconds=NCCL_TIMEOUT_S)
|
||||
dist.init_process_group(
|
||||
backend=self.backend,
|
||||
init_method=url,
|
||||
rank=world_rank,
|
||||
world_size=world_size,
|
||||
timeout=timeout)
|
||||
setup_process_group(
|
||||
url, world_rank, world_size, timeout, backend=self.backend)
|
||||
|
||||
def setup_ddp_and_operator(self):
|
||||
"""Runs distributed coordination components.
|
||||
|
||||
@@ -0,0 +1,235 @@
|
||||
# Original Code here:
|
||||
# https://github.com/pytorch/examples/blob/master/mnist/main.py
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
from datetime import timedelta
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.result import RESULT_DUPLICATE
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.trainable import TrainableUtil
|
||||
from ray.util.sgd.torch.utils import setup_process_group
|
||||
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
|
||||
from ray.util.sgd.torch.utils import setup_address
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def logger_creator(log_config, logdir, rank):
|
||||
worker_dir = os.path.join(logdir, "worker_{}".format(rank))
|
||||
os.makedirs(worker_dir, exist_ok=True)
|
||||
return NoopLogger(log_config, worker_dir)
|
||||
|
||||
|
||||
class _TorchTrainable(tune.Trainable):
|
||||
"""Base class for distributed training on Tune.
|
||||
|
||||
A wrapper class is needed to actually create a working
|
||||
version of this trainable.
|
||||
"""
|
||||
_function = None
|
||||
_num_workers = None
|
||||
_use_gpu = None
|
||||
_num_cpus_per_worker = None
|
||||
|
||||
__slots__ = ["workers", "_finished"]
|
||||
|
||||
@classmethod
|
||||
def default_process_group_parameters(self):
|
||||
return dict(timeout=timedelta(NCCL_TIMEOUT_S), backend="gloo")
|
||||
|
||||
@classmethod
|
||||
def get_remote_worker_options(self):
|
||||
num_gpus = 1 if self._use_gpu else 0
|
||||
num_cpus = int(self._num_cpus_per_worker or 1)
|
||||
return dict(num_cpus=num_cpus, num_gpus=num_gpus)
|
||||
|
||||
def setup(self, config):
|
||||
self._finished = False
|
||||
num_workers = self._num_workers
|
||||
logdir = self.logdir
|
||||
assert self._function
|
||||
|
||||
func_trainable = wrap_function(self.__class__._function)
|
||||
|
||||
remote_trainable = ray.remote(func_trainable)
|
||||
remote_trainable = remote_trainable.options(
|
||||
**self.get_remote_worker_options())
|
||||
|
||||
address = setup_address()
|
||||
self.workers = [
|
||||
remote_trainable.remote(
|
||||
config=config,
|
||||
logger_creator=lambda cfg: logger_creator(cfg, logdir, rank))
|
||||
for rank in range(num_workers)
|
||||
]
|
||||
|
||||
pgroup_params = self.default_process_group_parameters()
|
||||
from functools import partial
|
||||
setup_on_worker = partial(
|
||||
setup_process_group,
|
||||
url=address,
|
||||
world_size=num_workers,
|
||||
**pgroup_params)
|
||||
ray.get([
|
||||
w.execute.remote(lambda _: setup_on_worker(world_rank=rank))
|
||||
for rank, w in enumerate(self.workers)
|
||||
])
|
||||
|
||||
def step(self):
|
||||
if self._finished:
|
||||
raise RuntimeError("Training has already finished.")
|
||||
result = ray.get([w.step.remote() for w in self.workers])[0]
|
||||
if RESULT_DUPLICATE in result:
|
||||
self._finished = True
|
||||
return result
|
||||
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
# TODO: optimize if colocated
|
||||
save_obj = ray.get(self.workers[0].save_to_object.remote())
|
||||
checkpoint_path = TrainableUtil.create_from_pickle(
|
||||
save_obj, checkpoint_dir)
|
||||
return checkpoint_path
|
||||
|
||||
def load_checkpoint(self, checkpoint_dir):
|
||||
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
|
||||
return ray.get(
|
||||
w.restore_from_object.remote(checkpoint_obj) for w in self.workers)
|
||||
|
||||
def stop(self):
|
||||
ray.get([worker.stop.remote() for worker in self.workers])
|
||||
|
||||
|
||||
def DistributedTrainableCreator(func,
|
||||
use_gpu=False,
|
||||
num_workers=1,
|
||||
num_cpus_per_worker=1,
|
||||
backend="gloo",
|
||||
timeout_s=NCCL_TIMEOUT_S):
|
||||
"""Creates a class that executes distributed training.
|
||||
|
||||
Note that you typically should not instantiate the object
|
||||
created.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
trainable_cls = DistributedTrainableCreator(
|
||||
train_func, num_workers=2)
|
||||
analysis = tune.run(trainable_cls)
|
||||
"""
|
||||
|
||||
class WrappedDistributedTorchTrainable(_TorchTrainable):
|
||||
_function = func
|
||||
_num_workers = num_workers
|
||||
_use_gpu = use_gpu
|
||||
_num_cpus_per_worker = num_cpus_per_worker
|
||||
|
||||
@classmethod
|
||||
def default_process_group_parameters(self):
|
||||
return dict(timeout=timedelta(timeout_s), backend=backend)
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
num_workers_ = int(config.get("num_workers", num_workers))
|
||||
num_cpus = int(
|
||||
config.get("num_cpus_per_worker", num_cpus_per_worker))
|
||||
use_gpu_ = config.get("use_gpu", use_gpu)
|
||||
|
||||
return Resources(
|
||||
cpu=0,
|
||||
gpu=0,
|
||||
extra_cpu=num_cpus * num_workers_,
|
||||
extra_gpu=num_workers_ if use_gpu_ else 0)
|
||||
|
||||
return WrappedDistributedTorchTrainable
|
||||
|
||||
|
||||
class distributed_checkpoint:
|
||||
"""ContextManager for creating a distributed checkpoint.
|
||||
|
||||
Only checkpoints a file on the "main" training actor, avoiding
|
||||
redundant work.
|
||||
|
||||
Args:
|
||||
label (int | str): Used to label the checkpoint
|
||||
disable (bool): Disable for prototyping.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
if epoch % 3 == 0:
|
||||
with distributed_checkpoint(label=epoch) as path:
|
||||
torch.save(model.state_dict(), path)
|
||||
"""
|
||||
|
||||
def __init__(self, label, disable=False):
|
||||
self.label = label
|
||||
self.file = None
|
||||
self.disable = disable
|
||||
|
||||
def __enter__(self):
|
||||
if torch.distributed.get_rank() == 0 and not self.disable:
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=self.label)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
else:
|
||||
path = os.devnull
|
||||
self.file = path
|
||||
return path
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
if torch.distributed.get_rank() == 0 and not self.disable:
|
||||
tune.save_checkpoint(self.file)
|
||||
|
||||
|
||||
def _train_simple(config, checkpoint=False):
|
||||
"""For testing only. Putting this here because Ray has problems
|
||||
serializing within the test file."""
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
import torch.optim as optim
|
||||
# N is batch size; D_in is input dimension;
|
||||
# H is hidden dimension; D_out is output dimension.
|
||||
N, D_in, H, D_out = 8, 5, 5, 5
|
||||
|
||||
# Create random Tensors to hold inputs and outputs
|
||||
x = torch.randn(N, D_in)
|
||||
y = torch.randn(N, D_out)
|
||||
loss_fn = nn.MSELoss()
|
||||
|
||||
# Use the nn package to define our model and loss function.
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(D_in, H),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(H, D_out),
|
||||
)
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.1)
|
||||
|
||||
if checkpoint:
|
||||
with open(checkpoint) as f:
|
||||
model_state, optimizer_state = torch.load(f)
|
||||
|
||||
model.load_state_dict(model_state)
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
model = DistributedDataParallel(model)
|
||||
|
||||
for epoch in range(config.get("epochs", 10)):
|
||||
optimizer.zero_grad()
|
||||
output = model(x)
|
||||
loss = loss_fn(output, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if epoch % 3 == 0:
|
||||
if config.get("enable_checkpoint", True):
|
||||
with distributed_checkpoint(label=epoch) as path:
|
||||
torch.save((model.state_dict(), optimizer.state_dict()),
|
||||
path)
|
||||
tune.report(mean_loss=loss.item())
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import timedelta
|
||||
import numpy as np
|
||||
import logging
|
||||
import os
|
||||
@@ -16,7 +17,8 @@ from ray.util.sgd.torch.distributed_torch_runner import (
|
||||
DistributedTorchRunner, LocalDistributedRunner, DeactivatedRunner)
|
||||
from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner
|
||||
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP
|
||||
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP, NCCL_TIMEOUT_S
|
||||
from ray.util.sgd.torch.utils import setup_address
|
||||
from ray.util.sgd.data import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -180,6 +182,7 @@ class TorchTrainer:
|
||||
use_gpu="auto",
|
||||
backend="auto",
|
||||
wrap_ddp=True,
|
||||
timeout_s=NCCL_TIMEOUT_S,
|
||||
serialize_data_creation=True,
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
@@ -248,6 +251,7 @@ class TorchTrainer:
|
||||
|
||||
self.serialize_data_creation = serialize_data_creation
|
||||
self.wrap_ddp = wrap_ddp
|
||||
self.timeout_s = timeout_s
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_tqdm = use_tqdm
|
||||
self.add_dist_sampler = add_dist_sampler
|
||||
@@ -347,10 +351,7 @@ class TorchTrainer:
|
||||
self.apply_all_workers(self.initialization_hook)
|
||||
|
||||
# Compute URL for initializing distributed PyTorch
|
||||
ip = ray.services.get_node_ip_address()
|
||||
port = self.local_worker.find_free_port()
|
||||
|
||||
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
|
||||
address = setup_address()
|
||||
|
||||
# Runs the creator functions.
|
||||
remote_component_setup = [
|
||||
@@ -363,10 +364,12 @@ class TorchTrainer:
|
||||
|
||||
# Setup the process group among all workers.
|
||||
remote_pgroup_setups = [
|
||||
worker.setup_process_group.remote(address, i + 1, num_workers)
|
||||
worker.setup_process_group.remote(address, i + 1, num_workers,
|
||||
timedelta(self.timeout_s))
|
||||
for i, worker in enumerate(self.remote_workers)
|
||||
]
|
||||
self.local_worker.setup_process_group(address, 0, num_workers)
|
||||
self.local_worker.setup_process_group(address, 0, num_workers,
|
||||
timedelta(self.timeout_s))
|
||||
# Get setup tasks in order to throw errors on failure
|
||||
ray.get(remote_pgroup_setups)
|
||||
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import os
|
||||
import logging
|
||||
import torch.distributed as dist
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.utils import find_free_port
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_address():
|
||||
ip = ray.services.get_node_ip_address()
|
||||
port = find_free_port()
|
||||
return "tcp://{ip}:{port}".format(ip=ip, port=port)
|
||||
|
||||
|
||||
def setup_process_group(url, world_rank, world_size, timeout, backend="gloo"):
|
||||
"""Connects the distributed PyTorch backend.
|
||||
|
||||
Args:
|
||||
url (str): the URL used to connect to distributed PyTorch.
|
||||
world_rank (int): the index of the runner.
|
||||
world_size (int): the total number of runners.
|
||||
timeout (timedelta): Timeout for operations executed against
|
||||
the process group.
|
||||
backend (str): One of gloo or nccl. Depending on
|
||||
build-time configuration
|
||||
"""
|
||||
logger.debug("Connecting to {} world_rank: {} world_size: {}".format(
|
||||
url, world_rank, world_size))
|
||||
logger.debug("using {}".format(backend))
|
||||
if backend == "nccl" and "NCCL_BLOCKING_WAIT" not in os.environ:
|
||||
logger.debug(
|
||||
"Setting NCCL_BLOCKING_WAIT for detecting node failure. "
|
||||
"To override this behavior, you can set NCCL_BLOCKING_WAIT=0.")
|
||||
os.environ["NCCL_BLOCKING_WAIT"] = "1"
|
||||
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
init_method=url,
|
||||
rank=world_rank,
|
||||
world_size=world_size,
|
||||
timeout=timeout)
|
||||
Reference in New Issue
Block a user