diff --git a/python/ray/util/sgd/torch/constants.py b/python/ray/util/sgd/torch/constants.py index cf3a7dc8f..1fd37c8d6 100644 --- a/python/ray/util/sgd/torch/constants.py +++ b/python/ray/util/sgd/torch/constants.py @@ -6,7 +6,7 @@ SCHEDULER_STEP = "scheduler_step" SCHEDULER_STEP_BATCH = "batch" SCHEDULER_STEP_EPOCH = "epoch" SCHEDULER_STEP_MANUAL = "manual" -NCCL_TIMEOUT_S = env_integer("NCCL_TIMEOUT_S", 10) +NCCL_TIMEOUT_S = env_integer("NCCL_TIMEOUT_S", 1800) VALID_SCHEDULER_STEP = { SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH, SCHEDULER_STEP_MANUAL diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 50f20e39f..79217bcc0 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -105,7 +105,9 @@ class TorchTrainer: wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel over each model. If False, you are expected to call it yourself. timeout_s (float): Seconds before the torch process group - times out. Useful when machines are unreliable. + times out. Useful when machines are unreliable. If not set, default + to 30 min, which is the same default as + ``torch.init_process_group(...)``. add_dist_sampler (bool): Whether to automatically add a DistributedSampler to all created dataloaders. Only applicable if num_workers > 1. @@ -143,7 +145,7 @@ class TorchTrainer: use_gpu="auto", backend="auto", wrap_ddp=True, - timeout_s=NCCL_TIMEOUT_S, + timeout_s=1800, use_fp16=False, use_tqdm=False, add_dist_sampler=True, @@ -230,6 +232,9 @@ class TorchTrainer: if backend == "auto": backend = "nccl" if use_gpu else "gloo" + if backend == "nccl": + timeout_s = NCCL_TIMEOUT_S + logger.debug(f"Using {backend} as backend.") self.backend = backend self.num_cpus_per_worker = num_cpus_per_worker diff --git a/python/ray/util/sgd/torch/worker_group.py b/python/ray/util/sgd/torch/worker_group.py index f1a82082a..390059c87 100644 --- a/python/ray/util/sgd/torch/worker_group.py +++ b/python/ray/util/sgd/torch/worker_group.py @@ -175,7 +175,7 @@ class RemoteWorkerGroup(WorkerGroupInterface): url=address, world_rank=i + starting_rank, world_size=world_size, - timeout=timedelta(self._timeout_s)) + timeout=timedelta(seconds=self._timeout_s)) for i, worker in enumerate(self.remote_workers) ] return remote_pgroup_setups @@ -467,7 +467,7 @@ class LocalWorkerGroup(WorkerGroupInterface): url=address, world_rank=0, world_size=num_workers, - timeout=timedelta(self._timeout_s)) + timeout=timedelta(seconds=self._timeout_s)) ray.get(remote_pgs) local_node_ip = ray.services.get_node_ip_address()