[SGD] Better support for custom DDP (#11771)

This commit is contained in:
Amog Kamsetty
2020-11-04 13:58:51 -08:00
committed by GitHub
parent 6147b6a1a3
commit 92718de40c
9 changed files with 106 additions and 41 deletions
+16
View File
@@ -766,6 +766,22 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811
trainer2.shutdown()
def test_custom_ddp_args(ray_start_2_cpus):
class TestTrainingOperator(TrainingOperator):
def setup(self, config):
model = model_creator(config)
optimizer = optimizer_creator(model, config)
train_loader, val_loader = data_creator(config)
self.model, self.optimizer, = \
self.register(
models=model, optimizers=optimizer, ddp_args={
"find_unused_parameters": True})
assert self.model.find_unused_parameters
TorchTrainer(training_operator_cls=TestTrainingOperator, num_workers=2)
@pytest.mark.parametrize("use_local", [True, False])
def test_multi_input_model(ray_start_2_cpus, use_local):
def model_creator(config):
@@ -87,7 +87,6 @@ class DistributedTorchRunner(TorchRunner):
use_gpu=self.use_gpu,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm,
apex_args=self.apex_args,
wrap_ddp=self.wrap_ddp,
add_dist_sampler=self.add_dist_sampler,
scheduler_step_freq=self.scheduler_step_freq)
@@ -142,7 +142,6 @@ def main():
training_operator_cls=CustomTrainingOperator,
use_tqdm=True,
use_fp16=args.amp,
apex_args={"opt_level": "O1"},
config={
"args": args,
BATCH_SIZE: args.batch_size
@@ -160,8 +160,9 @@ class TransformerOperator(TrainingOperator):
self.model, self.optimizer = self.register(
models=model,
optimizers=optimizer,
train_loader=train_loader,
validation_loader=None)
apex_args={"opt_level": args.fp16_opt_level})
self.register_data(train_loader=train_loader, validation_loader=None)
self.train_data_len = len(self.train_loader)
self._warmup_scheduler = get_linear_schedule_with_warmup(
@@ -331,7 +332,6 @@ def main():
trainer = TorchTrainer(
training_operator_cls=TransformerOperator,
use_fp16=args.fp16,
apex_args={"opt_level": args.fp16_opt_level},
num_workers=args.num_workers,
use_gpu=use_gpu,
use_tqdm=True,
+3 -4
View File
@@ -27,16 +27,15 @@ logger = logging.getLogger(__name__)
class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
TrainerOptimizersMixin):
def _configure_amp(self, amp, models, optimizers):
def _configure_amp(self, amp, models, optimizers, apex_args=None):
assert len(models) == 1
model = models[0]
assert isinstance(model, ptl.LightningModule)
amp_level = self._apex_args.get("opt_level", "O2")
model, optimizers = model.configure_apex(
amp, model, optimizers, amp_level=amp_level)
amp, model, optimizers, amp_level="O2")
return [model], optimizers
def _configure_ddp(self, models, device_ids):
def _configure_ddp(self, models, device_ids, ddp_args=None):
assert len(models) == 1
model = models[0]
assert isinstance(model, ptl.LightningModule)
@@ -28,7 +28,6 @@ class TorchRunner:
serialize_data_creation=True,
use_fp16=False,
use_tqdm=False,
apex_args=None,
scheduler_step_freq=None):
self.training_operator_cls = training_operator_cls
self.config = {} if config is None else config
@@ -40,7 +39,6 @@ class TorchRunner:
self.use_gpu = use_gpu
self.use_fp16 = use_fp16
self.use_tqdm = use_tqdm
self.apex_args = apex_args or {}
if use_fp16 and not amp:
raise ImportError(
"Please install apex from "
@@ -64,7 +62,6 @@ class TorchRunner:
use_gpu=self.use_gpu,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm,
apex_args=self.apex_args,
scheduler_step_freq=self.scheduler_step_freq)
def get_iterator(self, training=True):
+7 -10
View File
@@ -113,10 +113,6 @@ class TorchTrainer:
is installed. This is automatically done after the model and
optimizers are constructed and will work for multi-model training.
Please see https://github.com/NVIDIA/apex for more details.
apex_args (dict|None): Dict containing keyword args for amp.initialize.
See https://nvidia.github.io/apex/amp.html#module-apex.amp. By
default, the models and optimizers are passed in. Consider using
"num_losses" if operating over multiple models and optimizers.
scheduler_step_freq: "batch", "epoch", "manual", or None. This will
determine when ``scheduler.step`` is called. If "batch",
``step`` will be called after every optimizer step. If "epoch",
@@ -150,7 +146,6 @@ class TorchTrainer:
timeout_s=NCCL_TIMEOUT_S,
use_fp16=False,
use_tqdm=False,
apex_args=None,
add_dist_sampler=True,
scheduler_step_freq=None,
use_local=False,
@@ -164,6 +159,7 @@ class TorchTrainer:
loss_creator=None,
serialize_data_creation=None,
data_loader_args=None,
apex_args=None,
):
if (model_creator or data_creator or optimizer_creator
or scheduler_creator or loss_creator):
@@ -201,6 +197,12 @@ class TorchTrainer:
"config={ray.util.sgd.utils.BATCH_SIZE: N} to specify a "
"batch size to be used across all workers.")
if apex_args is not None:
raise DeprecationWarning(
"apex_args is deprecated. Pass in apex_args when calling "
"`register` in the `setup` method of your `TrainingOperator` "
"instead.")
if serialize_data_creation is True:
if log_once("serialize_data_creation"):
logging.warning(
@@ -242,10 +244,6 @@ class TorchTrainer:
self.add_dist_sampler = add_dist_sampler
self.use_local = use_local
if apex_args and not isinstance(apex_args, dict):
raise ValueError("apex_args needs to be a dict object.")
self.apex_args = apex_args
self.temp_dir = tempfile.mkdtemp(prefix="raysgd")
self._num_failures = 0
self._last_resize = float("-inf")
@@ -294,7 +292,6 @@ class TorchTrainer:
use_fp16=self.use_fp16,
use_gpu=self.use_gpu,
use_tqdm=self.use_tqdm,
apex_args=self.apex_args,
scheduler_step_freq=self.scheduler_step_freq)
dist_params = dict(
+46 -10
View File
@@ -125,7 +125,6 @@ class TrainingOperator:
use_gpu=False,
use_fp16=False,
use_tqdm=False,
apex_args=None,
wrap_ddp=False,
add_dist_sampler=False,
scheduler_step_freq=None):
@@ -142,7 +141,6 @@ class TrainingOperator:
raise ValueError("tqdm must be installed to use tqdm in training.")
self._use_tqdm = use_tqdm
self.global_step = 0
self._apex_args = apex_args if apex_args else {}
self._wrap_ddp = wrap_ddp
self._add_dist_sampler = add_dist_sampler
self._scheduler_step_freq = scheduler_step_freq
@@ -154,14 +152,13 @@ class TrainingOperator:
"""Passes in the timers from the Runner."""
self.timers = timers
def _configure_amp(self, amp, models, optimizers):
models, optimizers = amp.initialize(models, optimizers,
**self._apex_args)
def _configure_amp(self, amp, models, optimizers, apex_args):
models, optimizers = amp.initialize(models, optimizers, **apex_args)
return models, optimizers
def _configure_ddp(self, models, device_ids):
def _configure_ddp(self, models, device_ids, ddp_args):
return [
DistributedDataParallel(model, device_ids=device_ids)
DistributedDataParallel(model, device_ids=device_ids, **ddp_args)
for model in models
]
@@ -188,7 +185,14 @@ class TrainingOperator:
"""
raise NotImplementedError
def register(self, *, models, optimizers, criterion=None, schedulers=None):
def register(self,
*,
models,
optimizers,
criterion=None,
schedulers=None,
ddp_args=None,
apex_args=None):
"""Registers parameters with Ray SGD and sets up training components.
By calling this method to register your models, optimizers,
@@ -200,6 +204,14 @@ class TrainingOperator:
If more than one model, optimizer, or scheduler is passed in,
you should implement your own custom training loop.
Calling register will perform the following steps in this order:
1. If using GPU, Move model(s) and criterion to the corresponding
Cuda device.
2. If using fp16, initializes amp with model(s), optimizer(s),
and apex_args.
3. If using distributed training and wrap_ddp is True,
wraps model(s) with DistributedDataParallel.
.. code-block:: python
class MyTrainingOperator(TrainingOperator):
@@ -238,11 +250,30 @@ class TrainingOperator:
schedulers (torch.optim.lr_scheduler or Iterable[
torch.optim.lr_scheduler], optional): A learning rate
scheduler or multiple learning rate schedulers.
ddp_args (dict|None): Dict containing keyword args for
DistributedDataParallel if distributed training is being
used. `module` and `device_ids` are automatically passed in,
but this dict is useful for passing in other args such as
`find_unused_parameters=True`.
apex_args (dict|None): Dict containing keyword args for
amp.initialize if fp16 is being used. See
https://nvidia.github.io/apex/amp.html#module-apex.amp.
By default, the models and optimizers are passed in.
Consider using "num_losses" if operating over multiple
models and optimizers.
Returns:
Tuple of model, optimizer, criterion if not None, and scheduler
if not None.
"""
if ddp_args and not isinstance(ddp_args, dict):
raise ValueError("ddp_args needs to be a dict object.")
ddp_args = ddp_args if ddp_args else {}
if apex_args and not isinstance(apex_args, dict):
raise ValueError("apex_args needs to be a dict object.")
apex_args = apex_args if apex_args else {}
return_vals = []
logger.debug("Registering models.")
self._original_models = models
@@ -285,12 +316,17 @@ class TrainingOperator:
logger.debug("Setting up Apex.")
self._amp = amp
self._original_models, self._optimizers = self._configure_amp(
self._amp, self._original_models, self._optimizers)
self._amp,
self._original_models,
self._optimizers,
apex_args=apex_args)
if self._wrap_ddp:
logging.debug("Setting up DDP for models.")
self._models = self._configure_ddp(
models=self._original_models, device_ids=self.device_ids)
models=self._original_models,
device_ids=self.device_ids,
ddp_args=ddp_args)
else:
self._models = self._original_models