[SGD] Fix process group timeout units (#12477)

This commit is contained in:
Amog Kamsetty
2020-12-19 21:46:33 -08:00
committed by GitHub
parent 4832b39066
commit 51139ed37c
3 changed files with 10 additions and 5 deletions
+1 -1
View File
@@ -6,7 +6,7 @@ SCHEDULER_STEP = "scheduler_step"
SCHEDULER_STEP_BATCH = "batch"
SCHEDULER_STEP_EPOCH = "epoch"
SCHEDULER_STEP_MANUAL = "manual"
NCCL_TIMEOUT_S = env_integer("NCCL_TIMEOUT_S", 10)
NCCL_TIMEOUT_S = env_integer("NCCL_TIMEOUT_S", 1800)
VALID_SCHEDULER_STEP = {
SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH, SCHEDULER_STEP_MANUAL
+7 -2
View File
@@ -105,7 +105,9 @@ class TorchTrainer:
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
over each model. If False, you are expected to call it yourself.
timeout_s (float): Seconds before the torch process group
times out. Useful when machines are unreliable.
times out. Useful when machines are unreliable. If not set, default
to 30 min, which is the same default as
``torch.init_process_group(...)``.
add_dist_sampler (bool): Whether to automatically add a
DistributedSampler to all created dataloaders. Only applicable
if num_workers > 1.
@@ -143,7 +145,7 @@ class TorchTrainer:
use_gpu="auto",
backend="auto",
wrap_ddp=True,
timeout_s=NCCL_TIMEOUT_S,
timeout_s=1800,
use_fp16=False,
use_tqdm=False,
add_dist_sampler=True,
@@ -230,6 +232,9 @@ class TorchTrainer:
if backend == "auto":
backend = "nccl" if use_gpu else "gloo"
if backend == "nccl":
timeout_s = NCCL_TIMEOUT_S
logger.debug(f"Using {backend} as backend.")
self.backend = backend
self.num_cpus_per_worker = num_cpus_per_worker
+2 -2
View File
@@ -175,7 +175,7 @@ class RemoteWorkerGroup(WorkerGroupInterface):
url=address,
world_rank=i + starting_rank,
world_size=world_size,
timeout=timedelta(self._timeout_s))
timeout=timedelta(seconds=self._timeout_s))
for i, worker in enumerate(self.remote_workers)
]
return remote_pgroup_setups
@@ -467,7 +467,7 @@ class LocalWorkerGroup(WorkerGroupInterface):
url=address,
world_rank=0,
world_size=num_workers,
timeout=timedelta(self._timeout_s))
timeout=timedelta(seconds=self._timeout_s))
ray.get(remote_pgs)
local_node_ip = ray.services.get_node_ip_address()