From 92718de40c6ee67e37b25988d7f3c9f72df4dcb4 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Wed, 4 Nov 2020 13:58:51 -0800 Subject: [PATCH] [SGD] Better support for custom DDP (#11771) --- doc/source/raysgd/raysgd_pytorch.rst | 40 ++++++++++--- python/ray/util/sgd/tests/test_torch.py | 16 ++++++ .../sgd/torch/distributed_torch_runner.py | 1 - .../sgd/torch/examples/image_models/train.py | 1 - .../transformers/transformers_example.py | 6 +- python/ray/util/sgd/torch/ptl_operator.py | 7 +-- python/ray/util/sgd/torch/torch_runner.py | 3 - python/ray/util/sgd/torch/torch_trainer.py | 17 +++--- .../ray/util/sgd/torch/training_operator.py | 56 +++++++++++++++---- 9 files changed, 106 insertions(+), 41 deletions(-) diff --git a/doc/source/raysgd/raysgd_pytorch.rst b/doc/source/raysgd/raysgd_pytorch.rst index bcd8379fd..093a24a79 100644 --- a/doc/source/raysgd/raysgd_pytorch.rst +++ b/doc/source/raysgd/raysgd_pytorch.rst @@ -251,7 +251,20 @@ TorchTrainer automatically applies a DistributedDataParallel wrapper to your mod DistributedDataParallel(model, device_ids=self.device_ids) -By setting ``TorchTrainer(wrap_ddp=False)``, you can change the parameters on the DistributedDataParallel wrapper or provide your own wrapper. +You can also pass in additional arguments to DistributedDataParallel by setting the `ddp_args` field in your `TrainingOperator`. + +.. code-block:: python + :emphasize-lines: 6 + + from ray.util.sgd.torch import TrainingOperator + + class CustomOperator(TrainingOperator): + def setup(self, config): + ... + self.model, ... = self.register(..., ddp_args={"find_unused_parameters": True}) + + +If you want to use a custom wrapper for distributed training or if you want to wrap in DistributedDataParallel yourself, you can do so by setting ``TorchTrainer(wrap_ddp=False)``. .. note:: Make sure to register the model before it is wrapped in DistributedDataParallel or a custom wrapper. @@ -402,20 +415,29 @@ You can enable mixed precision training for PyTorch with the ``use_fp16`` flag. ``Apex`` is a Pytorch extension with NVIDIA-maintained utilities to streamline mixed precision and distributed training. When ``use_fp16=True``, you should not manually cast your model or data to ``.half()``. The flag informs the Trainer to call ``amp.initialize`` on the created models and optimizers and optimize using the scaled loss: ``amp.scale_loss(loss, optimizer)``. -To specify particular parameters for ``amp.initialize``, you can use the ``apex_args`` field for the TorchTrainer constructor. Valid arguments can be found on the `Apex documentation `_: +To specify particular parameters for ``amp.initialize``, you can use the ``apex_args`` field when calling `self.register` in your `TrainingOperator`. Valid arguments can be found on the `Apex documentation `_: .. code-block:: python - :emphasize-lines: 5-10 + :emphasize-lines: 8-12 + + class MyTrainingOperator(TrainingOperator): + def setup(self, config): + models = [...] + optimizers = [...] + model, optimizer = self.register( + models=models, + optimizers=optimizers, + apex_args={ + opt_level="03", + num_losses=2, + verbosity=0 + } + ) trainer = TorchTrainer( training_operator_cls=MyTrainingOperator, num_workers=4, - use_fp16=True, - apex_args={ - opt_level="O3", - num_losses=2, - verbosity=0 - } + use_fp16=True ) Note that if implementing custom training (:ref:`raysgd-custom-training`), you will need to manage loss scaling manually. diff --git a/python/ray/util/sgd/tests/test_torch.py b/python/ray/util/sgd/tests/test_torch.py index a34ac3534..c84d3f7f5 100644 --- a/python/ray/util/sgd/tests/test_torch.py +++ b/python/ray/util/sgd/tests/test_torch.py @@ -766,6 +766,22 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811 trainer2.shutdown() +def test_custom_ddp_args(ray_start_2_cpus): + class TestTrainingOperator(TrainingOperator): + def setup(self, config): + model = model_creator(config) + optimizer = optimizer_creator(model, config) + train_loader, val_loader = data_creator(config) + + self.model, self.optimizer, = \ + self.register( + models=model, optimizers=optimizer, ddp_args={ + "find_unused_parameters": True}) + assert self.model.find_unused_parameters + + TorchTrainer(training_operator_cls=TestTrainingOperator, num_workers=2) + + @pytest.mark.parametrize("use_local", [True, False]) def test_multi_input_model(ray_start_2_cpus, use_local): def model_creator(config): diff --git a/python/ray/util/sgd/torch/distributed_torch_runner.py b/python/ray/util/sgd/torch/distributed_torch_runner.py index fd338c9cd..e09045072 100644 --- a/python/ray/util/sgd/torch/distributed_torch_runner.py +++ b/python/ray/util/sgd/torch/distributed_torch_runner.py @@ -87,7 +87,6 @@ class DistributedTorchRunner(TorchRunner): use_gpu=self.use_gpu, use_fp16=self.use_fp16, use_tqdm=self.use_tqdm, - apex_args=self.apex_args, wrap_ddp=self.wrap_ddp, add_dist_sampler=self.add_dist_sampler, scheduler_step_freq=self.scheduler_step_freq) diff --git a/python/ray/util/sgd/torch/examples/image_models/train.py b/python/ray/util/sgd/torch/examples/image_models/train.py index 48fe4579c..5e0574b66 100644 --- a/python/ray/util/sgd/torch/examples/image_models/train.py +++ b/python/ray/util/sgd/torch/examples/image_models/train.py @@ -142,7 +142,6 @@ def main(): training_operator_cls=CustomTrainingOperator, use_tqdm=True, use_fp16=args.amp, - apex_args={"opt_level": "O1"}, config={ "args": args, BATCH_SIZE: args.batch_size diff --git a/python/ray/util/sgd/torch/examples/transformers/transformers_example.py b/python/ray/util/sgd/torch/examples/transformers/transformers_example.py index 1b3b0826f..c118a57a5 100644 --- a/python/ray/util/sgd/torch/examples/transformers/transformers_example.py +++ b/python/ray/util/sgd/torch/examples/transformers/transformers_example.py @@ -160,8 +160,9 @@ class TransformerOperator(TrainingOperator): self.model, self.optimizer = self.register( models=model, optimizers=optimizer, - train_loader=train_loader, - validation_loader=None) + apex_args={"opt_level": args.fp16_opt_level}) + + self.register_data(train_loader=train_loader, validation_loader=None) self.train_data_len = len(self.train_loader) self._warmup_scheduler = get_linear_schedule_with_warmup( @@ -331,7 +332,6 @@ def main(): trainer = TorchTrainer( training_operator_cls=TransformerOperator, use_fp16=args.fp16, - apex_args={"opt_level": args.fp16_opt_level}, num_workers=args.num_workers, use_gpu=use_gpu, use_tqdm=True, diff --git a/python/ray/util/sgd/torch/ptl_operator.py b/python/ray/util/sgd/torch/ptl_operator.py index 3403f2559..46d3a09d0 100644 --- a/python/ray/util/sgd/torch/ptl_operator.py +++ b/python/ray/util/sgd/torch/ptl_operator.py @@ -27,16 +27,15 @@ logger = logging.getLogger(__name__) class LightningOperator(TrainingOperator, TrainerModelHooksMixin, TrainerOptimizersMixin): - def _configure_amp(self, amp, models, optimizers): + def _configure_amp(self, amp, models, optimizers, apex_args=None): assert len(models) == 1 model = models[0] assert isinstance(model, ptl.LightningModule) - amp_level = self._apex_args.get("opt_level", "O2") model, optimizers = model.configure_apex( - amp, model, optimizers, amp_level=amp_level) + amp, model, optimizers, amp_level="O2") return [model], optimizers - def _configure_ddp(self, models, device_ids): + def _configure_ddp(self, models, device_ids, ddp_args=None): assert len(models) == 1 model = models[0] assert isinstance(model, ptl.LightningModule) diff --git a/python/ray/util/sgd/torch/torch_runner.py b/python/ray/util/sgd/torch/torch_runner.py index f7b3872c3..120b1a7ea 100644 --- a/python/ray/util/sgd/torch/torch_runner.py +++ b/python/ray/util/sgd/torch/torch_runner.py @@ -28,7 +28,6 @@ class TorchRunner: serialize_data_creation=True, use_fp16=False, use_tqdm=False, - apex_args=None, scheduler_step_freq=None): self.training_operator_cls = training_operator_cls self.config = {} if config is None else config @@ -40,7 +39,6 @@ class TorchRunner: self.use_gpu = use_gpu self.use_fp16 = use_fp16 self.use_tqdm = use_tqdm - self.apex_args = apex_args or {} if use_fp16 and not amp: raise ImportError( "Please install apex from " @@ -64,7 +62,6 @@ class TorchRunner: use_gpu=self.use_gpu, use_fp16=self.use_fp16, use_tqdm=self.use_tqdm, - apex_args=self.apex_args, scheduler_step_freq=self.scheduler_step_freq) def get_iterator(self, training=True): diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index e49d0c3aa..0602222c5 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -113,10 +113,6 @@ class TorchTrainer: is installed. This is automatically done after the model and optimizers are constructed and will work for multi-model training. Please see https://github.com/NVIDIA/apex for more details. - apex_args (dict|None): Dict containing keyword args for amp.initialize. - See https://nvidia.github.io/apex/amp.html#module-apex.amp. By - default, the models and optimizers are passed in. Consider using - "num_losses" if operating over multiple models and optimizers. scheduler_step_freq: "batch", "epoch", "manual", or None. This will determine when ``scheduler.step`` is called. If "batch", ``step`` will be called after every optimizer step. If "epoch", @@ -150,7 +146,6 @@ class TorchTrainer: timeout_s=NCCL_TIMEOUT_S, use_fp16=False, use_tqdm=False, - apex_args=None, add_dist_sampler=True, scheduler_step_freq=None, use_local=False, @@ -164,6 +159,7 @@ class TorchTrainer: loss_creator=None, serialize_data_creation=None, data_loader_args=None, + apex_args=None, ): if (model_creator or data_creator or optimizer_creator or scheduler_creator or loss_creator): @@ -201,6 +197,12 @@ class TorchTrainer: "config={ray.util.sgd.utils.BATCH_SIZE: N} to specify a " "batch size to be used across all workers.") + if apex_args is not None: + raise DeprecationWarning( + "apex_args is deprecated. Pass in apex_args when calling " + "`register` in the `setup` method of your `TrainingOperator` " + "instead.") + if serialize_data_creation is True: if log_once("serialize_data_creation"): logging.warning( @@ -242,10 +244,6 @@ class TorchTrainer: self.add_dist_sampler = add_dist_sampler self.use_local = use_local - if apex_args and not isinstance(apex_args, dict): - raise ValueError("apex_args needs to be a dict object.") - - self.apex_args = apex_args self.temp_dir = tempfile.mkdtemp(prefix="raysgd") self._num_failures = 0 self._last_resize = float("-inf") @@ -294,7 +292,6 @@ class TorchTrainer: use_fp16=self.use_fp16, use_gpu=self.use_gpu, use_tqdm=self.use_tqdm, - apex_args=self.apex_args, scheduler_step_freq=self.scheduler_step_freq) dist_params = dict( diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index 8e40d37f7..b8ca0e785 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -125,7 +125,6 @@ class TrainingOperator: use_gpu=False, use_fp16=False, use_tqdm=False, - apex_args=None, wrap_ddp=False, add_dist_sampler=False, scheduler_step_freq=None): @@ -142,7 +141,6 @@ class TrainingOperator: raise ValueError("tqdm must be installed to use tqdm in training.") self._use_tqdm = use_tqdm self.global_step = 0 - self._apex_args = apex_args if apex_args else {} self._wrap_ddp = wrap_ddp self._add_dist_sampler = add_dist_sampler self._scheduler_step_freq = scheduler_step_freq @@ -154,14 +152,13 @@ class TrainingOperator: """Passes in the timers from the Runner.""" self.timers = timers - def _configure_amp(self, amp, models, optimizers): - models, optimizers = amp.initialize(models, optimizers, - **self._apex_args) + def _configure_amp(self, amp, models, optimizers, apex_args): + models, optimizers = amp.initialize(models, optimizers, **apex_args) return models, optimizers - def _configure_ddp(self, models, device_ids): + def _configure_ddp(self, models, device_ids, ddp_args): return [ - DistributedDataParallel(model, device_ids=device_ids) + DistributedDataParallel(model, device_ids=device_ids, **ddp_args) for model in models ] @@ -188,7 +185,14 @@ class TrainingOperator: """ raise NotImplementedError - def register(self, *, models, optimizers, criterion=None, schedulers=None): + def register(self, + *, + models, + optimizers, + criterion=None, + schedulers=None, + ddp_args=None, + apex_args=None): """Registers parameters with Ray SGD and sets up training components. By calling this method to register your models, optimizers, @@ -200,6 +204,14 @@ class TrainingOperator: If more than one model, optimizer, or scheduler is passed in, you should implement your own custom training loop. + Calling register will perform the following steps in this order: + 1. If using GPU, Move model(s) and criterion to the corresponding + Cuda device. + 2. If using fp16, initializes amp with model(s), optimizer(s), + and apex_args. + 3. If using distributed training and wrap_ddp is True, + wraps model(s) with DistributedDataParallel. + .. code-block:: python class MyTrainingOperator(TrainingOperator): @@ -238,11 +250,30 @@ class TrainingOperator: schedulers (torch.optim.lr_scheduler or Iterable[ torch.optim.lr_scheduler], optional): A learning rate scheduler or multiple learning rate schedulers. + ddp_args (dict|None): Dict containing keyword args for + DistributedDataParallel if distributed training is being + used. `module` and `device_ids` are automatically passed in, + but this dict is useful for passing in other args such as + `find_unused_parameters=True`. + apex_args (dict|None): Dict containing keyword args for + amp.initialize if fp16 is being used. See + https://nvidia.github.io/apex/amp.html#module-apex.amp. + By default, the models and optimizers are passed in. + Consider using "num_losses" if operating over multiple + models and optimizers. Returns: Tuple of model, optimizer, criterion if not None, and scheduler if not None. """ + if ddp_args and not isinstance(ddp_args, dict): + raise ValueError("ddp_args needs to be a dict object.") + ddp_args = ddp_args if ddp_args else {} + + if apex_args and not isinstance(apex_args, dict): + raise ValueError("apex_args needs to be a dict object.") + apex_args = apex_args if apex_args else {} + return_vals = [] logger.debug("Registering models.") self._original_models = models @@ -285,12 +316,17 @@ class TrainingOperator: logger.debug("Setting up Apex.") self._amp = amp self._original_models, self._optimizers = self._configure_amp( - self._amp, self._original_models, self._optimizers) + self._amp, + self._original_models, + self._optimizers, + apex_args=apex_args) if self._wrap_ddp: logging.debug("Setting up DDP for models.") self._models = self._configure_ddp( - models=self._original_models, device_ids=self.device_ids) + models=self._original_models, + device_ids=self.device_ids, + ddp_args=ddp_args) else: self._models = self._original_models