mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 09:55:49 +08:00
[SGD] Fixes TorchTrainer scales up (#12563)
This commit is contained in:
@@ -93,14 +93,22 @@ def test_resize(ray_start_2_cpus, use_local): # noqa: F811
|
||||
use_local=use_local,
|
||||
num_workers=2)
|
||||
|
||||
@ray.remote
|
||||
def try_test():
|
||||
import time
|
||||
time.sleep(100)
|
||||
@ray.remote(num_cpus=1)
|
||||
class DummyActor:
|
||||
def get(self):
|
||||
return 1
|
||||
|
||||
try_test.remote()
|
||||
dummy_handler = DummyActor.remote()
|
||||
trainer1.train(max_retries=1)
|
||||
assert trainer1.worker_group.num_workers == 1
|
||||
assert trainer1._num_failures == 1
|
||||
|
||||
ray.get(dummy_handler.get.remote())
|
||||
ray.kill(dummy_handler)
|
||||
time.sleep(1)
|
||||
# trigger scale up
|
||||
trainer1.train()
|
||||
assert trainer1.worker_group.num_workers == 2
|
||||
|
||||
trainer1.shutdown(force=True)
|
||||
|
||||
@@ -132,6 +140,8 @@ def test_fail_twice(ray_start_2_cpus, use_local): # noqa: F811
|
||||
|
||||
# MAX RETRIES SHOULD BE ON BY DEFAULT
|
||||
trainer1.train()
|
||||
assert trainer1._num_failures == 2
|
||||
assert trainer1.worker_group.num_workers == 2
|
||||
trainer1.shutdown(force=True)
|
||||
|
||||
|
||||
|
||||
@@ -300,7 +300,7 @@ class TorchTrainer:
|
||||
wrap_ddp=self.wrap_ddp)
|
||||
|
||||
worker_args = {
|
||||
"max_workers": num_workers,
|
||||
"max_workers": self.max_replicas,
|
||||
"params": params,
|
||||
"dist_params": dist_params,
|
||||
"initialization_hook": self.initialization_hook,
|
||||
|
||||
@@ -381,7 +381,7 @@ class RemoteWorkerGroup(WorkerGroupInterface):
|
||||
past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S
|
||||
if past_cooldown and worker_gap:
|
||||
# Assume 1 resource is already reserved for local worker.
|
||||
potential_remote_size = self._check_potential_remote_workers_size()
|
||||
potential_remote_size = self.new_workers_size()
|
||||
return potential_remote_size > 0
|
||||
return False
|
||||
|
||||
@@ -514,7 +514,10 @@ class LocalWorkerGroup(WorkerGroupInterface):
|
||||
|
||||
def reset(self):
|
||||
"""Terminates models without giving up local resource reservation."""
|
||||
self.local_worker.shutdown(cleanup=False)
|
||||
if not isinstance(self.local_worker, LocalDistributedRunner):
|
||||
self.local_worker.shutdown()
|
||||
else:
|
||||
self.local_worker.shutdown(cleanup=False)
|
||||
self.remote_worker_group.reset()
|
||||
|
||||
self.local_worker = None
|
||||
|
||||
Reference in New Issue
Block a user