From 7cad64837042dc1da4cae4a4f8d47b1891fbc136 Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Sat, 5 Dec 2020 05:55:15 +0800 Subject: [PATCH] [SGD] Fixes TorchTrainer scales up (#12563) --- .../ray/util/sgd/tests/test_torch_failure.py | 20 ++++++++++++++----- python/ray/util/sgd/torch/torch_trainer.py | 2 +- python/ray/util/sgd/torch/worker_group.py | 7 +++++-- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/python/ray/util/sgd/tests/test_torch_failure.py b/python/ray/util/sgd/tests/test_torch_failure.py index 97324e0a5..4f48daaec 100644 --- a/python/ray/util/sgd/tests/test_torch_failure.py +++ b/python/ray/util/sgd/tests/test_torch_failure.py @@ -93,14 +93,22 @@ def test_resize(ray_start_2_cpus, use_local): # noqa: F811 use_local=use_local, num_workers=2) - @ray.remote - def try_test(): - import time - time.sleep(100) + @ray.remote(num_cpus=1) + class DummyActor: + def get(self): + return 1 - try_test.remote() + dummy_handler = DummyActor.remote() trainer1.train(max_retries=1) assert trainer1.worker_group.num_workers == 1 + assert trainer1._num_failures == 1 + + ray.get(dummy_handler.get.remote()) + ray.kill(dummy_handler) + time.sleep(1) + # trigger scale up + trainer1.train() + assert trainer1.worker_group.num_workers == 2 trainer1.shutdown(force=True) @@ -132,6 +140,8 @@ def test_fail_twice(ray_start_2_cpus, use_local): # noqa: F811 # MAX RETRIES SHOULD BE ON BY DEFAULT trainer1.train() + assert trainer1._num_failures == 2 + assert trainer1.worker_group.num_workers == 2 trainer1.shutdown(force=True) diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 0602222c5..d17c23191 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -300,7 +300,7 @@ class TorchTrainer: wrap_ddp=self.wrap_ddp) worker_args = { - "max_workers": num_workers, + "max_workers": self.max_replicas, "params": params, "dist_params": dist_params, "initialization_hook": self.initialization_hook, diff --git a/python/ray/util/sgd/torch/worker_group.py b/python/ray/util/sgd/torch/worker_group.py index 9a42bf46b..f1a82082a 100644 --- a/python/ray/util/sgd/torch/worker_group.py +++ b/python/ray/util/sgd/torch/worker_group.py @@ -381,7 +381,7 @@ class RemoteWorkerGroup(WorkerGroupInterface): past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S if past_cooldown and worker_gap: # Assume 1 resource is already reserved for local worker. - potential_remote_size = self._check_potential_remote_workers_size() + potential_remote_size = self.new_workers_size() return potential_remote_size > 0 return False @@ -514,7 +514,10 @@ class LocalWorkerGroup(WorkerGroupInterface): def reset(self): """Terminates models without giving up local resource reservation.""" - self.local_worker.shutdown(cleanup=False) + if not isinstance(self.local_worker, LocalDistributedRunner): + self.local_worker.shutdown() + else: + self.local_worker.shutdown(cleanup=False) self.remote_worker_group.reset() self.local_worker = None