mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 11:27:32 +08:00
[sgd] make ddp optional (#7875)
* loosen * devices * tryitout * fix * fix * fix * easy * test * fix * fix * better visibility * fix
This commit is contained in:
@@ -503,6 +503,46 @@ def test_save_and_restore(ray_start_2_cpus, num_workers,
|
||||
trainer2.shutdown()
|
||||
|
||||
|
||||
def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
wrap_ddp=False,
|
||||
num_workers=2)
|
||||
trainer1.train()
|
||||
checkpoint_path = os.path.join(tmp_path, "checkpoint")
|
||||
trainer1.save(checkpoint_path)
|
||||
|
||||
model1 = trainer1.get_model()
|
||||
assert not hasattr(trainer1.local_worker.training_operator.model, "module")
|
||||
assert hasattr(trainer1.local_worker.training_operator, "device_ids")
|
||||
trainer1.shutdown()
|
||||
|
||||
trainer2 = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
wrap_ddp=False,
|
||||
num_workers=2)
|
||||
trainer2.load(checkpoint_path)
|
||||
|
||||
model2 = trainer2.get_model()
|
||||
|
||||
model1_state_dict = model1.state_dict()
|
||||
model2_state_dict = model2.state_dict()
|
||||
|
||||
assert set(model1_state_dict.keys()) == set(model2_state_dict.keys())
|
||||
|
||||
for k in model1_state_dict:
|
||||
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
|
||||
trainer2.shutdown()
|
||||
|
||||
|
||||
def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from ray.ray_constants import env_integer
|
||||
|
||||
USE_FP16 = "__use_fp16__"
|
||||
NUM_STEPS = "__num_steps__"
|
||||
SCHEDULER_STEP = "scheduler_step"
|
||||
SCHEDULER_STEP_BATCH = "batch"
|
||||
SCHEDULER_STEP_EPOCH = "epoch"
|
||||
NCCL_TIMEOUT_IN_SECONDS = 10
|
||||
NCCL_TIMEOUT_S = env_integer("NCCL_TIMEOUT_S", 10)
|
||||
|
||||
VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH}
|
||||
|
||||
@@ -9,7 +9,7 @@ import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_IN_SECONDS
|
||||
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner, _remind_gpu_usage
|
||||
@@ -26,15 +26,23 @@ class DistributedTorchRunner(TorchRunner):
|
||||
backend (str): Backend used by distributed PyTorch.
|
||||
add_dist_sampler (bool): Whether to automatically add a
|
||||
DistributedSampler to all created dataloaders.
|
||||
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
|
||||
over each model. If False, you are expected to call it yourself.
|
||||
kwargs: Keyword arguments for TorchRunner.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args, backend="gloo", add_dist_sampler=True, **kwargs):
|
||||
def __init__(self,
|
||||
*args,
|
||||
backend="gloo",
|
||||
add_dist_sampler=True,
|
||||
wrap_ddp=False,
|
||||
**kwargs):
|
||||
super(DistributedTorchRunner, self).__init__(*args, **kwargs)
|
||||
if backend not in ("gloo", "nccl"):
|
||||
raise ValueError("Backend must be one of 'gloo' or 'nccl'.")
|
||||
self.backend = backend
|
||||
self.wrap_ddp = wrap_ddp
|
||||
self.add_dist_sampler = add_dist_sampler
|
||||
|
||||
def setup(self, url, world_rank, world_size):
|
||||
@@ -61,7 +69,7 @@ class DistributedTorchRunner(TorchRunner):
|
||||
"To override this behavior, you can set NCCL_BLOCKING_WAIT=0.")
|
||||
os.environ["NCCL_BLOCKING_WAIT"] = "1"
|
||||
|
||||
timeout = timedelta(seconds=NCCL_TIMEOUT_IN_SECONDS)
|
||||
timeout = timedelta(seconds=NCCL_TIMEOUT_S)
|
||||
dist.init_process_group(
|
||||
backend=self.backend,
|
||||
init_method=url,
|
||||
@@ -73,9 +81,10 @@ class DistributedTorchRunner(TorchRunner):
|
||||
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
# https://github.com/allenai/allennlp/issues/1090
|
||||
self._set_cuda_device_id()
|
||||
self.set_cuda_device_id()
|
||||
|
||||
def _set_cuda_device_id(self):
|
||||
def set_cuda_device_id(self):
|
||||
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
|
||||
self.device_ids = [0]
|
||||
|
||||
def _setup_training(self):
|
||||
@@ -99,23 +108,27 @@ class DistributedTorchRunner(TorchRunner):
|
||||
self._create_schedulers_if_available()
|
||||
self._try_setup_apex()
|
||||
|
||||
# This needs to happen after apex
|
||||
self.models = [
|
||||
DistributedDataParallel(model, device_ids=self.device_ids)
|
||||
for model in self.models
|
||||
]
|
||||
|
||||
self._create_loss()
|
||||
|
||||
training_models = self.models
|
||||
|
||||
if self.wrap_ddp:
|
||||
# This needs to happen after apex
|
||||
training_models = [
|
||||
DistributedDataParallel(model, device_ids=self.device_ids)
|
||||
for model in self.models
|
||||
]
|
||||
|
||||
self.training_operator = self.training_operator_cls(
|
||||
self.config,
|
||||
models=self.models,
|
||||
models=training_models,
|
||||
optimizers=self.optimizers,
|
||||
criterion=self.criterion,
|
||||
train_loader=self.train_loader,
|
||||
validation_loader=self.validation_loader,
|
||||
world_rank=self.world_rank,
|
||||
schedulers=self.schedulers,
|
||||
device_ids=self.device_ids,
|
||||
use_gpu=self.use_gpu,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm)
|
||||
@@ -158,17 +171,6 @@ class DistributedTorchRunner(TorchRunner):
|
||||
self.train_loader.sampler.set_epoch(self.epochs)
|
||||
return super(DistributedTorchRunner, self).train_epoch(**kwargs)
|
||||
|
||||
def _get_model_state_dicts(self):
|
||||
"""Fetch state from ``model.module`` instead of ``model``.
|
||||
|
||||
This is needed for PyTorch DistributedDataParallel models.
|
||||
"""
|
||||
return [model.module.state_dict() for model in self.models]
|
||||
|
||||
def _set_model_state_dicts(self, model_state_dicts):
|
||||
for model, model_state_dict in zip(self.models, model_state_dicts):
|
||||
model.module.load_state_dict(model_state_dict)
|
||||
|
||||
def shutdown(self):
|
||||
"""Attempts to shut down the worker."""
|
||||
# However, it seems to be harmless to remove permanently
|
||||
@@ -179,6 +181,14 @@ class DistributedTorchRunner(TorchRunner):
|
||||
super(DistributedTorchRunner, self).shutdown()
|
||||
|
||||
|
||||
def _init_cuda_context():
|
||||
# Force cuda initialization
|
||||
# Inspired by https://github.com/pytorch/pytorch/blob/
|
||||
# f050b16dd95b2bcce9853882fd3fb07a6fd80378/torch/testing/
|
||||
# _internal/common_cuda.py
|
||||
torch.cuda.is_available()
|
||||
|
||||
|
||||
class _DummyActor:
|
||||
def cuda_devices(self):
|
||||
return os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
@@ -203,7 +213,7 @@ class LocalDistributedRunner(DistributedTorchRunner):
|
||||
|
||||
# Reserve a local GPU or CPU for the local worker
|
||||
# TODO: we should make sure this NEVER dies.
|
||||
|
||||
self.local_device = "0"
|
||||
global _dummy_actor
|
||||
if not self.is_actor() and _dummy_actor is None:
|
||||
_dummy_actor = ray.remote(
|
||||
@@ -215,16 +225,25 @@ class LocalDistributedRunner(DistributedTorchRunner):
|
||||
|
||||
# This is a pretty annoying workaround. To enable SyncBatchNorm,
|
||||
# we need to signify that we are using only 1 CUDA device (via
|
||||
# the DDP constructor). However, on the local worker,
|
||||
# we set the CUDA_VISIBLE_DEVICES at runtime rather at process
|
||||
# start. This means that we have to make sure that DDP knows which
|
||||
# specific device we are using.
|
||||
# the DDP constructor).
|
||||
# However, on the local worker, we have to set the
|
||||
# CUDA_VISIBLE_DEVICES at runtime rather at process start.
|
||||
|
||||
# You can only call setdevice(int > 0) after you've interacted with
|
||||
# torch.cuda. But you can't guarantee that you _haven't_ interacted
|
||||
# with it (user can do arbitrary things), so we force an
|
||||
# interaction.
|
||||
_init_cuda_context()
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = self.local_device
|
||||
if self.local_device:
|
||||
torch.cuda.set_device(int(self.local_device))
|
||||
try:
|
||||
torch.cuda.set_device(int(self.local_device))
|
||||
except RuntimeError:
|
||||
logger.error("This happens if cuda is not initialized.")
|
||||
raise
|
||||
super(LocalDistributedRunner, self).__init__(*args, **kwargs)
|
||||
|
||||
def _set_cuda_device_id(self):
|
||||
def set_cuda_device_id(self):
|
||||
self.device_ids = [int(self.local_device)]
|
||||
|
||||
def shutdown(self, cleanup=True):
|
||||
|
||||
@@ -239,19 +239,12 @@ class TorchRunner:
|
||||
self.timers.disable()
|
||||
self.training_operator._set_timers(self.timers)
|
||||
|
||||
def _get_model_state_dicts(self):
|
||||
return [model.state_dict() for model in self.models]
|
||||
|
||||
def _set_model_state_dicts(self, models_state_dicts):
|
||||
for model, state_dict in zip(self.models, models_state_dicts):
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the runner."""
|
||||
state = {
|
||||
"epoch": self.epochs,
|
||||
"operator": self.training_operator.state_dict(),
|
||||
"models": self._get_model_state_dicts(),
|
||||
"models": [model.state_dict() for model in self.models],
|
||||
"optimizers": [opt.state_dict() for opt in self.optimizers]
|
||||
}
|
||||
if self.schedulers:
|
||||
@@ -267,7 +260,8 @@ class TorchRunner:
|
||||
|
||||
def load_state_dict(self, state):
|
||||
"""Sets the state of the model."""
|
||||
self._set_model_state_dicts(state["models"])
|
||||
for model, state_dict in zip(self.models, state["models"]):
|
||||
model.load_state_dict(state_dict)
|
||||
for optimizer, state_dict in zip(self.optimizers, state["optimizers"]):
|
||||
optimizer.load_state_dict(state_dict)
|
||||
if self.schedulers:
|
||||
|
||||
@@ -117,6 +117,8 @@ class TorchTrainer:
|
||||
support "nccl", "gloo", and "auto". If "auto", RaySGD will
|
||||
automatically use "nccl" if `use_gpu` is True, and "gloo"
|
||||
otherwise.
|
||||
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
|
||||
over each model. If False, you are expected to call it yourself.
|
||||
add_dist_sampler (bool): Whether to automatically add a
|
||||
DistributedSampler to all created dataloaders. Only applicable
|
||||
if num_workers > 1.
|
||||
@@ -154,6 +156,7 @@ class TorchTrainer:
|
||||
num_workers=1,
|
||||
use_gpu="auto",
|
||||
backend="auto",
|
||||
wrap_ddp=True,
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
@@ -218,6 +221,7 @@ class TorchTrainer:
|
||||
self.use_gpu = use_gpu
|
||||
self.max_replicas = num_workers
|
||||
|
||||
self.wrap_ddp = wrap_ddp
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_tqdm = use_tqdm
|
||||
self.add_dist_sampler = add_dist_sampler
|
||||
@@ -292,7 +296,9 @@ class TorchTrainer:
|
||||
self.local_worker.setup()
|
||||
else:
|
||||
params.update(
|
||||
backend=self.backend, add_dist_sampler=self.add_dist_sampler)
|
||||
backend=self.backend,
|
||||
add_dist_sampler=self.add_dist_sampler,
|
||||
wrap_ddp=self.wrap_ddp)
|
||||
|
||||
# Start local worker
|
||||
self.local_worker = LocalDistributedRunner(
|
||||
|
||||
@@ -60,6 +60,7 @@ class TrainingOperator:
|
||||
world_rank,
|
||||
criterion=None,
|
||||
schedulers=None,
|
||||
device_ids=None,
|
||||
use_gpu=False,
|
||||
use_fp16=False,
|
||||
use_tqdm=False):
|
||||
@@ -81,6 +82,7 @@ class TrainingOperator:
|
||||
type(schedulers)))
|
||||
self._config = config
|
||||
self._use_fp16 = use_fp16
|
||||
self._device_ids = device_ids
|
||||
self._use_gpu = use_gpu and torch.cuda.is_available()
|
||||
self._device = torch.device("cuda" if self._use_gpu else "cpu")
|
||||
if tqdm is None and use_tqdm:
|
||||
@@ -327,21 +329,27 @@ class TrainingOperator:
|
||||
}
|
||||
|
||||
def state_dict(self):
|
||||
"""Override this to return a representation of the operator state."""
|
||||
"""Override this to return a representation of the operator state.
|
||||
|
||||
Returns:
|
||||
dict: The state dict of the operator."""
|
||||
pass
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Override this to load the representation of the operator state."""
|
||||
"""Override this to load the representation of the operator state.
|
||||
|
||||
Args:
|
||||
state_dict (dict): State dict as returned by the operator. """
|
||||
pass
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""The torch device, at your convenience."""
|
||||
"""torch.device: The appropriate torch device, at your convenience."""
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
"""Dictionary as provided into TorchTrainer."""
|
||||
"""dict: Provided into TorchTrainer."""
|
||||
return self._config
|
||||
|
||||
@property
|
||||
@@ -366,21 +374,18 @@ class TrainingOperator:
|
||||
|
||||
@property
|
||||
def train_loader(self):
|
||||
"""
|
||||
Data loader for the validation dataset created by the ``data_creator``.
|
||||
"""Iterable: 1st Dataloader from ``data_creator``.
|
||||
"""
|
||||
return self._train_loader
|
||||
|
||||
@property
|
||||
def validation_loader(self):
|
||||
"""
|
||||
Data loader for the train dataset created by the ``data_creator``.
|
||||
"""
|
||||
"""Iterable: 2nd Dataloader from ``data_creator``."""
|
||||
return self._validation_loader
|
||||
|
||||
@property
|
||||
def world_rank(self):
|
||||
"""The rank of the parent runner. Always 0 if not distributed."""
|
||||
"""int: The rank of the parent runner. Always 0 if not distributed."""
|
||||
return self._world_rank
|
||||
|
||||
@property
|
||||
@@ -401,14 +406,22 @@ class TrainingOperator:
|
||||
|
||||
@property
|
||||
def use_fp16(self):
|
||||
"""Whether the model and optimizer have been FP16 enabled."""
|
||||
"""bool: Whether the model and optimizer have been FP16 enabled."""
|
||||
return self._use_fp16
|
||||
|
||||
@property
|
||||
def use_tqdm(self):
|
||||
"""Whether tqdm progress bars are enabled."""
|
||||
"""bool: Whether tqdm progress bars are enabled."""
|
||||
return self._use_tqdm
|
||||
|
||||
@property
|
||||
def device_ids(self):
|
||||
"""List[int]: Device IDs for the model.
|
||||
|
||||
This is useful for using batch norm with DistributedDataParallel.
|
||||
"""
|
||||
return self._device_ids
|
||||
|
||||
|
||||
class _TestingOperator(TrainingOperator):
|
||||
def train_epoch(self, iterator, info):
|
||||
|
||||
Reference in New Issue
Block a user