[SGD] Fixes TorchTrainer scales up (#12563)

This commit is contained in:
Xianyang Liu
2020-12-05 05:55:15 +08:00
committed by GitHub
parent f965537ae9
commit 7cad648370
3 changed files with 21 additions and 8 deletions
@@ -93,14 +93,22 @@ def test_resize(ray_start_2_cpus, use_local): # noqa: F811
use_local=use_local,
num_workers=2)
@ray.remote
def try_test():
import time
time.sleep(100)
@ray.remote(num_cpus=1)
class DummyActor:
def get(self):
return 1
try_test.remote()
dummy_handler = DummyActor.remote()
trainer1.train(max_retries=1)
assert trainer1.worker_group.num_workers == 1
assert trainer1._num_failures == 1
ray.get(dummy_handler.get.remote())
ray.kill(dummy_handler)
time.sleep(1)
# trigger scale up
trainer1.train()
assert trainer1.worker_group.num_workers == 2
trainer1.shutdown(force=True)
@@ -132,6 +140,8 @@ def test_fail_twice(ray_start_2_cpus, use_local): # noqa: F811
# MAX RETRIES SHOULD BE ON BY DEFAULT
trainer1.train()
assert trainer1._num_failures == 2
assert trainer1.worker_group.num_workers == 2
trainer1.shutdown(force=True)
+1 -1
View File
@@ -300,7 +300,7 @@ class TorchTrainer:
wrap_ddp=self.wrap_ddp)
worker_args = {
"max_workers": num_workers,
"max_workers": self.max_replicas,
"params": params,
"dist_params": dist_params,
"initialization_hook": self.initialization_hook,
+5 -2
View File
@@ -381,7 +381,7 @@ class RemoteWorkerGroup(WorkerGroupInterface):
past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S
if past_cooldown and worker_gap:
# Assume 1 resource is already reserved for local worker.
potential_remote_size = self._check_potential_remote_workers_size()
potential_remote_size = self.new_workers_size()
return potential_remote_size > 0
return False
@@ -514,7 +514,10 @@ class LocalWorkerGroup(WorkerGroupInterface):
def reset(self):
"""Terminates models without giving up local resource reservation."""
self.local_worker.shutdown(cleanup=False)
if not isinstance(self.local_worker, LocalDistributedRunner):
self.local_worker.shutdown()
else:
self.local_worker.shutdown(cleanup=False)
self.remote_worker_group.reset()
self.local_worker = None