mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 00:46:14 +08:00
[SGD] Fix process group timeout units (#12477)
This commit is contained in:
@@ -6,7 +6,7 @@ SCHEDULER_STEP = "scheduler_step"
|
||||
SCHEDULER_STEP_BATCH = "batch"
|
||||
SCHEDULER_STEP_EPOCH = "epoch"
|
||||
SCHEDULER_STEP_MANUAL = "manual"
|
||||
NCCL_TIMEOUT_S = env_integer("NCCL_TIMEOUT_S", 10)
|
||||
NCCL_TIMEOUT_S = env_integer("NCCL_TIMEOUT_S", 1800)
|
||||
|
||||
VALID_SCHEDULER_STEP = {
|
||||
SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH, SCHEDULER_STEP_MANUAL
|
||||
|
||||
@@ -105,7 +105,9 @@ class TorchTrainer:
|
||||
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
|
||||
over each model. If False, you are expected to call it yourself.
|
||||
timeout_s (float): Seconds before the torch process group
|
||||
times out. Useful when machines are unreliable.
|
||||
times out. Useful when machines are unreliable. If not set, default
|
||||
to 30 min, which is the same default as
|
||||
``torch.init_process_group(...)``.
|
||||
add_dist_sampler (bool): Whether to automatically add a
|
||||
DistributedSampler to all created dataloaders. Only applicable
|
||||
if num_workers > 1.
|
||||
@@ -143,7 +145,7 @@ class TorchTrainer:
|
||||
use_gpu="auto",
|
||||
backend="auto",
|
||||
wrap_ddp=True,
|
||||
timeout_s=NCCL_TIMEOUT_S,
|
||||
timeout_s=1800,
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
add_dist_sampler=True,
|
||||
@@ -230,6 +232,9 @@ class TorchTrainer:
|
||||
if backend == "auto":
|
||||
backend = "nccl" if use_gpu else "gloo"
|
||||
|
||||
if backend == "nccl":
|
||||
timeout_s = NCCL_TIMEOUT_S
|
||||
|
||||
logger.debug(f"Using {backend} as backend.")
|
||||
self.backend = backend
|
||||
self.num_cpus_per_worker = num_cpus_per_worker
|
||||
|
||||
@@ -175,7 +175,7 @@ class RemoteWorkerGroup(WorkerGroupInterface):
|
||||
url=address,
|
||||
world_rank=i + starting_rank,
|
||||
world_size=world_size,
|
||||
timeout=timedelta(self._timeout_s))
|
||||
timeout=timedelta(seconds=self._timeout_s))
|
||||
for i, worker in enumerate(self.remote_workers)
|
||||
]
|
||||
return remote_pgroup_setups
|
||||
@@ -467,7 +467,7 @@ class LocalWorkerGroup(WorkerGroupInterface):
|
||||
url=address,
|
||||
world_rank=0,
|
||||
world_size=num_workers,
|
||||
timeout=timedelta(self._timeout_s))
|
||||
timeout=timedelta(seconds=self._timeout_s))
|
||||
ray.get(remote_pgs)
|
||||
|
||||
local_node_ip = ray.services.get_node_ip_address()
|
||||
|
||||
Reference in New Issue
Block a user