mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 05:41:51 +08:00
[SGD] Better support for custom DDP (#11771)
This commit is contained in:
@@ -766,6 +766,22 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811
|
||||
trainer2.shutdown()
|
||||
|
||||
|
||||
def test_custom_ddp_args(ray_start_2_cpus):
|
||||
class TestTrainingOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
model = model_creator(config)
|
||||
optimizer = optimizer_creator(model, config)
|
||||
train_loader, val_loader = data_creator(config)
|
||||
|
||||
self.model, self.optimizer, = \
|
||||
self.register(
|
||||
models=model, optimizers=optimizer, ddp_args={
|
||||
"find_unused_parameters": True})
|
||||
assert self.model.find_unused_parameters
|
||||
|
||||
TorchTrainer(training_operator_cls=TestTrainingOperator, num_workers=2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_multi_input_model(ray_start_2_cpus, use_local):
|
||||
def model_creator(config):
|
||||
|
||||
@@ -87,7 +87,6 @@ class DistributedTorchRunner(TorchRunner):
|
||||
use_gpu=self.use_gpu,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm,
|
||||
apex_args=self.apex_args,
|
||||
wrap_ddp=self.wrap_ddp,
|
||||
add_dist_sampler=self.add_dist_sampler,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
|
||||
@@ -142,7 +142,6 @@ def main():
|
||||
training_operator_cls=CustomTrainingOperator,
|
||||
use_tqdm=True,
|
||||
use_fp16=args.amp,
|
||||
apex_args={"opt_level": "O1"},
|
||||
config={
|
||||
"args": args,
|
||||
BATCH_SIZE: args.batch_size
|
||||
|
||||
@@ -160,8 +160,9 @@ class TransformerOperator(TrainingOperator):
|
||||
self.model, self.optimizer = self.register(
|
||||
models=model,
|
||||
optimizers=optimizer,
|
||||
train_loader=train_loader,
|
||||
validation_loader=None)
|
||||
apex_args={"opt_level": args.fp16_opt_level})
|
||||
|
||||
self.register_data(train_loader=train_loader, validation_loader=None)
|
||||
|
||||
self.train_data_len = len(self.train_loader)
|
||||
self._warmup_scheduler = get_linear_schedule_with_warmup(
|
||||
@@ -331,7 +332,6 @@ def main():
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=TransformerOperator,
|
||||
use_fp16=args.fp16,
|
||||
apex_args={"opt_level": args.fp16_opt_level},
|
||||
num_workers=args.num_workers,
|
||||
use_gpu=use_gpu,
|
||||
use_tqdm=True,
|
||||
|
||||
@@ -27,16 +27,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
TrainerOptimizersMixin):
|
||||
def _configure_amp(self, amp, models, optimizers):
|
||||
def _configure_amp(self, amp, models, optimizers, apex_args=None):
|
||||
assert len(models) == 1
|
||||
model = models[0]
|
||||
assert isinstance(model, ptl.LightningModule)
|
||||
amp_level = self._apex_args.get("opt_level", "O2")
|
||||
model, optimizers = model.configure_apex(
|
||||
amp, model, optimizers, amp_level=amp_level)
|
||||
amp, model, optimizers, amp_level="O2")
|
||||
return [model], optimizers
|
||||
|
||||
def _configure_ddp(self, models, device_ids):
|
||||
def _configure_ddp(self, models, device_ids, ddp_args=None):
|
||||
assert len(models) == 1
|
||||
model = models[0]
|
||||
assert isinstance(model, ptl.LightningModule)
|
||||
|
||||
@@ -28,7 +28,6 @@ class TorchRunner:
|
||||
serialize_data_creation=True,
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
scheduler_step_freq=None):
|
||||
self.training_operator_cls = training_operator_cls
|
||||
self.config = {} if config is None else config
|
||||
@@ -40,7 +39,6 @@ class TorchRunner:
|
||||
self.use_gpu = use_gpu
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_tqdm = use_tqdm
|
||||
self.apex_args = apex_args or {}
|
||||
if use_fp16 and not amp:
|
||||
raise ImportError(
|
||||
"Please install apex from "
|
||||
@@ -64,7 +62,6 @@ class TorchRunner:
|
||||
use_gpu=self.use_gpu,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm,
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
|
||||
def get_iterator(self, training=True):
|
||||
|
||||
@@ -113,10 +113,6 @@ class TorchTrainer:
|
||||
is installed. This is automatically done after the model and
|
||||
optimizers are constructed and will work for multi-model training.
|
||||
Please see https://github.com/NVIDIA/apex for more details.
|
||||
apex_args (dict|None): Dict containing keyword args for amp.initialize.
|
||||
See https://nvidia.github.io/apex/amp.html#module-apex.amp. By
|
||||
default, the models and optimizers are passed in. Consider using
|
||||
"num_losses" if operating over multiple models and optimizers.
|
||||
scheduler_step_freq: "batch", "epoch", "manual", or None. This will
|
||||
determine when ``scheduler.step`` is called. If "batch",
|
||||
``step`` will be called after every optimizer step. If "epoch",
|
||||
@@ -150,7 +146,6 @@ class TorchTrainer:
|
||||
timeout_s=NCCL_TIMEOUT_S,
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
add_dist_sampler=True,
|
||||
scheduler_step_freq=None,
|
||||
use_local=False,
|
||||
@@ -164,6 +159,7 @@ class TorchTrainer:
|
||||
loss_creator=None,
|
||||
serialize_data_creation=None,
|
||||
data_loader_args=None,
|
||||
apex_args=None,
|
||||
):
|
||||
if (model_creator or data_creator or optimizer_creator
|
||||
or scheduler_creator or loss_creator):
|
||||
@@ -201,6 +197,12 @@ class TorchTrainer:
|
||||
"config={ray.util.sgd.utils.BATCH_SIZE: N} to specify a "
|
||||
"batch size to be used across all workers.")
|
||||
|
||||
if apex_args is not None:
|
||||
raise DeprecationWarning(
|
||||
"apex_args is deprecated. Pass in apex_args when calling "
|
||||
"`register` in the `setup` method of your `TrainingOperator` "
|
||||
"instead.")
|
||||
|
||||
if serialize_data_creation is True:
|
||||
if log_once("serialize_data_creation"):
|
||||
logging.warning(
|
||||
@@ -242,10 +244,6 @@ class TorchTrainer:
|
||||
self.add_dist_sampler = add_dist_sampler
|
||||
self.use_local = use_local
|
||||
|
||||
if apex_args and not isinstance(apex_args, dict):
|
||||
raise ValueError("apex_args needs to be a dict object.")
|
||||
|
||||
self.apex_args = apex_args
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="raysgd")
|
||||
self._num_failures = 0
|
||||
self._last_resize = float("-inf")
|
||||
@@ -294,7 +292,6 @@ class TorchTrainer:
|
||||
use_fp16=self.use_fp16,
|
||||
use_gpu=self.use_gpu,
|
||||
use_tqdm=self.use_tqdm,
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
|
||||
dist_params = dict(
|
||||
|
||||
@@ -125,7 +125,6 @@ class TrainingOperator:
|
||||
use_gpu=False,
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
wrap_ddp=False,
|
||||
add_dist_sampler=False,
|
||||
scheduler_step_freq=None):
|
||||
@@ -142,7 +141,6 @@ class TrainingOperator:
|
||||
raise ValueError("tqdm must be installed to use tqdm in training.")
|
||||
self._use_tqdm = use_tqdm
|
||||
self.global_step = 0
|
||||
self._apex_args = apex_args if apex_args else {}
|
||||
self._wrap_ddp = wrap_ddp
|
||||
self._add_dist_sampler = add_dist_sampler
|
||||
self._scheduler_step_freq = scheduler_step_freq
|
||||
@@ -154,14 +152,13 @@ class TrainingOperator:
|
||||
"""Passes in the timers from the Runner."""
|
||||
self.timers = timers
|
||||
|
||||
def _configure_amp(self, amp, models, optimizers):
|
||||
models, optimizers = amp.initialize(models, optimizers,
|
||||
**self._apex_args)
|
||||
def _configure_amp(self, amp, models, optimizers, apex_args):
|
||||
models, optimizers = amp.initialize(models, optimizers, **apex_args)
|
||||
return models, optimizers
|
||||
|
||||
def _configure_ddp(self, models, device_ids):
|
||||
def _configure_ddp(self, models, device_ids, ddp_args):
|
||||
return [
|
||||
DistributedDataParallel(model, device_ids=device_ids)
|
||||
DistributedDataParallel(model, device_ids=device_ids, **ddp_args)
|
||||
for model in models
|
||||
]
|
||||
|
||||
@@ -188,7 +185,14 @@ class TrainingOperator:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def register(self, *, models, optimizers, criterion=None, schedulers=None):
|
||||
def register(self,
|
||||
*,
|
||||
models,
|
||||
optimizers,
|
||||
criterion=None,
|
||||
schedulers=None,
|
||||
ddp_args=None,
|
||||
apex_args=None):
|
||||
"""Registers parameters with Ray SGD and sets up training components.
|
||||
|
||||
By calling this method to register your models, optimizers,
|
||||
@@ -200,6 +204,14 @@ class TrainingOperator:
|
||||
If more than one model, optimizer, or scheduler is passed in,
|
||||
you should implement your own custom training loop.
|
||||
|
||||
Calling register will perform the following steps in this order:
|
||||
1. If using GPU, Move model(s) and criterion to the corresponding
|
||||
Cuda device.
|
||||
2. If using fp16, initializes amp with model(s), optimizer(s),
|
||||
and apex_args.
|
||||
3. If using distributed training and wrap_ddp is True,
|
||||
wraps model(s) with DistributedDataParallel.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyTrainingOperator(TrainingOperator):
|
||||
@@ -238,11 +250,30 @@ class TrainingOperator:
|
||||
schedulers (torch.optim.lr_scheduler or Iterable[
|
||||
torch.optim.lr_scheduler], optional): A learning rate
|
||||
scheduler or multiple learning rate schedulers.
|
||||
ddp_args (dict|None): Dict containing keyword args for
|
||||
DistributedDataParallel if distributed training is being
|
||||
used. `module` and `device_ids` are automatically passed in,
|
||||
but this dict is useful for passing in other args such as
|
||||
`find_unused_parameters=True`.
|
||||
apex_args (dict|None): Dict containing keyword args for
|
||||
amp.initialize if fp16 is being used. See
|
||||
https://nvidia.github.io/apex/amp.html#module-apex.amp.
|
||||
By default, the models and optimizers are passed in.
|
||||
Consider using "num_losses" if operating over multiple
|
||||
models and optimizers.
|
||||
|
||||
Returns:
|
||||
Tuple of model, optimizer, criterion if not None, and scheduler
|
||||
if not None.
|
||||
"""
|
||||
if ddp_args and not isinstance(ddp_args, dict):
|
||||
raise ValueError("ddp_args needs to be a dict object.")
|
||||
ddp_args = ddp_args if ddp_args else {}
|
||||
|
||||
if apex_args and not isinstance(apex_args, dict):
|
||||
raise ValueError("apex_args needs to be a dict object.")
|
||||
apex_args = apex_args if apex_args else {}
|
||||
|
||||
return_vals = []
|
||||
logger.debug("Registering models.")
|
||||
self._original_models = models
|
||||
@@ -285,12 +316,17 @@ class TrainingOperator:
|
||||
logger.debug("Setting up Apex.")
|
||||
self._amp = amp
|
||||
self._original_models, self._optimizers = self._configure_amp(
|
||||
self._amp, self._original_models, self._optimizers)
|
||||
self._amp,
|
||||
self._original_models,
|
||||
self._optimizers,
|
||||
apex_args=apex_args)
|
||||
|
||||
if self._wrap_ddp:
|
||||
logging.debug("Setting up DDP for models.")
|
||||
self._models = self._configure_ddp(
|
||||
models=self._original_models, device_ids=self.device_ids)
|
||||
models=self._original_models,
|
||||
device_ids=self.device_ids,
|
||||
ddp_args=ddp_args)
|
||||
else:
|
||||
self._models = self._original_models
|
||||
|
||||
|
||||
Reference in New Issue
Block a user