From d5a7c53908018c72a663859990db016d86a4ac5e Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Tue, 15 Sep 2020 11:58:57 -0700 Subject: [PATCH] [Ray SGD] use_local flag + Worker group abstraction (#10539) Co-authored-by: Richard Liaw --- doc/source/raysgd/raysgd_pytorch.rst | 6 +- python/ray/tune/integration/torch.py | 5 +- python/ray/util/sgd/BUILD | 8 + python/ray/util/sgd/data/dataset.py | 4 +- python/ray/util/sgd/tests/test_torch.py | 240 +++----- .../ray/util/sgd/tests/test_torch_failure.py | 175 ++++++ .../sgd/torch/distributed_torch_runner.py | 31 +- .../torch/examples/cifar_pytorch_example.py | 2 +- .../torch/examples/raysgd_torch_signatures.py | 20 +- .../util/sgd/torch/examples/tune_example.py | 2 +- python/ray/util/sgd/torch/torch_runner.py | 37 +- python/ray/util/sgd/torch/torch_trainer.py | 339 ++++------- .../ray/util/sgd/torch/training_operator.py | 6 +- python/ray/util/sgd/torch/utils.py | 2 +- python/ray/util/sgd/torch/worker_group.py | 574 ++++++++++++++++++ 15 files changed, 1025 insertions(+), 426 deletions(-) create mode 100644 python/ray/util/sgd/tests/test_torch_failure.py create mode 100644 python/ray/util/sgd/torch/worker_group.py diff --git a/doc/source/raysgd/raysgd_pytorch.rst b/doc/source/raysgd/raysgd_pytorch.rst index f6ae551e1..0f09f031e 100644 --- a/doc/source/raysgd/raysgd_pytorch.rst +++ b/doc/source/raysgd/raysgd_pytorch.rst @@ -32,7 +32,7 @@ The :ref:`ref-torch-trainer` can be constructed from a custom :ref:`ref-torch-o :start-after: __torch_operator_start__ :end-before: __torch_operator_end__ -Under the hood, ``TorchTrainer`` will create *replicas* of your model (controlled by ``num_workers``), each of which is managed by a Ray actor. One of the replicas will be on the main process, which can simplify the debugging and logging experience. +Under the hood, ``TorchTrainer`` will create *replicas* of your model (controlled by ``num_workers``), each of which is managed by a Ray actor. Before instantiating the trainer, first start or connect to a Ray cluster: @@ -288,8 +288,8 @@ However, if you have these creator functions already and do not want to change y .. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py :language: python - :start-after: __backwards_compat__start - :end-before: __backwards_compat_end + :start-after: __backwards_compat_start__ + :end-before: __backwards_compat_end__ Initialization Functions ------------------------ diff --git a/python/ray/tune/integration/torch.py b/python/ray/tune/integration/torch.py index f7612a82a..ba582d233 100644 --- a/python/ray/tune/integration/torch.py +++ b/python/ray/tune/integration/torch.py @@ -75,7 +75,6 @@ class _TorchTrainable(tune.Trainable): remote_trainable = remote_trainable.options( **self.get_remote_worker_options()) - address = setup_address() self.workers = [ remote_trainable.remote( config=config, @@ -83,6 +82,10 @@ class _TorchTrainable(tune.Trainable): for rank in range(num_workers) ] + # Address has to be IP of rank 0 worker's node. + address = ray.get( + self.workers[0].execute.remote(lambda _: setup_address())) + pgroup_params = self.default_process_group_parameters() from functools import partial setup_on_worker = partial( diff --git a/python/ray/util/sgd/BUILD b/python/ray/util/sgd/BUILD index 664ac40a3..081bc97d6 100644 --- a/python/ray/util/sgd/BUILD +++ b/python/ray/util/sgd/BUILD @@ -18,6 +18,14 @@ py_test( deps = [":sgd_lib"], ) +py_test( + name = "test_torch_failure", + size = "large", + srcs = ["tests/test_torch_failure.py"], + tags = ["exclusive", "pytorch"], + deps = [":sgd_lib"], +) + py_test( name = "test_torch_runner", size = "small", diff --git a/python/ray/util/sgd/data/dataset.py b/python/ray/util/sgd/data/dataset.py index 44ced77ef..544db1319 100644 --- a/python/ray/util/sgd/data/dataset.py +++ b/python/ray/util/sgd/data/dataset.py @@ -84,8 +84,8 @@ class Dataset(): Returns a single, iterable shard. """ assert i < self.iter.num_shards(), \ - "Trying to get shard {} but there are only {} shards." + \ - "Are you sure you called set_num_shards already?".format( + "Trying to get shard {} but there are only {} shards. Are you " \ + "sure you called set_num_shards already?".format( i, self.iter.num_shards() ) diff --git a/python/ray/util/sgd/tests/test_torch.py b/python/ray/util/sgd/tests/test_torch.py index f7493bc91..48197d7b5 100644 --- a/python/ray/util/sgd/tests/test_torch.py +++ b/python/ray/util/sgd/tests/test_torch.py @@ -1,8 +1,6 @@ -from unittest.mock import patch import numpy as np import os import pytest -import time import torch import torch.nn as nn import torch.distributed as dist @@ -14,8 +12,7 @@ from ray.util.sgd.torch import TorchTrainer from ray.util.sgd.torch.training_operator import ( get_test_operator, get_test_metrics_operator, TrainingOperator) from ray.util.sgd.torch.constants import SCHEDULER_STEP -from ray.util.sgd.utils import (check_for_failure, NUM_SAMPLES, BATCH_COUNT, - BATCH_SIZE) +from ray.util.sgd.utils import (NUM_SAMPLES, BATCH_COUNT, BATCH_SIZE) from ray.util.sgd.data.examples import mlp_identity from ray.util.sgd.torch.examples.train_example import ( @@ -48,8 +45,13 @@ Operator = TrainingOperator.from_creators( model_creator, optimizer_creator, data_creator, loss_creator=nn.MSELoss) -def test_single_step(ray_start_2_cpus): # noqa: F811 - trainer = TorchTrainer(training_operator_cls=Operator, num_workers=1) +@pytest.mark.parametrize("use_local", [True, False]) +def test_single_step(ray_start_2_cpus, use_local): # noqa: F811 + trainer = TorchTrainer( + training_operator_cls=Operator, + num_workers=1, + use_local=use_local, + use_gpu=False) metrics = trainer.train(num_steps=1) assert metrics[BATCH_COUNT] == 1 @@ -58,9 +60,14 @@ def test_single_step(ray_start_2_cpus): # noqa: F811 trainer.shutdown() -def test_dead_trainer(ray_start_2_cpus): # noqa: F811 +@pytest.mark.parametrize("use_local", [True, False]) +def test_dead_trainer(ray_start_2_cpus, use_local): # noqa: F811 TestOperator = get_test_operator(Operator) - trainer = TorchTrainer(training_operator_cls=TestOperator, num_workers=2) + trainer = TorchTrainer( + training_operator_cls=TestOperator, + num_workers=2, + use_local=use_local, + use_gpu=False) trainer.train(num_steps=1) trainer.shutdown() with pytest.raises(RuntimeError): @@ -68,9 +75,13 @@ def test_dead_trainer(ray_start_2_cpus): # noqa: F811 @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) -def test_train(ray_start_2_cpus, num_workers): # noqa: F811 +@pytest.mark.parametrize("use_local", [True, False]) +def test_train(ray_start_2_cpus, num_workers, use_local): # noqa: F811 trainer = TorchTrainer( - training_operator_cls=Operator, num_workers=num_workers) + training_operator_cls=Operator, + num_workers=num_workers, + use_local=use_local, + use_gpu=False) for i in range(3): train_loss1 = trainer.train()["train_loss"] validation_loss1 = trainer.validate()["val_loss"] @@ -86,7 +97,8 @@ def test_train(ray_start_2_cpus, num_workers): # noqa: F811 @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) -def test_multi_model(ray_start_2_cpus, num_workers): +@pytest.mark.parametrize("use_local", [True, False]) +def test_multi_model(ray_start_2_cpus, num_workers, use_local): def train(*, model=None, criterion=None, optimizer=None, iterator=None): model.train() train_loss = 0 @@ -140,7 +152,10 @@ def test_multi_model(ray_start_2_cpus, num_workers): trainer1 = TorchTrainer( config={"custom_func": train_epoch}, training_operator_cls=TestOperator, - num_workers=num_workers) + num_workers=num_workers, + use_local=use_local, + use_gpu=False, + ) trainer1.train() state = trainer1.state_dict() @@ -151,7 +166,10 @@ def test_multi_model(ray_start_2_cpus, num_workers): trainer2 = TorchTrainer( config={"custom_func": train_epoch}, training_operator_cls=TestOperator, - num_workers=num_workers) + num_workers=num_workers, + use_local=use_local, + use_gpu=False, + ) trainer2.load_state_dict(state) models2 = trainer2.get_model() @@ -170,7 +188,9 @@ def test_multi_model(ray_start_2_cpus, num_workers): @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) -def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811 +@pytest.mark.parametrize("use_local", [True, False]) +def test_multi_model_matrix(ray_start_2_cpus, num_workers, use_local): # + # noqa: F811 def train_epoch(self, iterator, info): if self.config.get("models", 1) > 1: assert len(self.models) == self.config["models"], self.config @@ -231,6 +251,7 @@ def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811 scheduler_step_freq="epoch", training_operator_cls=TestOperator, num_workers=num_workers, + use_local=use_local, config={ "models": model_count, "optimizers": optimizer_count, @@ -242,7 +263,8 @@ def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811 @pytest.mark.parametrize("scheduler_freq", ["epoch", "batch", "manual", None]) -def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811 +def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: + # F811 def train_epoch(self, iterator, info): assert info[SCHEDULER_STEP] == scheduler_freq return {"done": 1} @@ -271,20 +293,23 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811 trainer = TorchTrainer( config={"custom_func": train_epoch}, training_operator_cls=TestTrainingOperator, - scheduler_step_freq=scheduler_freq) + scheduler_step_freq=scheduler_freq, + ) else: trainer = TorchTrainer( config={"custom_func": train_epoch}, training_operator_cls=TestTrainingOperator, - scheduler_step_freq=scheduler_freq) + scheduler_step_freq=scheduler_freq, + ) for i in range(3): trainer.train() trainer.shutdown() -def test_profiling(ray_start_2_cpus): # noqa: F811 - trainer = TorchTrainer(training_operator_cls=Operator) +@pytest.mark.parametrize("use_local", [True, False]) +def test_profiling(ray_start_2_cpus, use_local): # noqa: F811 + trainer = TorchTrainer(training_operator_cls=Operator, use_local=use_local) stats = trainer.train(profile=True) assert "profile" in stats @@ -293,7 +318,8 @@ def test_profiling(ray_start_2_cpus): # noqa: F811 trainer.shutdown() -def test_dataset(ray_start_4_cpus): +@pytest.mark.parametrize("use_local", [True, False]) +def test_dataset(ray_start_4_cpus, use_local): """ This test tries training the mlp_identity example. We check the accuracy of the model as an all inclusive way of ensuring that we are properly sharding @@ -312,6 +338,7 @@ def test_dataset(ray_start_4_cpus): trainer = TorchTrainer( training_operator_cls=DatasetOperator, + use_local=use_local, num_workers=2, ) @@ -319,13 +346,14 @@ def test_dataset(ray_start_4_cpus): for i in range(5): trainer.train(dataset=dataset, num_steps=100) - input = mlp_identity.to_mat(0.5) - prediction = float(trainer.get_model()(input)[0][0]) + x = mlp_identity.to_mat(0.5) + prediction = float(trainer.get_model()(x)[0][0]) assert 0.4 <= prediction <= 0.6 trainer.shutdown() -def test_split_batch(ray_start_2_cpus): +@pytest.mark.parametrize("use_local", [True, False]) +def test_split_batch(ray_start_2_cpus, use_local): if not dist.is_available(): return @@ -347,6 +375,7 @@ def test_split_batch(ray_start_2_cpus): trainer = TorchTrainer( training_operator_cls=TestOperator, num_workers=2, + use_local=use_local, config={ BATCH_SIZE: batch_size, "data_size": data_size, @@ -358,7 +387,8 @@ def test_split_batch(ray_start_2_cpus): trainer.shutdown() -def test_reduce_result(ray_start_2_cpus): +@pytest.mark.parametrize("use_local", [True, False]) +def test_reduce_result(ray_start_2_cpus, use_local): if not dist.is_available(): return @@ -380,6 +410,7 @@ def test_reduce_result(ray_start_2_cpus): trainer = TorchTrainer( training_operator_cls=TestOperator, num_workers=2, + use_local=use_local, config={"data_size": data_size}) list_stats = trainer.train(reduce_results=False, profile=True) assert len(list_stats) == 2 @@ -393,7 +424,8 @@ def test_reduce_result(ray_start_2_cpus): @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) -def test_metrics(ray_start_2_cpus, num_workers): +@pytest.mark.parametrize("use_local", [True, False]) +def test_metrics(ray_start_2_cpus, num_workers, use_local): data_size, val_size = 600, 500 batch_size = 4 @@ -407,6 +439,7 @@ def test_metrics(ray_start_2_cpus, num_workers): trainer = TorchTrainer( training_operator_cls=TestOperator, num_workers=num_workers, + use_local=use_local, config={ "scores": train_scores, "val_scores": val_scores, @@ -440,7 +473,8 @@ def test_metrics(ray_start_2_cpus, num_workers): @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) -def test_metrics_nan(ray_start_2_cpus, num_workers): +@pytest.mark.parametrize("use_local", [True, False]) +def test_metrics_nan(ray_start_2_cpus, num_workers, use_local): data_size, val_size = 100, 100 batch_size = 10 @@ -453,6 +487,7 @@ def test_metrics_nan(ray_start_2_cpus, num_workers): trainer = TorchTrainer( training_operator_cls=TestOperator, num_workers=num_workers, + use_local=use_local, config={ "scores": train_scores, "val_scores": val_scores, @@ -474,7 +509,8 @@ def test_metrics_nan(ray_start_2_cpus, num_workers): trainer.shutdown() -def test_scheduler_validate(ray_start_2_cpus): # noqa: F811 +@pytest.mark.parametrize("use_local", [True, False]) +def test_scheduler_validate(ray_start_2_cpus, use_local): # noqa: F811 from torch.optim.lr_scheduler import ReduceLROnPlateau TestOperator = TrainingOperator.from_creators( @@ -485,7 +521,9 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811 loss_creator=lambda config: nn.MSELoss()) TestOperator = get_test_operator(TestOperator) trainer = TorchTrainer( - scheduler_step_freq="manual", training_operator_cls=TestOperator) + scheduler_step_freq="manual", + training_operator_cls=TestOperator, + use_local=use_local) trainer.update_scheduler(0.5) trainer.update_scheduler(0.5) assert all( @@ -494,14 +532,16 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811 trainer.shutdown() -@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) -def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811 +@pytest.mark.parametrize("num_workers", [2] if dist.is_available() else [1]) +@pytest.mark.parametrize("use_local", [True, False]) +def test_tune_train(ray_start_4_cpus, num_workers, use_local): # noqa: F811 TorchTrainable = TorchTrainer.as_trainable( **{ "training_operator_cls": Operator, "num_workers": num_workers, "use_gpu": False, "backend": "gloo", + "use_local": use_local, "config": { "batch_size": 512, "lr": 0.001 @@ -526,10 +566,13 @@ def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811 @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) -def test_save_and_restore(ray_start_2_cpus, num_workers, +@pytest.mark.parametrize("use_local", [True, False]) +def test_save_and_restore(ray_start_2_cpus, num_workers, use_local, tmp_path): # noqa: F811 trainer1 = TorchTrainer( - training_operator_cls=Operator, num_workers=num_workers) + training_operator_cls=Operator, + num_workers=num_workers, + use_local=use_local) trainer1.train() checkpoint_path = os.path.join(tmp_path, "checkpoint") trainer1.save(checkpoint_path) @@ -539,7 +582,9 @@ def test_save_and_restore(ray_start_2_cpus, num_workers, trainer1.shutdown() trainer2 = TorchTrainer( - training_operator_cls=Operator, num_workers=num_workers) + training_operator_cls=Operator, + num_workers=num_workers, + use_local=use_local) trainer2.load(checkpoint_path) model2 = trainer2.get_model() @@ -558,14 +603,15 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811 if not dist.is_available(): return trainer1 = TorchTrainer( - training_operator_cls=Operator, wrap_ddp=False, num_workers=2) + training_operator_cls=Operator, + wrap_ddp=False, + num_workers=2, + use_local=True) trainer1.train() checkpoint_path = os.path.join(tmp_path, "checkpoint") trainer1.save(checkpoint_path) model1 = trainer1.get_model() - assert not hasattr(trainer1.local_worker.training_operator.model, "module") - assert hasattr(trainer1.local_worker.training_operator, "device_ids") trainer1.shutdown() trainer2 = TorchTrainer( @@ -584,123 +630,8 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811 trainer2.shutdown() -def gen_step_with_fail(num_fails): - def step_with_fail(self, - num_steps=None, - profile=False, - info=None, - dataset=None): - params = dict(num_steps=num_steps, profile=profile, info=info) - remote_worker_stats = [ - w.train_epoch.remote(**params) for w in self.remote_workers - ] - - if self._num_failures < num_fails: - time.sleep(1) # Make the batch will fail correctly. - ray.kill(self.remote_workers[0]) - - try: - local_worker_stats = self.local_worker.train_epoch(**params) - except RuntimeError: - return False, None - - success = check_for_failure(remote_worker_stats) - if success: - return success, [local_worker_stats] + ray.get(remote_worker_stats) - - return success, None - - return step_with_fail - - -def test_fail_with_recover(ray_start_2_cpus): # noqa: F811 - if not dist.is_available(): - return - - def single_loader(config): - dataset = LinearDataset(2, 5, size=1000000) - return DataLoader(dataset, batch_size=config.get("batch_size", 32)) - - step_with_fail = gen_step_with_fail(3) - - TestOperator = TrainingOperator.from_creators( - model_creator, - optimizer_creator, - single_loader, - loss_creator=lambda config: nn.MSELoss()) - with patch.object(TorchTrainer, "_train_epoch", step_with_fail): - trainer1 = TorchTrainer( - training_operator_cls=TestOperator, - config={"batch_size": 100000}, - num_workers=2) - - with pytest.raises(RuntimeError): - trainer1.train(max_retries=1) - - trainer1.shutdown(force=True) - - -def test_resize(ray_start_2_cpus): # noqa: F811 - if not dist.is_available(): - return - - def single_loader(config): - dataset = LinearDataset(2, 5, size=1000000) - return DataLoader(dataset, batch_size=config.get("batch_size", 32)) - - step_with_fail = gen_step_with_fail(1) - - TestOperator = TrainingOperator.from_creators( - model_creator, - optimizer_creator, - single_loader, - loss_creator=lambda config: nn.MSELoss()) - with patch.object(TorchTrainer, "_train_epoch", step_with_fail): - trainer1 = TorchTrainer( - training_operator_cls=TestOperator, - config={"batch_size": 100000}, - num_workers=2) - - @ray.remote - def try_test(): - import time - time.sleep(100) - - try_test.remote() - trainer1.train(max_retries=1) - assert len(trainer1.remote_workers) == 1 - - trainer1.shutdown() - - -def test_fail_twice(ray_start_2_cpus): # noqa: F811 - if not dist.is_available(): - return - - def single_loader(config): - dataset = LinearDataset(2, 5, size=1000000) - return DataLoader(dataset, batch_size=config.get("batch_size", 32)) - - step_with_fail = gen_step_with_fail(2) - - TestOperator = TrainingOperator.from_creators( - model_creator, - optimizer_creator, - single_loader, - loss_creator=lambda config: nn.MSELoss()) - - with patch.object(TorchTrainer, "_train_epoch", step_with_fail): - trainer1 = TorchTrainer( - training_operator_cls=TestOperator, - config={"batch_size": 100000}, - num_workers=2) - - # MAX RETRIES SHOULD BE ON BY DEFAULT - trainer1.train() - trainer1.shutdown() - - -def test_multi_input_model(ray_start_2_cpus): +@pytest.mark.parametrize("use_local", [True, False]) +def test_multi_input_model(ray_start_2_cpus, use_local): def model_creator(config): class MultiInputModel(nn.Module): def __init__(self): @@ -742,7 +673,8 @@ def test_multi_input_model(ray_start_2_cpus): data_creator, loss_creator=lambda config: nn.MSELoss()) - trainer = TorchTrainer(training_operator_cls=Operator, num_workers=1) + trainer = TorchTrainer( + training_operator_cls=Operator, num_workers=1, use_local=use_local) metrics = trainer.train(num_steps=1) assert metrics[BATCH_COUNT] == 1 diff --git a/python/ray/util/sgd/tests/test_torch_failure.py b/python/ray/util/sgd/tests/test_torch_failure.py new file mode 100644 index 000000000..97324e0a5 --- /dev/null +++ b/python/ray/util/sgd/tests/test_torch_failure.py @@ -0,0 +1,175 @@ +from unittest.mock import patch +import pytest +import time +import torch.nn as nn +import torch.distributed as dist +from torch.utils.data import DataLoader + +import ray +from ray.util.sgd.torch import TorchTrainer +from ray.util.sgd.torch.worker_group import RemoteWorkerGroup +from ray.util.sgd.torch.training_operator import TrainingOperator + +from ray.util.sgd.torch.examples.train_example import ( + model_creator, optimizer_creator, data_creator, LinearDataset) + +Operator = TrainingOperator.from_creators( + model_creator, optimizer_creator, data_creator, loss_creator=nn.MSELoss) + + +@pytest.fixture +def ray_start_2_cpus(): + address_info = ray.init(num_cpus=2) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + # Ensure that tests don't ALL fail + if dist.is_initialized(): + dist.destroy_process_group() + + +@pytest.fixture +def ray_start_4_cpus(): + address_info = ray.init(num_cpus=4) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + # Ensure that tests don't ALL fail + if dist.is_initialized(): + dist.destroy_process_group() + + +def remote_worker_train_with_fail(self, num_steps, profile, info, + dataset=None): + remote_worker_stats = [] + for i, w in enumerate(self.remote_workers): + params = dict(num_steps=num_steps, profile=profile, info=info) + if dataset: + params["iterator"] = dataset.get_shard(i) + stats = w.train_epoch.remote(**params) + remote_worker_stats.append(stats) + if i == 0 and hasattr(self, "should_fail") and self.should_fail: + time.sleep(1) + ray.kill(self.remote_workers[i]) + return remote_worker_stats + + +start_workers = TorchTrainer._start_workers + + +def gen_start_with_fail(num_fails): + def start_with_fail(self, *args, **kwargs): + start_workers(self, *args, **kwargs) + fail = self._num_failures < num_fails + if self.use_local: + self.worker_group.remote_worker_group.should_fail = fail + else: + self.worker_group.should_fail = fail + + return start_with_fail + + +@pytest.mark.parametrize("use_local", [False, True]) +@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail) +def test_resize(ray_start_2_cpus, use_local): # noqa: F811 + if not dist.is_available(): + return + + def single_loader(config): + dataset = LinearDataset(2, 5, size=1000000) + return DataLoader(dataset, batch_size=config.get("batch_size", 32)) + + start_with_fail = gen_start_with_fail(1) + + TestOperator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + single_loader, + loss_creator=lambda config: nn.MSELoss()) + with patch.object(TorchTrainer, "_start_workers", start_with_fail): + trainer1 = TorchTrainer( + training_operator_cls=TestOperator, + config={"batch_size": 100000}, + use_local=use_local, + num_workers=2) + + @ray.remote + def try_test(): + import time + time.sleep(100) + + try_test.remote() + trainer1.train(max_retries=1) + assert trainer1.worker_group.num_workers == 1 + + trainer1.shutdown(force=True) + + +@pytest.mark.parametrize("use_local", [False, True]) +@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail) +def test_fail_twice(ray_start_2_cpus, use_local): # noqa: F811 + if not dist.is_available(): + return + + def single_loader(config): + dataset = LinearDataset(2, 5, size=1000000) + return DataLoader(dataset, batch_size=config.get("batch_size", 32)) + + TestOperator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + single_loader, + loss_creator=lambda config: nn.MSELoss()) + + start_with_fail = gen_start_with_fail(2) + + with patch.object(TorchTrainer, "_start_workers", start_with_fail): + trainer1 = TorchTrainer( + training_operator_cls=TestOperator, + config={"batch_size": 100000}, + use_local=use_local, + num_workers=2) + + # MAX RETRIES SHOULD BE ON BY DEFAULT + trainer1.train() + trainer1.shutdown(force=True) + + +@pytest.mark.parametrize("use_local", [False, True]) +@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail) +def test_fail_with_recover(ray_start_2_cpus, use_local): # noqa: F811 + print(locals()) + if not dist.is_available(): + return + + def single_loader(config): + dataset = LinearDataset(2, 5, size=1000000) + return DataLoader(dataset, batch_size=config.get("batch_size", 32)) + + TestOperator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + single_loader, + loss_creator=lambda config: nn.MSELoss()) + + start_with_fail = gen_start_with_fail(3) + + with patch.object(TorchTrainer, "_start_workers", start_with_fail): + trainer1 = TorchTrainer( + training_operator_cls=TestOperator, + config={"batch_size": 100000}, + timeout_s=5, + use_local=use_local, + num_workers=2) + + with pytest.raises(RuntimeError): + trainer1.train(max_retries=1) + + trainer1.shutdown(force=True) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/sgd/torch/distributed_torch_runner.py b/python/ray/util/sgd/torch/distributed_torch_runner.py index 375116427..013ce0916 100644 --- a/python/ray/util/sgd/torch/distributed_torch_runner.py +++ b/python/ray/util/sgd/torch/distributed_torch_runner.py @@ -1,5 +1,4 @@ import logging -import io import os import torch @@ -11,6 +10,8 @@ from ray.util.sgd.torch.utils import setup_process_group import ray from ray.util.sgd.torch.torch_runner import TorchRunner +from ray.util.sgd.torch.utils import setup_address + logger = logging.getLogger(__name__) @@ -42,6 +43,9 @@ class DistributedTorchRunner(TorchRunner): self.add_dist_sampler = add_dist_sampler self.world_rank = None + def setup_address(self): + return setup_address() + def setup_process_group(self, url, world_rank, world_size, timeout): """Connects the distributed PyTorch backend. @@ -52,6 +56,7 @@ class DistributedTorchRunner(TorchRunner): timeout (timedelta): Seconds for process group operations to timeout. """ + logger.info(f"Setting up process group for: {url} [rank={world_rank}]") self.world_rank = world_rank setup_process_group( url, world_rank, world_size, timeout, backend=self.backend) @@ -83,23 +88,6 @@ class DistributedTorchRunner(TorchRunner): """Needed for SyncBatchNorm, which needs 1 GPU per process.""" return [0] - def load_state_stream(self, byte_obj): - """Loads a bytes object the training state dict. - - This is needed because we don't want to deserialize the tensor - onto the same device (which is from the driver process). We want to - map it onto the actor's specific device. - - From: github.com/pytorch/pytorch/issues/10622#issuecomment-474733769 - """ - _buffer = io.BytesIO(byte_obj) - to_gpu = self.use_gpu and torch.cuda.is_available() - state_dict = torch.load( - _buffer, - map_location=("cpu" if not to_gpu else - lambda storage, loc: storage.cuda())) - return self.load_state_dict(state_dict) - def _wrap_dataloaders(self): def with_sampler(loader): # Automatically set the DistributedSampler @@ -346,10 +334,3 @@ class LocalDistributedRunner(DistributedTorchRunner): def is_actor(self): actor_id = ray.worker.global_worker.actor_id return actor_id != actor_id.nil() - - -class DeactivatedRunner: - def __getattr__(self, *args, **kwargs): - raise RuntimeError( - "This TorchTrainer is not active (it is likely shutdown already). " - "Create a new TorchTrainer.") diff --git a/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py b/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py index f51549f23..9a3564613 100644 --- a/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py +++ b/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py @@ -138,7 +138,7 @@ if __name__ == "__main__": use_gpu=args.use_gpu, scheduler_step_freq="epoch", use_fp16=args.fp16, - use_tqdm=True) + use_tqdm=False) pbar = trange(args.num_epochs, unit="epoch") for i in pbar: info = {"num_steps": 1} if args.smoke_test else {} diff --git a/python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py b/python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py index b9f1bb721..042f4cd10 100644 --- a/python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py +++ b/python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py @@ -144,7 +144,14 @@ def scheduler_creator(optimizer, config): # __torch_scheduler_end__ -# __backwards_compat__start +# __torch_ray_start__ +import ray + +ray.init() +# or ray.init(address="auto") to connect to a running cluster. +# __torch_ray_end__ + +# __backwards_compat_start__ from ray.util.sgd import TorchTrainer MyTrainingOperator = TrainingOperator.from_creators( @@ -157,14 +164,9 @@ trainer = TorchTrainer( scheduler_step_freq="epoch", # if scheduler_creator is passed in config={"lr": 0.001, "batch_size": 64}) -# __backwards_compat_end +# __backwards_compat_end__ -# __torch_ray_start__ -import ray - -ray.init() -# or ray.init(address="auto") to connect to a running cluster. -# __torch_ray_end__ +trainer.shutdown() # __torch_trainer_start__ from ray.util.sgd import TorchTrainer @@ -175,3 +177,5 @@ trainer = TorchTrainer( config={"lr": 0.001, "batch_size": 64}) # __torch_trainer_end__ + +trainer.shutdown() diff --git a/python/ray/util/sgd/torch/examples/tune_example.py b/python/ray/util/sgd/torch/examples/tune_example.py index da04bc455..f5c226851 100644 --- a/python/ray/util/sgd/torch/examples/tune_example.py +++ b/python/ray/util/sgd/torch/examples/tune_example.py @@ -52,7 +52,7 @@ def tune_example(operator_cls, num_workers=1, use_gpu=False): stop={"training_iteration": 2}, verbose=1) - return analysis.get_best_config(metric="validation_loss", mode="min") + return analysis.get_best_config(metric="val_loss", mode="min") # __end_torch_tune_example__ diff --git a/python/ray/util/sgd/torch/torch_runner.py b/python/ray/util/sgd/torch/torch_runner.py index 82e47ceb9..3486458cb 100644 --- a/python/ray/util/sgd/torch/torch_runner.py +++ b/python/ray/util/sgd/torch/torch_runner.py @@ -3,7 +3,6 @@ import io import itertools import torch -import ray from ray.util.sgd.torch.constants import USE_FP16, NUM_STEPS from ray.util.sgd import utils @@ -57,14 +56,6 @@ class TorchRunner: apex_args=self.apex_args, scheduler_step_freq=self.scheduler_step_freq) - def get_node_ip(self): - """Returns the IP address of the current node.""" - return ray.services.get_node_ip_address() - - def find_free_port(self): - """Finds a free port on the current node.""" - return utils.find_free_port() - def train_epoch(self, num_steps=None, profile=False, @@ -132,11 +123,15 @@ class TorchRunner: def state_dict(self): """Returns the state of the runner.""" + model_states = [model.state_dict() for model in self.models] + optimizer_states = [ + optimizer.state_dict() for optimizer in self.optimizers + ] state = { "epoch": self.epochs, "operator": self.training_operator.state_dict(), - "models": [model.state_dict() for model in self.models], - "optimizers": [opt.state_dict() for opt in self.optimizers] + "models": model_states, + "optimizers": optimizer_states } schedulers = self.schedulers if schedulers: @@ -148,6 +143,7 @@ class TorchRunner: # Check if fp16 is True and if NVIDIA Apex is imported. if self.use_fp16 and self.training_operator._amp: state.update({"amp": self.training_operator._amp.state_dict()}) + return state def load_state_dict(self, state): @@ -176,9 +172,20 @@ class TorchRunner: return _buffer.getvalue() def load_state_stream(self, byte_obj): - """Loads a bytes object the training state dict.""" + """Loads a bytes object the training state dict. + + This is needed because we don't want to deserialize the tensor + onto the same device (which is from the driver process). We want to + map it onto the actor's specific device. + + From: github.com/pytorch/pytorch/issues/10622#issuecomment-474733769 + """ _buffer = io.BytesIO(byte_obj) - state_dict = torch.load(_buffer) + to_gpu = self.use_gpu and torch.cuda.is_available() + state_dict = torch.load( + _buffer, + map_location=("cpu" if not to_gpu else + lambda storage, loc: storage.cuda())) return self.load_state_dict(state_dict) def apply(self, fn): @@ -193,6 +200,10 @@ class TorchRunner: if torch.cuda.is_available(): torch.cuda.empty_cache() + def get_models(self): + """Getter method. Needed for remote actor calls.""" + return self.models + @property def models(self): if not hasattr(self.training_operator, "_original_models"): diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index ed811d782..805ae37bd 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -1,29 +1,25 @@ -from datetime import timedelta +import time + import numpy as np import logging import os import numbers import tempfile -import time import torch import torch.distributed as dist import ray -from ray.exceptions import RayActorError from ray.tune import Trainable from ray.tune.resources import Resources from ray.tune.utils.util import merge_dicts from ray.util import log_once -from ray.util.sgd.torch.distributed_torch_runner import ( - DistributedTorchRunner, LocalDistributedRunner, DeactivatedRunner) -from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE -from ray.util.sgd.torch.torch_runner import TorchRunner +from ray.util.sgd.torch.worker_group import LocalWorkerGroup, \ + RemoteWorkerGroup, DeactivatedWorkerGroup +from ray.util.sgd.utils import NUM_SAMPLES, BATCH_SIZE from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP, NCCL_TIMEOUT_S -from ray.util.sgd.torch.utils import setup_address from ray.util.sgd.data import Dataset logger = logging.getLogger(__name__) -RESIZE_COOLDOWN_S = 10 def _validate_scheduler_step_freq(scheduler_step_freq): @@ -86,11 +82,17 @@ class TorchTrainer: that subclasses the TrainingOperator class. This class will be copied onto all remote workers and used to specify training components and custom training and validation operations. + initialization_hook (function): A function to call on all training + workers when they are first initialized. This could be useful to + set environment variables for all the worker processes. config (dict): Custom configuration value to be passed to all operator constructors. num_workers (int): the number of workers used in distributed training. If 1, the worker will not be wrapped with - DistributedDataParallel. + DistributedDataParallel. TorchTrainer will scale down the number + of workers if enough resources are not available, and will scale + back up once they are. The total number of + workers will never exceed `num_workers` amount. num_cpus_per_worker (int): Sets the cpu requirement for each worker. use_gpu (bool): Sets resource allocation for workers to 1 GPU if true, and automatically moves both the model and optimizer @@ -119,9 +121,13 @@ class TorchTrainer: ``step`` will be called after every optimizer step. If "epoch", ``step`` will be called after one pass of the DataLoader. If "manual", the scheduler will not be incremented automatically - - you are expected to call ``trainer.update_schedulers`` manually. + you are expected to call ``trainer.update_scheduler`` manually. If a scheduler is passed in, this value is expected to not be None. - + use_local (bool): If True, 1 worker will be a local worker running + on the driver process, and all other workers will be remote. If + False, all workers will be remote. Set this to True for easy + debugging of worker on driver process, but could also + lead to issues with Cuda devices. Defaults to False. """ # TODO: Implement autoscaling. If num_workers=-1, the trainer will use as @@ -146,6 +152,7 @@ class TorchTrainer: apex_args=None, add_dist_sampler=True, scheduler_step_freq=None, + use_local=False, # Deprecated Args. num_replicas=None, batch_size=None, @@ -168,6 +175,13 @@ class TorchTrainer: "model_creator, ...) and pass in CustomOperator into " "TorchTrainer.") + if use_local and log_once("use_local"): + logger.warning("use_local is set to True. This could lead to " + "issues with Cuda devices. If you are seeing this " + "issue, try setting use_local to False. For more " + "information, see " + "https://github.com/ray-project/ray/issues/9202.") + if num_workers > 1 and not dist.is_available(): raise ValueError( ("Distributed PyTorch is not supported on macOS. " @@ -225,6 +239,7 @@ class TorchTrainer: self.use_fp16 = use_fp16 self.use_tqdm = use_tqdm self.add_dist_sampler = add_dist_sampler + self.use_local = use_local if apex_args and not isinstance(apex_args, dict): raise ValueError("apex_args needs to be a dict object.") @@ -234,9 +249,6 @@ class TorchTrainer: self._num_failures = 0 self._last_resize = float("-inf") - self.local_worker = DeactivatedRunner() - self.remote_workers = [] - if scheduler_step_freq: _validate_scheduler_step_freq(scheduler_step_freq) @@ -270,12 +282,10 @@ class TorchTrainer: return batch_size_per_worker def _start_workers(self, num_workers): - logger.debug(f"start_workers: Setting %d workers." % num_workers) worker_config = self.config.copy() batch_size_per_worker = self._configure_and_split_batch(num_workers) if batch_size_per_worker: worker_config[BATCH_SIZE] = batch_size_per_worker - params = dict( training_operator_cls=self.training_operator_cls, config=worker_config, @@ -286,57 +296,68 @@ class TorchTrainer: apex_args=self.apex_args, scheduler_step_freq=self.scheduler_step_freq) - if num_workers == 1: - # Start local worker - self.local_worker = TorchRunner(**params) - if self.initialization_hook: - self.apply_all_workers(self.initialization_hook) - self.local_worker.setup_operator() + dist_params = dict( + backend=self.backend, + add_dist_sampler=self.add_dist_sampler, + wrap_ddp=self.wrap_ddp) + + worker_args = { + "max_workers": num_workers, + "params": params, + "dist_params": dist_params, + "initialization_hook": self.initialization_hook, + "num_cpus_per_worker": self.num_cpus_per_worker, + "use_gpu": self.use_gpu, + "timeout_s": self.timeout_s + } + + if self.use_local: + self.worker_group = LocalWorkerGroup(**worker_args) else: - params.update( - backend=self.backend, - add_dist_sampler=self.add_dist_sampler, - wrap_ddp=self.wrap_ddp) + self.worker_group = RemoteWorkerGroup(**worker_args) - # Start local worker - self.local_worker = LocalDistributedRunner( - num_cpus=self.num_cpus_per_worker, - num_gpus=int(self.use_gpu), - **params) + # TODO(amogkam): If not enough resources are available to create + # num_workers workers, this command will hang. Instead, + # start_workers should take into account available resources when + # determining how many workers to create. + self.worker_group.start_workers(num_workers) - # Generate actor class - RemoteRunner = ray.remote( - num_cpus=self.num_cpus_per_worker, - num_gpus=int(self.use_gpu))(DistributedTorchRunner) - # Start workers - self.remote_workers = [ - RemoteRunner.remote(**params) for i in range(num_workers - 1) - ] - if self.initialization_hook: - self.apply_all_workers(self.initialization_hook) + def _resize_worker_group(self, max_retries=10): + """Resizes the number of remote workers based on available resources. + Total number of workers will never exceed `num_workers` amount. - # Compute URL for initializing distributed PyTorch - address = setup_address() + Args: + max_retries (int): How many times to attempt to resize workers + before failing. + """ + state_dict = self.state_dict() + old_workers = self.worker_group.num_workers + self.worker_group.reset() - # Setup the process group among all workers. - remote_pgroup_setups = [ - worker.setup_process_group.remote(address, i + 1, num_workers, - timedelta(self.timeout_s)) - for i, worker in enumerate(self.remote_workers) - ] - self.local_worker.setup_process_group(address, 0, num_workers, - timedelta(self.timeout_s)) - # Get setup tasks in order to throw errors on failure - ray.get(remote_pgroup_setups) - - # Runs code that requires all creator functions to have run. - remote_operator_setups = [ - worker.setup_operator.remote() - for worker in self.remote_workers - ] - self.local_worker.setup_operator() - # Get setup tasks in order to throw errors on failure - ray.get(remote_operator_setups) + time.sleep(1) + for i in range(max_retries): + new_workers = self.worker_group.new_workers_size() + if new_workers: + self._last_resize = time.time() + self._start_workers(int(new_workers)) + self.load_state_dict(state_dict, blocking=True) + if self.use_local and new_workers == 1 and old_workers > 1: + # Major hack. If we go from LocalDistributedRunner to a + # standard TorchRunner we have to manually reset the + # dummy actor handle global vars. + # TODO(amog): Refactor LocalDistributedTorchRunner to + # not use global variables for resource reservation. + ray.util.sgd.torch.distributed_torch_runner\ + ._dummy_cuda_actor = None + ray.util.sgd.torch.distributed_torch_runner\ + ._dummy_cpu_actor = None + return + else: + delay = 2**i + logger.warning( + "No new workers found. Retrying in %d sec." % delay) + time.sleep(delay) + raise RuntimeError("Exceeded max_retries for relaunching workers.") def train(self, num_steps=None, @@ -384,10 +405,10 @@ class TorchTrainer: assert isinstance(dataset, Dataset) is not None \ or self.data_creator, \ "Must specify either a data creator or a dataset" - if self._should_resize(): + if self.worker_group.should_scale_up(): logger.info("Resize opportunity detected. Attempting to scale up.") - self._resize_workers() - success, worker_stats = self._train_epoch( + self._resize_worker_group() + success, worker_stats = self.worker_group.train( num_steps=num_steps, profile=profile, info=info, dataset=dataset) # Fault handling for i in range(max_retries): @@ -395,10 +416,10 @@ class TorchTrainer: break else: self._num_failures += 1 - self._resize_workers() + self._resize_worker_group() logger.info("Retrying training step with %d workers." % - (len(self.remote_workers) + 1)) - success, worker_stats = self._train_epoch( + self.worker_group.num_workers) + success, worker_stats = self.worker_group.train( num_steps=num_steps, profile=profile, info=info, @@ -425,43 +446,6 @@ class TorchTrainer: stats[stat_key] = worker_stats[0][stat_key] return stats - def _train_epoch(self, - num_steps=None, - profile=False, - info=None, - dataset=None): - params = dict(num_steps=num_steps, profile=profile, info=info) - remote_worker_stats = [] - if dataset: - dataset.set_num_shards(self.max_replicas) - for i, w in enumerate(self.remote_workers): - params = dict(num_steps=num_steps, profile=profile, info=info) - if dataset: - params["iterator"] = dataset.get_shard(i) - stats = w.train_epoch.remote(**params) - remote_worker_stats.append(stats) - - try: - if dataset: - params["iterator"] = dataset.get_shard( - len(self.remote_workers)) - local_worker_stats = self.local_worker.train_epoch(**params) - except RuntimeError as err: - if "gloo" in err.args[0] and "Timed out" in err.args[0]: - logger.warning(err) - return False, None - if "NCCL" in err.args[0]: # there is no specific error message - logger.warning(err) - return False, None - - raise err - - success = check_for_failure(remote_worker_stats) - if success: - return success, [local_worker_stats] + ray.get(remote_worker_stats) - - return success, None - def apply_all_workers(self, fn): """Run a function on all operators on the workers. @@ -472,9 +456,7 @@ class TorchTrainer: A list of objects returned by ``fn`` on each worker. """ - remote_calls = [w.apply.remote(fn) for w in self.remote_workers] - local_call = self.local_worker.apply(fn) - return [local_call] + ray.get(remote_calls) + return self.worker_group.apply_all_workers(fn) def apply_all_operators(self, fn): """Run a function on all operators on the workers. @@ -487,11 +469,7 @@ class TorchTrainer: A list of objects returned by ``fn`` on each operator. """ - remote_calls = [ - w.apply_operator.remote(fn) for w in self.remote_workers - ] - local_call = self.local_worker.apply_operator(fn) - return [local_call] + ray.get(remote_calls) + return self.worker_group.apply_all_operators(fn) def validate(self, num_steps=None, @@ -517,13 +495,8 @@ class TorchTrainer: You can provide custom metrics by passing in a custom ``training_operator_cls``. """ - params = dict(num_steps=num_steps, profile=profile, info=info) - - remote_worker_stats = [ - w.validate.remote(**params) for w in self.remote_workers - ] - local_worker_stats = self.local_worker.validate(**params) - worker_stats = [local_worker_stats] + ray.get(remote_worker_stats) + worker_stats = self.worker_group.validate( + num_steps=num_steps, profile=profile, info=info) if reduce_results: return self._process_stats(worker_stats) @@ -535,13 +508,14 @@ class TorchTrainer: This is useful for lr_schedulers such as ``ReduceLROnPlateau``. """ - self.apply_all_operators( + self.worker_group.apply_all_operators( lambda op: [sched.step(metric) for sched in op._schedulers]) def get_model(self): """Returns the learned model(s).""" unwrapped = [] - for model in self.local_worker.models: + models = self.worker_group.get_model() + for model in models: unwrapped += [model.module if hasattr(model, "module") else model] if len(unwrapped) == 1: return unwrapped[0] @@ -556,23 +530,13 @@ class TorchTrainer: Returns: TrainingOperator: The local TrainingOperator object. """ - return self.local_worker.training_operator + return self.worker_group.get_local_operator() def state_dict(self): - return self.local_worker.state_dict() + return self.worker_group.state_dict() def load_state_dict(self, state_dict, blocking=False): - # This is not the most efficient because you have to wait for - # the local worker to save then dump to buffer. - self.local_worker.load_state_dict(state_dict) - state_id = ray.put(self.local_worker.state_stream()) - - remote_calls = [ - worker.load_state_stream.remote(state_id) - for worker in self.remote_workers - ] - if blocking: - ray.get(remote_calls) + self.worker_group.load_state_dict(state_dict, blocking=blocking) def save(self, checkpoint): """Saves the Trainer state to the provided checkpoint path. @@ -596,81 +560,16 @@ class TorchTrainer: raise DeprecationWarning("Use `TorchTrainer.load()` instead.") def shutdown(self, force=False): - """Shuts down workers and releases resources.""" - if not force: - cleanup = [ - worker.shutdown.remote() for worker in self.remote_workers - ] - self.local_worker.shutdown() - try: - ray.get(cleanup) - [ - worker.__ray_terminate__.remote() - for worker in self.remote_workers - ] - except RayActorError: - logger.warning( - "Failed to shutdown gracefully, forcing a shutdown.") + """Shuts down workers and releases resources. - for worker in self.remote_workers: - logger.warning(f"Killing worker {worker}.") - ray.kill(worker) - else: - self.local_worker.shutdown() - for worker in self.remote_workers: - logger.debug(f"Killing worker {worker}.") - ray.kill(worker) + Args: + force (bool): If True, forcefully kill all workers. If False, + attempt a graceful shutdown first, and then forcefully kill if + unsuccessful. - self.local_worker = DeactivatedRunner() - self.remote_workers = [] - - def _reset(self): - """Terminates models without giving up local resource reservation.""" - self.local_worker.shutdown(cleanup=False) - for worker in self.remote_workers: - logger.debug(f"Killing worker {worker}.") - ray.kill(worker) - self.local_worker = DeactivatedRunner() - self.remote_workers = [] - - def _check_potential_remote_workers_size(self): - # ASSUME 1 GPU + 1 CPU is already reserved for the local worker - remote_resources = ray.available_resources() - max_remote_workers = self.max_replicas - 1 - new_remote_workers = min( - remote_resources.get("CPU", 0), max_remote_workers) - if self.use_gpu: - new_remote_workers = min( - remote_resources.get("GPU", 0), new_remote_workers) - return new_remote_workers - - def _resize_workers(self, max_retries=10): - self._reset() - - time.sleep(1) - for i in range(max_retries): - new_remote_workers = self._check_potential_remote_workers_size() - if new_remote_workers: - self._last_resize = time.time() - self._start_workers(int(new_remote_workers) + 1) - self.load_state_dict(self.state_dict()) - return - else: - delay = 2**i - logger.warning( - "No new workers found. Retrying in %d sec." % delay) - time.sleep(delay) - raise RuntimeError("Exceeded max_retries for relaunching workers.") - - def _should_resize(self): - """Returns True if past cooldown and exists resources to scale up.""" - worker_gap = self.max_replicas - 1 - len(self.remote_workers) - past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S - if past_cooldown and worker_gap: - # Assume 1 resource is already reserved for local worker. - potential_remote_size = self._check_potential_remote_workers_size() - return potential_remote_size > 0 - return False + """ + self.worker_group.shutdown(force=force) + self.worker_group = DeactivatedWorkerGroup() @classmethod def as_trainable(cls, *args, **kwargs): @@ -698,16 +597,26 @@ class TorchTrainer: def default_resource_request(cls, config): num_workers = config.get("num_workers", kwargs.get("num_workers", 1)) - num_cpus = config.get("num_cpus_per_worker", - kwargs.get("num_cpus_per_worker", 1)) + num_cpus_per_worker = config.get( + "num_cpus_per_worker", kwargs.get("num_cpus_per_worker", + 1)) use_gpu = config.get("use_gpu", kwargs.get("use_gpu")) + use_local = config.get("use_local", + kwargs.get("use_local", False)) - remote_worker_count = num_workers - 1 + if use_local: + remote_worker_count = num_workers - 1 + local_cpus = 1 + local_gpus = int(use_gpu) + else: + remote_worker_count = num_workers + local_cpus = 0 + local_gpus = 0 return Resources( - cpu=num_cpus, - gpu=int(use_gpu), - extra_cpu=int(remote_worker_count), + cpu=int(local_cpus * num_cpus_per_worker), + gpu=int(local_gpus), + extra_cpu=int(remote_worker_count * num_cpus_per_worker), extra_gpu=int(int(use_gpu) * remote_worker_count)) def _create_trainer(self, tune_config): diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index bf9529f07..85ba958c0 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -255,8 +255,8 @@ class TrainingOperator: else: self._criterion = None - logger.debug("Setting up Apex.") if self.use_fp16 and amp: + logger.debug("Setting up Apex.") self._models, self._optimizers = amp.initialize( self._models, self._optimizers, **self._apex_args) self._amp = amp @@ -649,7 +649,9 @@ class TrainingOperator: """Override this to return a representation of the operator state. Any argument passed into self.register and self.register_data will automatically be saved. - Use this method to save any additional state. + Use this method to save any additional state. If your TorchTrainer + is on a CPU-only machine, make sure this method converts all state + to be CPU-compatible. Returns: dict: The state dict of the operator.""" diff --git a/python/ray/util/sgd/torch/utils.py b/python/ray/util/sgd/torch/utils.py index d7fef6ccf..c255a6e91 100644 --- a/python/ray/util/sgd/torch/utils.py +++ b/python/ray/util/sgd/torch/utils.py @@ -1,8 +1,8 @@ import os import logging -import torch.distributed as dist import ray +import torch.distributed as dist from ray.util.sgd.utils import find_free_port logger = logging.getLogger(__name__) diff --git a/python/ray/util/sgd/torch/worker_group.py b/python/ray/util/sgd/torch/worker_group.py new file mode 100644 index 000000000..89bc41e63 --- /dev/null +++ b/python/ray/util/sgd/torch/worker_group.py @@ -0,0 +1,574 @@ +import io +import logging +import time +from datetime import timedelta + +import ray +import torch +from ray.exceptions import RayActorError +from ray.util.sgd.torch.distributed_torch_runner import \ + LocalDistributedRunner, DistributedTorchRunner +from ray.util.sgd.torch.torch_runner import TorchRunner +from ray.util.sgd.torch.utils import setup_address +from ray.util.sgd.utils import check_for_failure + +RESIZE_COOLDOWN_S = 10 +logger = logging.getLogger(__name__) + + +class WorkerGroupInterface: + """Manages a group of TorchRunner workers.""" + + def start_workers(self, num_workers): + """Start workers for training. + + This method has 4 steps. + 1. Creates `num_workers` TorchRunner objects, either all as remote + actors or 1 locally and all but one remote. + 2. If necessary, applies an initialization hook to all the + workers. + 3. Sets up the process group for all workers if running in + distributed setting. + 4. Instantiates the user provided TrainingOperator on all + workers to set up training state. + """ + raise NotImplementedError + + def apply_all_operators(self, fn): + """See TorchTrainer.apply_all_operators.""" + raise NotImplementedError + + def apply_all_workers(self, fn): + """See TorchTrainer.apply_all_workers.""" + raise NotImplementedError + + def get_local_operator(self): + """See TorchTrainer.get_local_operator.""" + raise NotImplementedError + + def get_model(self): + """See TorchTrainer.get_model.""" + raise NotImplementedError + + def load_state_dict(self, state_dict, blocking=False): + """See TorchTrainer.load_state_dict.""" + raise NotImplementedError + + def new_workers_size(self): + """Returns number of workers to create based on available resources. + Total number of workers will never exceed `max_workers` amount. + """ + raise NotImplementedError + + def reset(self): + """Resets worker group.""" + raise NotImplementedError + + def should_scale_up(self): + """Returns whether to scale up the number of remote workers. + + This method returns True if current number of workers is less than + max_workers provided during startup, enough resources are + available to scale up, and a sufficient cooldown period has passed. + """ + raise NotImplementedError + + def shutdown(self, force=False): + """See TorchTrainer.shutdown.""" + raise NotImplementedError + + def state_dict(self): + """See TorchTrainer.state_dict.""" + raise NotImplementedError + + def train(self, num_steps=None, profile=False, info=None, dataset=None): + """Runs one epoch of training on all workers. + Args: + See TorchTrainer.train. + Returns: + Tuple of (bool, list). First value is True if training was + successful and False otherwise. Second value is list of + training results from all workers if training was successful, + and None otherwise. + """ + raise NotImplementedError + + def validate(self, num_steps=None, profile=False, info=None): + """Runs validation for all workers. + Args: + See TorchTrainer.validate. + Return: + List of validation results for each worker. + """ + raise NotImplementedError + + +class RemoteWorkerGroup(WorkerGroupInterface): + """A group of TorchRunner workers that are all remote Ray actors. + + Args: + max_workers (int): Maximum number of workers to use. + params (dict): Parameters to pass into a TorchRunner worker. + dist_params (dict): Additional parameters for distributed training + to pass into a DistributedTorchRunner worker. + initialization_hook (Callable): See TorchTrainer.__init__. + timeout_s (float): See TorchTrainer.__init__. + num_cpus_per_worker (int): See TorchTrainer.__init__. + use_gpu (bool): See TorchTrainer.__init__. + + """ + + def __init__(self, max_workers, params, dist_params, initialization_hook, + timeout_s, num_cpus_per_worker, use_gpu): + # Invariant: These variables should never change state! + self._max_workers = max_workers + self._params = params + self._dist_params = dist_params + self._initialization_hook = initialization_hook + self._timeout_s = timeout_s + self._num_cpus_per_worker = num_cpus_per_worker + self._use_gpu = use_gpu + + self.remote_workers = [] + + # The last time when this worker group was resized. + self._last_resize = float("-inf") + + def _init_dist_workers(self, num_workers): + """Create `num_workers` remote workers.""" + # Generate actor class + RemoteRunner = ray.remote( + num_cpus=self._num_cpus_per_worker, + num_gpus=int(self._use_gpu))(DistributedTorchRunner) + + # Start workers + self.remote_workers = [ + RemoteRunner.remote(**{ + **self._params, + **self._dist_params + }) for _ in range(num_workers) + ] + + def _setup_process_group(self, address, world_size, starting_rank=0): + """Sets up process group for all workers. + + Args: + address (str): Address to use for TCP process group setup. The + provided address must use the IP address of the node that the + rank 0 worker is located on. + world_size (int): Total number of training workers in the + process group. This may differ from self.num_workers if + there are additional workers outside this worker group class. + starting_rank (int): The rank to use for the first worker. + Worker ranks will be in [starting_rank, + len(self.remote_workers)+starting_rank). This is useful if + you want to use worker outside of this group as the rank 0 + worker. + + Returns: + List of process group set up promises. + """ + # Setup the process group among all workers. + remote_pgroup_setups = [ + worker.setup_process_group.remote( + url=address, + world_rank=i + starting_rank, + world_size=world_size, + timeout=timedelta(self._timeout_s)) + for i, worker in enumerate(self.remote_workers) + ] + return remote_pgroup_setups + + def _setup_operator(self): + """Instantiates user provided TrainingOperator on all workers. + + Returns: + List of operator setup promises. + """ + # Runs code that requires all creator functions to have run. + remote_operator_setups = [ + worker.setup_operator.remote() for worker in self.remote_workers + ] + return remote_operator_setups + + def start_workers(self, num_workers): + logger.debug(f"start_workers: Setting %d workers." % num_workers) + if num_workers == 1: + RemoteRunner = ray.remote( + num_cpus=self._num_cpus_per_worker, + num_gpus=int(self._use_gpu))(TorchRunner) + self.remote_workers = [RemoteRunner.remote(**self._params)] + ray.get(self.remote_workers[0].setup_operator.remote()) + else: + self._init_dist_workers(num_workers) + + if self._initialization_hook: + self.apply_all_workers(self._initialization_hook) + + # Make sure to get the IP address of the rank 0 worker node. + address = ray.get(self.remote_workers[0].setup_address.remote()) + + ray.get( + self._setup_process_group( + address=address, world_size=num_workers)) + + ray.get(self._setup_operator()) + + def _apply_all_operators(self, fn): + remote_calls = [ + w.apply_operator.remote(fn) for w in self.remote_workers + ] + return remote_calls + + def apply_all_operators(self, fn): + return ray.get(self._apply_all_operators(fn)) + + def _apply_all_workers(self, fn): + return [w.apply.remote(fn) for w in self.remote_workers] + + def apply_all_workers(self, fn): + return ray.get(self._apply_all_workers(fn)) + + def get_local_operator(self): + raise NotImplementedError( + "Cannot return a local operators if all " + "workers are remote. Set use_local to True in" + "TorchTrainer to access a local operator.") + + def get_model(self): + ready, _ = ray.wait( + [r.get_models.remote() for r in self.remote_workers]) + models = ray.get(ready[0]) + return models + + def _load_state_id(self, state_id): + """Loads the object with id `state_id` to all workers.""" + remote_calls = [ + worker.load_state_stream.remote(state_id) + for worker in self.remote_workers + ] + return remote_calls + + def load_state_dict(self, state_dict, blocking=False): + _buffer = io.BytesIO() + torch.save(state_dict, _buffer) + stream = _buffer.getvalue() + state_id = ray.put(stream) + + remote_calls = self._load_state_id(state_id) + + if blocking: + ray.get(remote_calls) + + def state_dict(self): + # This is needed to handle calling ray.get on a dead actor. + buffer_object = None + futures = {r.state_stream.remote() for r in self.remote_workers} + while len(futures) > 0: + ready, _ = ray.wait(list(futures), num_returns=1) + object_ref = ready[0] + try: + buffer_object = ray.get(object_ref) + except RayActorError: + futures.remove(object_ref) + else: + break + if buffer_object is None: + raise RuntimeError("Obtaining state_dict from remote workers is " + "unsuccessful since all workers have died.") + to_gpu = self._use_gpu and torch.cuda.is_available() + _buffer = io.BytesIO(buffer_object) + state_dict = torch.load( + _buffer, + map_location=("cpu" if not to_gpu else + lambda storage, loc: storage.cuda())) + return state_dict + + def _train(self, num_steps, profile, info, dataset=None): + """Runs 1 epoch of training on all workers. + Returns training result for all workers as promises. + """ + remote_worker_stats = [] + for i, w in enumerate(self.remote_workers): + params = dict(num_steps=num_steps, profile=profile, info=info) + if dataset: + params["iterator"] = dataset.get_shard(i) + stats = w.train_epoch.remote(**params) + remote_worker_stats.append(stats) + return remote_worker_stats + + def train(self, num_steps=None, profile=False, info=None, dataset=None): + """Runs 1 epoch of training on all workers. + + Has additional logic to check for worker failure. + """ + if dataset: + dataset.set_num_shards(self.num_workers) + remote_worker_stats = self._train(num_steps, profile, info, dataset) + # Check if each worker has failed before calling ray.get. + success = check_for_failure(remote_worker_stats) + if success: + return success, ray.get(remote_worker_stats) + return success, None + + def _validate(self, params): + """Runs validation for each worker. Returns results as promises.""" + remote_worker_stats = [ + w.validate.remote(**params) for w in self.remote_workers + ] + return remote_worker_stats + + def validate(self, num_steps=None, profile=False, info=None): + params = dict(num_steps=num_steps, profile=profile, info=info) + remote_worker_stats = self._validate(params) + return ray.get(remote_worker_stats) + + def _shutdown_remote_workers(self): + """Shuts down each worker and returns a list of cleanup promises.""" + cleanup = [worker.shutdown.remote() for worker in self.remote_workers] + return cleanup + + def _terminate_remote_workers(self, cleanup): + """Blocks on worker shutdown and then terminates each worker actor. + + If graceful shutdown fails, forcefully kills all actors. + """ + try: + ray.get(cleanup) + [ + worker.__ray_terminate__.remote() + for worker in self.remote_workers + ] + except RayActorError: + logger.warning("Failed to shutdown gracefully, forcing a " + "shutdown.") + self.reset() + + def shutdown(self, force=False): + if not force: + cleanup = [ + worker.shutdown.remote() for worker in self.remote_workers + ] + self._terminate_remote_workers(cleanup) + else: + self.reset() + self.remote_workers = [] + + def reset(self): + for worker in self.remote_workers: + logger.debug(f"Killing worker {worker}.") + ray.kill(worker) + self.remote_workers = [] + + def should_scale_up(self): + worker_gap = self._max_workers - self.num_workers + past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S + if past_cooldown and worker_gap: + # Assume 1 resource is already reserved for local worker. + potential_remote_size = self._check_potential_remote_workers_size() + return potential_remote_size > 0 + return False + + def new_workers_size(self): + """Returns number of workers to create based on available resources.""" + remote_resources = ray.available_resources() + max_remote_workers = self._max_workers + new_remote_workers = min( + remote_resources.get("CPU", 0), max_remote_workers) + if self._use_gpu: + new_remote_workers = min( + remote_resources.get("GPU", 0), new_remote_workers) + return new_remote_workers + + @property + def num_workers(self): + """Current number of workers being used for training. + This may differ from self.num_workers if self.resize_workers has + been called. + """ + return len(self.remote_workers) + + +class LocalWorkerGroup(WorkerGroupInterface): + """A group of TorchRunner workers. + 1 worker runs locally, and all the other workers are remote actors. + + Args: + Same as RemoteWorkerGroup. + """ + + def __init__(self, max_workers, params, dist_params, initialization_hook, + timeout_s, num_cpus_per_worker, use_gpu): + + # Invariant: These variables should never change state! + self._max_workers = max_workers + self._params = params + self._dist_params = dist_params + self._initialization_hook = initialization_hook + self._timeout_s = timeout_s + self._num_cpus_per_worker = num_cpus_per_worker + self._use_gpu = use_gpu + + self.local_worker = None + self.remote_worker_group = RemoteWorkerGroup( + max_workers=max_workers - 1, + params=params, + dist_params=dist_params, + initialization_hook=initialization_hook, + timeout_s=timeout_s, + num_cpus_per_worker=num_cpus_per_worker, + use_gpu=use_gpu) + + def start_workers(self, num_workers): + logger.debug(f"start_workers: Setting %d workers." % num_workers) + + if num_workers == 1: + self.local_worker = TorchRunner(**self._params) + if self._initialization_hook: + self.apply_all_workers(self._initialization_hook) + self.local_worker.setup_operator() + else: + + # Start local worker + self.local_worker = LocalDistributedRunner( + num_cpus=self._num_cpus_per_worker, + num_gpus=int(self._use_gpu), + **{ + **self._params, + **self._dist_params + }) + self.remote_worker_group._init_dist_workers(num_workers - 1) + if self._initialization_hook: + self.apply_all_workers(self._initialization_hook) + + # Compute URL for initializing distributed PyTorch. + address = setup_address() + + remote_pgs = self.remote_worker_group._setup_process_group( + address=address, world_size=num_workers, starting_rank=1) + # Use the local worker as rank 0. This will help with debugging. + self.local_worker.setup_process_group( + url=address, + world_rank=0, + world_size=num_workers, + timeout=timedelta(self._timeout_s)) + ray.get(remote_pgs) + + remote_operators = self.remote_worker_group._setup_operator() + self.local_worker.setup_operator() + ray.get(remote_operators) + + def apply_all_operators(self, fn): + remote_calls = self.remote_worker_group._apply_all_operators(fn) + local_call = self.local_worker.apply_operator(fn) + return [local_call] + ray.get(remote_calls) + + def apply_all_workers(self, fn): + remote_calls = self.remote_worker_group.apply_all_workers(fn) + local_call = self.local_worker.apply(fn) + return [local_call] + ray.get(remote_calls) + + def get_local_operator(self): + return self.local_worker.training_operator + + def get_model(self): + return self.local_worker.models + + def load_state_dict(self, state_dict, blocking=False): + # This is not the most efficient because you have to wait for + # the local worker to save then dump to buffer. + self.local_worker.load_state_dict(state_dict) + state_id = ray.put(self.local_worker.state_stream()) + + remote_calls = self.remote_worker_group._load_state_id(state_id) + if blocking: + ray.get(remote_calls) + + def state_dict(self): + return self.local_worker.state_dict() + + def should_scale_up(self): + return self.remote_worker_group.should_scale_up() + + def reset(self): + """Terminates models without giving up local resource reservation.""" + self.local_worker.shutdown(cleanup=False) + self.remote_worker_group.reset() + + self.local_worker = None + self.remote_worker_group = RemoteWorkerGroup( + max_workers=self._max_workers - 1, + params=self._params, + dist_params=self._dist_params, + initialization_hook=self._initialization_hook, + num_cpus_per_worker=self._num_cpus_per_worker, + use_gpu=self._use_gpu, + timeout_s=self._timeout_s) + + def new_workers_size(self): + return self.remote_worker_group.new_workers_size() + 1 + + def train(self, num_steps=None, profile=False, info=None, dataset=None): + params = dict(num_steps=num_steps, profile=profile, info=info) + if dataset: + dataset.set_num_shards(self.num_workers) + + remote_worker_stats = self.remote_worker_group._train( + num_steps, profile, info, dataset) + try: + if dataset: + params["iterator"] = dataset.get_shard(self.num_workers - 1) + local_worker_stats = self.local_worker.train_epoch(**params) + except RuntimeError as err: + if "gloo" in err.args[0] and "Timed out" in err.args[0]: + logger.warning(err) + return False, None + if "NCCL" in err.args[0]: # there is no specific error message + logger.warning(err) + return False, None + if "Connection closed by peer" in err.args[0]: + logger.warning(err) + return False, None + + raise err + + success = check_for_failure(remote_worker_stats) + if success: + return success, [local_worker_stats] + ray.get(remote_worker_stats) + + return success, None + + def validate(self, num_steps=None, profile=False, info=None): + params = dict(num_steps=num_steps, profile=profile, info=info) + + remote_worker_stats = self.remote_worker_group._validate(params) + local_worker_stats = self.local_worker.validate(**params) + worker_stats = [local_worker_stats] + ray.get(remote_worker_stats) + return worker_stats + + def shutdown(self, force=False): + if not force: + cleanup = self.remote_worker_group._shutdown_remote_workers() + self.local_worker.shutdown() + self.remote_worker_group._terminate_remote_workers(cleanup) + else: + self.local_worker.shutdown() + self.remote_worker_group.reset() + + self.local_worker = None + self.remote_worker_group = DeactivatedWorkerGroup() + + @property + def num_workers(self): + return self.remote_worker_group.num_workers + 1 + + @property + def remote_workers(self): + return self.remote_worker_group.remote_workers + + +class DeactivatedWorkerGroup: + def __getattr__(self, *args, **kwargs): + raise RuntimeError( + "This TorchTrainer is not active (it is likely shutdown already). " + "Create a new TorchTrainer.")