[sgd] make ddp optional (#7875)

* loosen

* devices

* tryitout

* fix

* fix

* fix

* easy

* test

* fix

* fix

* better visibility

* fix
This commit is contained in:
Richard Liaw
2020-04-06 11:41:36 -07:00
committed by GitHub
parent 203c077895
commit f63b4c1110
9 changed files with 239 additions and 74 deletions
+40
View File
@@ -503,6 +503,46 @@ def test_save_and_restore(ray_start_2_cpus, num_workers,
trainer2.shutdown()
def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811
if not dist.is_available():
return
trainer1 = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
wrap_ddp=False,
num_workers=2)
trainer1.train()
checkpoint_path = os.path.join(tmp_path, "checkpoint")
trainer1.save(checkpoint_path)
model1 = trainer1.get_model()
assert not hasattr(trainer1.local_worker.training_operator.model, "module")
assert hasattr(trainer1.local_worker.training_operator, "device_ids")
trainer1.shutdown()
trainer2 = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
wrap_ddp=False,
num_workers=2)
trainer2.load(checkpoint_path)
model2 = trainer2.get_model()
model1_state_dict = model1.state_dict()
model2_state_dict = model2.state_dict()
assert set(model1_state_dict.keys()) == set(model2_state_dict.keys())
for k in model1_state_dict:
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
trainer2.shutdown()
def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
if not dist.is_available():
return
+3 -1
View File
@@ -1,8 +1,10 @@
from ray.ray_constants import env_integer
USE_FP16 = "__use_fp16__"
NUM_STEPS = "__num_steps__"
SCHEDULER_STEP = "scheduler_step"
SCHEDULER_STEP_BATCH = "batch"
SCHEDULER_STEP_EPOCH = "epoch"
NCCL_TIMEOUT_IN_SECONDS = 10
NCCL_TIMEOUT_S = env_integer("NCCL_TIMEOUT_S", 10)
VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH}
@@ -9,7 +9,7 @@ import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_IN_SECONDS
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
import ray
from ray.util.sgd.torch.torch_runner import TorchRunner, _remind_gpu_usage
@@ -26,15 +26,23 @@ class DistributedTorchRunner(TorchRunner):
backend (str): Backend used by distributed PyTorch.
add_dist_sampler (bool): Whether to automatically add a
DistributedSampler to all created dataloaders.
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
over each model. If False, you are expected to call it yourself.
kwargs: Keyword arguments for TorchRunner.
"""
def __init__(self, *args, backend="gloo", add_dist_sampler=True, **kwargs):
def __init__(self,
*args,
backend="gloo",
add_dist_sampler=True,
wrap_ddp=False,
**kwargs):
super(DistributedTorchRunner, self).__init__(*args, **kwargs)
if backend not in ("gloo", "nccl"):
raise ValueError("Backend must be one of 'gloo' or 'nccl'.")
self.backend = backend
self.wrap_ddp = wrap_ddp
self.add_dist_sampler = add_dist_sampler
def setup(self, url, world_rank, world_size):
@@ -61,7 +69,7 @@ class DistributedTorchRunner(TorchRunner):
"To override this behavior, you can set NCCL_BLOCKING_WAIT=0.")
os.environ["NCCL_BLOCKING_WAIT"] = "1"
timeout = timedelta(seconds=NCCL_TIMEOUT_IN_SECONDS)
timeout = timedelta(seconds=NCCL_TIMEOUT_S)
dist.init_process_group(
backend=self.backend,
init_method=url,
@@ -73,9 +81,10 @@ class DistributedTorchRunner(TorchRunner):
if self.use_gpu and torch.cuda.is_available():
# https://github.com/allenai/allennlp/issues/1090
self._set_cuda_device_id()
self.set_cuda_device_id()
def _set_cuda_device_id(self):
def set_cuda_device_id(self):
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
self.device_ids = [0]
def _setup_training(self):
@@ -99,23 +108,27 @@ class DistributedTorchRunner(TorchRunner):
self._create_schedulers_if_available()
self._try_setup_apex()
# This needs to happen after apex
self.models = [
DistributedDataParallel(model, device_ids=self.device_ids)
for model in self.models
]
self._create_loss()
training_models = self.models
if self.wrap_ddp:
# This needs to happen after apex
training_models = [
DistributedDataParallel(model, device_ids=self.device_ids)
for model in self.models
]
self.training_operator = self.training_operator_cls(
self.config,
models=self.models,
models=training_models,
optimizers=self.optimizers,
criterion=self.criterion,
train_loader=self.train_loader,
validation_loader=self.validation_loader,
world_rank=self.world_rank,
schedulers=self.schedulers,
device_ids=self.device_ids,
use_gpu=self.use_gpu,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm)
@@ -158,17 +171,6 @@ class DistributedTorchRunner(TorchRunner):
self.train_loader.sampler.set_epoch(self.epochs)
return super(DistributedTorchRunner, self).train_epoch(**kwargs)
def _get_model_state_dicts(self):
"""Fetch state from ``model.module`` instead of ``model``.
This is needed for PyTorch DistributedDataParallel models.
"""
return [model.module.state_dict() for model in self.models]
def _set_model_state_dicts(self, model_state_dicts):
for model, model_state_dict in zip(self.models, model_state_dicts):
model.module.load_state_dict(model_state_dict)
def shutdown(self):
"""Attempts to shut down the worker."""
# However, it seems to be harmless to remove permanently
@@ -179,6 +181,14 @@ class DistributedTorchRunner(TorchRunner):
super(DistributedTorchRunner, self).shutdown()
def _init_cuda_context():
# Force cuda initialization
# Inspired by https://github.com/pytorch/pytorch/blob/
# f050b16dd95b2bcce9853882fd3fb07a6fd80378/torch/testing/
# _internal/common_cuda.py
torch.cuda.is_available()
class _DummyActor:
def cuda_devices(self):
return os.environ["CUDA_VISIBLE_DEVICES"]
@@ -203,7 +213,7 @@ class LocalDistributedRunner(DistributedTorchRunner):
# Reserve a local GPU or CPU for the local worker
# TODO: we should make sure this NEVER dies.
self.local_device = "0"
global _dummy_actor
if not self.is_actor() and _dummy_actor is None:
_dummy_actor = ray.remote(
@@ -215,16 +225,25 @@ class LocalDistributedRunner(DistributedTorchRunner):
# This is a pretty annoying workaround. To enable SyncBatchNorm,
# we need to signify that we are using only 1 CUDA device (via
# the DDP constructor). However, on the local worker,
# we set the CUDA_VISIBLE_DEVICES at runtime rather at process
# start. This means that we have to make sure that DDP knows which
# specific device we are using.
# the DDP constructor).
# However, on the local worker, we have to set the
# CUDA_VISIBLE_DEVICES at runtime rather at process start.
# You can only call setdevice(int > 0) after you've interacted with
# torch.cuda. But you can't guarantee that you _haven't_ interacted
# with it (user can do arbitrary things), so we force an
# interaction.
_init_cuda_context()
os.environ["CUDA_VISIBLE_DEVICES"] = self.local_device
if self.local_device:
torch.cuda.set_device(int(self.local_device))
try:
torch.cuda.set_device(int(self.local_device))
except RuntimeError:
logger.error("This happens if cuda is not initialized.")
raise
super(LocalDistributedRunner, self).__init__(*args, **kwargs)
def _set_cuda_device_id(self):
def set_cuda_device_id(self):
self.device_ids = [int(self.local_device)]
def shutdown(self, cleanup=True):
+3 -9
View File
@@ -239,19 +239,12 @@ class TorchRunner:
self.timers.disable()
self.training_operator._set_timers(self.timers)
def _get_model_state_dicts(self):
return [model.state_dict() for model in self.models]
def _set_model_state_dicts(self, models_state_dicts):
for model, state_dict in zip(self.models, models_state_dicts):
model.load_state_dict(state_dict)
def state_dict(self):
"""Returns the state of the runner."""
state = {
"epoch": self.epochs,
"operator": self.training_operator.state_dict(),
"models": self._get_model_state_dicts(),
"models": [model.state_dict() for model in self.models],
"optimizers": [opt.state_dict() for opt in self.optimizers]
}
if self.schedulers:
@@ -267,7 +260,8 @@ class TorchRunner:
def load_state_dict(self, state):
"""Sets the state of the model."""
self._set_model_state_dicts(state["models"])
for model, state_dict in zip(self.models, state["models"]):
model.load_state_dict(state_dict)
for optimizer, state_dict in zip(self.optimizers, state["optimizers"]):
optimizer.load_state_dict(state_dict)
if self.schedulers:
+7 -1
View File
@@ -117,6 +117,8 @@ class TorchTrainer:
support "nccl", "gloo", and "auto". If "auto", RaySGD will
automatically use "nccl" if `use_gpu` is True, and "gloo"
otherwise.
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
over each model. If False, you are expected to call it yourself.
add_dist_sampler (bool): Whether to automatically add a
DistributedSampler to all created dataloaders. Only applicable
if num_workers > 1.
@@ -154,6 +156,7 @@ class TorchTrainer:
num_workers=1,
use_gpu="auto",
backend="auto",
wrap_ddp=True,
use_fp16=False,
use_tqdm=False,
apex_args=None,
@@ -218,6 +221,7 @@ class TorchTrainer:
self.use_gpu = use_gpu
self.max_replicas = num_workers
self.wrap_ddp = wrap_ddp
self.use_fp16 = use_fp16
self.use_tqdm = use_tqdm
self.add_dist_sampler = add_dist_sampler
@@ -292,7 +296,9 @@ class TorchTrainer:
self.local_worker.setup()
else:
params.update(
backend=self.backend, add_dist_sampler=self.add_dist_sampler)
backend=self.backend,
add_dist_sampler=self.add_dist_sampler,
wrap_ddp=self.wrap_ddp)
# Start local worker
self.local_worker = LocalDistributedRunner(
+25 -12
View File
@@ -60,6 +60,7 @@ class TrainingOperator:
world_rank,
criterion=None,
schedulers=None,
device_ids=None,
use_gpu=False,
use_fp16=False,
use_tqdm=False):
@@ -81,6 +82,7 @@ class TrainingOperator:
type(schedulers)))
self._config = config
self._use_fp16 = use_fp16
self._device_ids = device_ids
self._use_gpu = use_gpu and torch.cuda.is_available()
self._device = torch.device("cuda" if self._use_gpu else "cpu")
if tqdm is None and use_tqdm:
@@ -327,21 +329,27 @@ class TrainingOperator:
}
def state_dict(self):
"""Override this to return a representation of the operator state."""
"""Override this to return a representation of the operator state.
Returns:
dict: The state dict of the operator."""
pass
def load_state_dict(self, state_dict):
"""Override this to load the representation of the operator state."""
"""Override this to load the representation of the operator state.
Args:
state_dict (dict): State dict as returned by the operator. """
pass
@property
def device(self):
"""The torch device, at your convenience."""
"""torch.device: The appropriate torch device, at your convenience."""
return self._device
@property
def config(self):
"""Dictionary as provided into TorchTrainer."""
"""dict: Provided into TorchTrainer."""
return self._config
@property
@@ -366,21 +374,18 @@ class TrainingOperator:
@property
def train_loader(self):
"""
Data loader for the validation dataset created by the ``data_creator``.
"""Iterable: 1st Dataloader from ``data_creator``.
"""
return self._train_loader
@property
def validation_loader(self):
"""
Data loader for the train dataset created by the ``data_creator``.
"""
"""Iterable: 2nd Dataloader from ``data_creator``."""
return self._validation_loader
@property
def world_rank(self):
"""The rank of the parent runner. Always 0 if not distributed."""
"""int: The rank of the parent runner. Always 0 if not distributed."""
return self._world_rank
@property
@@ -401,14 +406,22 @@ class TrainingOperator:
@property
def use_fp16(self):
"""Whether the model and optimizer have been FP16 enabled."""
"""bool: Whether the model and optimizer have been FP16 enabled."""
return self._use_fp16
@property
def use_tqdm(self):
"""Whether tqdm progress bars are enabled."""
"""bool: Whether tqdm progress bars are enabled."""
return self._use_tqdm
@property
def device_ids(self):
"""List[int]: Device IDs for the model.
This is useful for using batch norm with DistributedDataParallel.
"""
return self._device_ids
class _TestingOperator(TrainingOperator):
def train_epoch(self, iterator, info):