From 415be78cc0d1275a29d0ceda550d0d7a7a5224ea Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Tue, 8 Sep 2020 15:19:40 -0700 Subject: [PATCH] [RaySGD] Simplify Builder Process (#10321) Co-authored-by: Richard Liaw --- doc/source/raysgd/raysgd.rst | 34 +- doc/source/raysgd/raysgd_pytorch.rst | 279 +++---- python/ray/util/sgd/tests/test_torch.py | 300 ++++---- .../ray/util/sgd/tests/test_torch_runner.py | 114 ++- python/ray/util/sgd/torch/__init__.py | 9 +- .../sgd/torch/distributed_torch_runner.py | 29 +- .../torch/examples/benchmarks/benchmark.py | 23 +- .../torch/examples/cifar_pytorch_example.py | 99 ++- .../sgd/torch/examples/cifar_pytorch_pbt.py | 93 ++- python/ray/util/sgd/torch/examples/dcgan.py | 93 +-- .../sgd/torch/examples/image_models/train.py | 9 +- .../torch/examples/raysgd_torch_signatures.py | 64 +- .../segmentation/train_segmentation.py | 131 ++-- .../util/sgd/torch/examples/train_example.py | 11 +- .../transformers/transformers_example.py | 132 ++-- .../sgd/torch/examples/transformers/utils.py | 1 + .../util/sgd/torch/examples/tune_example.py | 15 +- python/ray/util/sgd/torch/torch_runner.py | 231 +++--- python/ray/util/sgd/torch/torch_trainer.py | 174 ++--- .../ray/util/sgd/torch/training_operator.py | 708 ++++++++++++++---- 20 files changed, 1436 insertions(+), 1113 deletions(-) diff --git a/doc/source/raysgd/raysgd.rst b/doc/source/raysgd/raysgd.rst index e7b8fdf50..5ab6503e4 100644 --- a/doc/source/raysgd/raysgd.rst +++ b/doc/source/raysgd/raysgd.rst @@ -26,33 +26,41 @@ You can start a ``TorchTrainer`` with the following: import ray from ray.util.sgd import TorchTrainer + from ray.util.sgd.torch import TrainingOperator from ray.util.sgd.torch.examples.train_example import LinearDataset import torch from torch.utils.data import DataLoader + class CustomTrainingOperator(TrainingOperator): + def setup(self, config): + # Load data. + train_loader = DataLoader(LinearDataset(2, 5), config["batch_size"]) + val_loader = DataLoader(LinearDataset(2, 5), config["batch_size"]) - def model_creator(config): - return torch.nn.Linear(1, 1) + # Create model. + model = torch.nn.Linear(1, 1) + # Create optimizer. + optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) - def optimizer_creator(model, config): - """Returns optimizer.""" - return torch.optim.SGD(model.parameters(), lr=1e-2) + # Create loss. + loss = torch.nn.MSELoss() + # Register model, optimizer, and loss. + self.model, self.optimizer, self.criterion = self.register( + models=model, + optimizers=optimizer, + criterion=loss) + + # Register data loaders. + self.register_data(train_loader=train_loader, validation_loader=val_loader) - def data_creator(config): - train_loader = DataLoader(LinearDataset(2, 5), config["batch_size"]) - val_loader = DataLoader(LinearDataset(2, 5), config["batch_size"]) - return train_loader, val_loader ray.init() trainer1 = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=torch.nn.MSELoss, + training_operator_cls=CustomTrainingOperator, num_workers=2, use_gpu=False, config={"batch_size": 64}) diff --git a/doc/source/raysgd/raysgd_pytorch.rst b/doc/source/raysgd/raysgd_pytorch.rst index 138ef3791..f6ae551e1 100644 --- a/doc/source/raysgd/raysgd_pytorch.rst +++ b/doc/source/raysgd/raysgd_pytorch.rst @@ -15,109 +15,25 @@ For end to end examples leveraging RaySGD TorchTrainer, jump to :ref:`raysgd-tor .. contents:: :local: +Basic Usage +----------- + Setting up training -------------------- +~~~~~~~~~~~~~~~~~~~ .. tip:: If you want to leverage multi-node data parallel training with PyTorch while using RayTune *without* restructuring your code, check out the :ref:`Tune PyTorch user guide ` and Tune's :ref:`distributed pytorch integrations `. -The ``TorchTrainer`` can be constructed with functions that wrap components of the training script. Specifically, it requires constructors for the Model, Data, Optimizer, Loss, and ``lr_scheduler`` to create replicated copies across different devices and machines. +The :ref:`ref-torch-trainer` can be constructed from a custom :ref:`ref-torch-operator` subclass that defines training components like the model, data, optimizer, loss, and ``lr_scheduler``. These components are all automatically replicated across different machines and devices so that training can be executed in parallel. + +.. warning:: You should call ``self.register(...)`` and ``self.register_data(...)`` inside the ``setup`` method of your custom ``TrainingOperator`` to register the necessary training components with Ray SGD. + +.. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py + :language: python + :start-after: __torch_operator_start__ + :end-before: __torch_operator_end__ Under the hood, ``TorchTrainer`` will create *replicas* of your model (controlled by ``num_workers``), each of which is managed by a Ray actor. One of the replicas will be on the main process, which can simplify the debugging and logging experience. -.. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py - :language: python - :start-after: __torch_trainer_start__ - :end-before: __torch_trainer_end__ - -The below section covers the expected signatures of creator functions. Jump to :ref:`starting-torch-trainer`. - -Model Creator -~~~~~~~~~~~~~ - -This is the signature needed for ``TorchTrainer(model_creator=...)``. - -.. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py - :language: python - :start-after: __torch_model_start__ - :end-before: __torch_model_end__ - - -Optimizer Creator -~~~~~~~~~~~~~~~~~ - -This is the signature needed for ``TorchTrainer(optimizer_creator=...)``. - -.. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py - :language: python - :start-after: __torch_optimizer_start__ - :end-before: __torch_optimizer_end__ - - -Data Creator -~~~~~~~~~~~~ - -This is the signature needed for ``TorchTrainer(data_creator=...)``. - -.. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py - :language: python - :start-after: __torch_data_start__ - :end-before: __torch_data_end__ - - -.. tip:: Setting the batch size: Using a provided ``ray.util.sgd.utils.BATCH_SIZE`` variable, you can provide a global batch size that will be divided among all workers automatically. - -.. code-block:: python - - from torch.utils.data import DataLoader - from ray.util.sgd.utils import BATCH_SIZE - - def data_creator(config): - # config[BATCH_SIZE] == provided BATCH_SIZE // num_workers - train_dataset, val_dataset = LinearDataset(2, 5), LinearDataset(2, 5) - train_loader = DataLoader(train_dataset, batch_size=config[BATCH_SIZE]) - val_loader = DataLoader(val_dataset, batch_size=config[BATCH_SIZE]) - return train_loader, val_loader - - trainer = Trainer( - model_creator=model_creator, - optimizer_creator=optimizer_creator, - data_creator=batch_data_creator - config={BATCH_SIZE: 1024}, - num_workers=128 - ) - - # Each worker will process 1024 // 128 samples per batch - stats = Trainer.train() - - -Loss Creator -~~~~~~~~~~~~ - -This is the signature needed for ``TorchTrainer(loss_creator=...)``. - -.. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py - :language: python - :start-after: __torch_loss_start__ - :end-before: __torch_loss_end__ - - -Scheduler Creator -~~~~~~~~~~~~~~~~~ - -Optionally, you can provide a creator function for the learning rate scheduler. This is the signature needed -for ``TorchTrainer(scheduler_creator=...)``. - -.. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py - :language: python - :start-after: __torch_scheduler_start__ - :end-before: __torch_scheduler_end__ - - -.. _starting-torch-trainer: - -Putting things together -~~~~~~~~~~~~~~~~~~~~~~~ - Before instantiating the trainer, first start or connect to a Ray cluster: .. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py @@ -125,7 +41,7 @@ Before instantiating the trainer, first start or connect to a Ray cluster: :start-after: __torch_ray_start__ :end-before: __torch_ray_end__ -Instantiate the trainer object: +And then you can instantiate the trainer object using your custom ``TrainingOperator``: .. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py :language: python @@ -135,19 +51,16 @@ Instantiate the trainer object: You can also set the number of workers and whether the workers will use GPUs: .. code-block:: python - :emphasize-lines: 8,9 + :emphasize-lines: 4,5 trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.MSELoss, - scheduler_creator=scheduler_creator, + training_operator_cls=MyTrainingOperator, config={"lr": 0.001}, num_workers=100, use_gpu=True) - +Executing Training +~~~~~~~~~~~~~~~~~~ Now that the trainer is constructed, here's how to train the model. .. code-block:: python @@ -157,7 +70,34 @@ Now that the trainer is constructed, here's how to train the model. val_metrics = trainer.validate() -Each ``train`` call makes one pass over the training data (trains on 1 epoch), and each ``validate`` call runs the model on the validation data passed in by the ``data_creator``. +Each ``train`` call makes one pass over the training data (trains on 1 epoch), and each ``validate`` call runs the model on the validation data. +Override training and validation methods in your Training Operator (:ref:`raysgd-custom-training`) to calculate custom metrics or customize the training/validation process. + +.. tip:: Setting the batch size: Using a provided ``ray.util.sgd.utils.BATCH_SIZE`` variable, you can provide a global batch size that will be divided among all workers automatically. + +.. code-block:: python + + from torch.utils.data import DataLoader + from ray.util.sgd.utils import BATCH_SIZE + + class MyTrainingOperator(TrainingOperator): + def setup(self, config): + ... + # Create data loaders. + # config[BATCH_SIZE] == provided BATCH_SIZE // num_workers + train_dataset, val_dataset = LinearDataset(2, 5), LinearDataset(2, 5) + train_loader = DataLoader(train_dataset, batch_size=config[BATCH_SIZE]) + val_loader = DataLoader(val_dataset, batch_size=config[BATCH_SIZE]) + ... + trainer = TorchTrainer( + training_operator_cls=MyTrainingOperator, + config={BATCH_SIZE: 1024}, + num_workers=128 + ) + + # Each worker will process 1024 // 128 samples per batch + stats = Trainer.train() + You can also obtain profiling information: @@ -177,8 +117,6 @@ You can also obtain profiling information: mean_grad_s: 0.00016553401947021483 train_epoch_s: 0.023712158203125 -Provide a custom training operator (:ref:`raysgd-custom-training`) to calculate custom metrics or customize the training/validation process. - After training, you may want to reappropriate the Ray cluster. To release Ray resources obtained by the Trainer: .. code-block:: python @@ -189,20 +127,20 @@ After training, you may want to reappropriate the Ray cluster. To release Ray re See the documentation on the TorchTrainer here: :ref:`ref-torch-trainer`. +See the documentation on the TrainingOperator here: :ref:`ref-torch-operator`. .. _raysgd-custom-training: -Custom Training and Validation (Operators) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Custom Training and Validation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -``TorchTrainer`` allows you to run a custom training and validation loops in parallel on each worker, providing a flexible interface similar to using PyTorch natively. -This is done via the :ref:`ref-torch-operator` interface. +If you would like to implement custom training and validation logic, you can do so by overriding the appropiate methods inside your :ref:`ref-torch-operator` subclass. For both training and validation, there are two granularities that you can provide customization - per epoch and per batch. These correspond to ``train_batch``, -``train_epoch``, ``validate``, and ``validate_batch``. Other useful methods to override include ``setup``, ``save`` and ``restore``. You can use these -to manage state (like a classifier neural network for calculating inception score, or a heavy tokenizer). +``train_epoch``, ``validate``, and ``validate_batch``. Other useful methods to override include ``state_dict`` and ``load_state_dict``. You can use these +to save and load additional state for your custom ``TrainingOperator``. -Providing a custom operator is necessary if creator functions return multiple models, optimizers, or schedulers. +Custom training is necessary if you are using multiple models, optimizers, or schedulers. Below is a partial example of a custom ``TrainingOperator`` that provides a ``train_batch`` implementation for a Deep Convolutional GAN. @@ -213,13 +151,17 @@ Below is a partial example of a custom ``TrainingOperator`` that provides a ``tr class GANOperator(TrainingOperator): def setup(self, config): - """Custom setup for this operator. + """Setup for this operator. + + This is where you define the training state and register it with Ray SGD. Args: config (dict): Custom configuration value to be passed to all creator and operator constructors. Same as ``self.config``. """ - pass + ... + self.models, self.optimizers, ... = self.register(...) + self.register_data(...) def train_batch(self, batch, batch_info): """Trains on one batch of data from the data creator. @@ -287,10 +229,6 @@ Below is a partial example of a custom ``TrainingOperator`` that provides a ``tr } trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.BCELoss, training_operator_cls=GANOperator, num_workers=num_workers, config=config, @@ -313,15 +251,19 @@ TorchTrainer automatically applies a DistributedDataParallel wrapper to your mod DistributedDataParallel(model, device_ids=self.device_ids) -By setting ``TorchTrainer(wrap_ddp=False)`` and providing your own custom training operator, you can change the parameters on the DistributedDataParallel wrapper or provide your own wrapper. +By setting ``TorchTrainer(wrap_ddp=False)``, you can change the parameters on the DistributedDataParallel wrapper or provide your own wrapper. + +.. note:: Make sure to register the model before it is wrapped in DistributedDataParallel or a custom wrapper. .. code-block:: python - :emphasize-lines: 20 + :emphasize-lines: 19 from ray.util.sgd.torch import TrainingOperator class CustomOperator(TrainingOperator): def setup(self, config): + ... + self.model, ... = self.register(...) self.new_model = CustomDataParallel(self.model, device_ids=self.device_ids) @@ -331,14 +273,23 @@ By setting ``TorchTrainer(wrap_ddp=False)`` and providing your own custom traini return {"loss": loss} trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, training_operator_cls=CustomOperator, num_workers=2, use_gpu=True - wrap_ddp=False, - ) + wrap_ddp=False) + +.. _backwards-compat: + +Backwards Compatibility +~~~~~~~~~~~~~~~~~~~~~~~ +In previous versions of Ray, *creator functions* (``model_creator``, ``optimizer_creator``, etc.) were necessary to setup the training components. +These creator functions are no longer used and instead training component setup should be specified inside the ``setup`` method of a ``TrainingOperator`` subclass. +However, if you have these creator functions already and do not want to change your code, you can easily use these creator functions to create a custom ``TrainingOperator``. + +.. literalinclude:: ../../../python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py + :language: python + :start-after: __backwards_compat__start + :end-before: __backwards_compat_end Initialization Functions ------------------------ @@ -355,10 +306,7 @@ Use the ``initialization_hook`` parameter to initialize state on each worker pro os.environ["NCCL_DEBUG"] = "INFO" trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.MSELoss, + training_operator_cls=MyTrainingOperator, initialization_hook=initialization_hook, config={"lr": 0.001} num_workers=100, @@ -370,6 +318,8 @@ Save and Load If you want to save or reload the training procedure, you can use ``trainer.save`` and ``trainer.load``, which wraps the relevant ``torch.save`` and ``torch.load`` calls. This should work across a distributed cluster even without a NFS because it takes advantage of Ray's distributed object store. +.. tip:: Make sure to override the ``state_dict`` and ``load_state_dict`` methods in your custom TrainingOperator if necessary. + .. code-block:: python checkpoint_path = os.path.join(tempfile.mkdtemp(), "checkpoint") @@ -378,10 +328,7 @@ and ``trainer.load``, which wraps the relevant ``torch.save`` and ``torch.load`` trainer_1.shutdown() trainer_2 = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.MSELoss, + training_operator_cls=MyTrainingOperator, num_workers=num_workers) trainer_2.load(checkpoint_path) @@ -445,16 +392,12 @@ Mixed Precision (FP16) Training You can enable mixed precision training for PyTorch with the ``use_fp16`` flag. This automatically converts the model(s) and optimizer(s) to train using mixed-precision. This requires NVIDIA ``Apex``, which can be installed from `the NVIDIA/Apex repository `_: .. code-block:: python - :emphasize-lines: 7 + :emphasize-lines: 4 trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.MSELoss, + training_operator_cls=MyTrainingOperator, num_workers=4, - use_fp16=True - ) + use_fp16=True) ``Apex`` is a Pytorch extension with NVIDIA-maintained utilities to streamline mixed precision and distributed training. When ``use_fp16=True``, you should not manually cast your model or data to ``.half()``. The flag informs the Trainer to call ``amp.initialize`` on the created models and optimizers and optimize using the scaled loss: ``amp.scale_loss(loss, optimizer)``. @@ -462,13 +405,10 @@ you should not manually cast your model or data to ``.half()``. The flag informs To specify particular parameters for ``amp.initialize``, you can use the ``apex_args`` field for the TorchTrainer constructor. Valid arguments can be found on the `Apex documentation `_: .. code-block:: python - :emphasize-lines: 7-12 + :emphasize-lines: 5-10 trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.MSELoss, + training_operator_cls=MyTrainingOperator, num_workers=4, use_fp16=True, apex_args={ @@ -478,7 +418,7 @@ To specify particular parameters for ``amp.initialize``, you can use the ``apex_ } ) -Note that if using a custom training operator (:ref:`raysgd-custom-training`), you will need to manage loss scaling manually. +Note that if implementing custom training (:ref:`raysgd-custom-training`), you will need to manage loss scaling manually. Distributed Multi-node Training @@ -499,10 +439,7 @@ After connecting, you can scale up the number of workers seamlessly across multi .. code-block:: python trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.MSELoss, + training_operator_cls=MyTrainingOperator, num_workers=100 ) trainer.train() @@ -532,7 +469,6 @@ Note that we assume the Trainer itself is not on a pre-emptible node. To allow t Advanced: Hyperparameter Tuning ------------------------------- - ``TorchTrainer`` naturally integrates with Tune via the ``BaseTorchTrainable`` interface. Without changing any arguments, you can call ``TorchTrainer.as_trainable(model_creator...)`` to create a Tune-compatible class. See the documentation (:ref:`BaseTorchTrainable-doc`). .. literalinclude:: ../../../python/ray/util/sgd/torch/examples/tune_example.py @@ -546,7 +482,7 @@ You can see the `Tune example script `_ for an end-to-end example. @@ -587,6 +523,22 @@ You can see the `DCGAN script `_. -**My creator functions download data, and I don't want multiple processes downloading to the same path at once.** +**My setup function downloads data, and I don't want multiple processes downloading to the same path at once.** -Use ``filelock`` within the creator functions to create locks for critical regions. For example: +Use ``FileLock`` to create locks for critical regions. For example: .. code-block:: python diff --git a/python/ray/util/sgd/tests/test_torch.py b/python/ray/util/sgd/tests/test_torch.py index 02b823729..f7493bc91 100644 --- a/python/ray/util/sgd/tests/test_torch.py +++ b/python/ray/util/sgd/tests/test_torch.py @@ -11,8 +11,8 @@ from torch.utils.data import DataLoader import ray from ray import tune from ray.util.sgd.torch import TorchTrainer -from ray.util.sgd.torch.training_operator import (_TestingOperator, - _TestMetricsOperator) +from ray.util.sgd.torch.training_operator import ( + get_test_operator, get_test_metrics_operator, TrainingOperator) from ray.util.sgd.torch.constants import SCHEDULER_STEP from ray.util.sgd.utils import (check_for_failure, NUM_SAMPLES, BATCH_COUNT, BATCH_SIZE) @@ -44,13 +44,12 @@ def ray_start_4_cpus(): dist.destroy_process_group() +Operator = TrainingOperator.from_creators( + model_creator, optimizer_creator, data_creator, loss_creator=nn.MSELoss) + + def test_single_step(ray_start_2_cpus): # noqa: F811 - trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), - num_workers=1) + trainer = TorchTrainer(training_operator_cls=Operator, num_workers=1) metrics = trainer.train(num_steps=1) assert metrics[BATCH_COUNT] == 1 @@ -59,49 +58,9 @@ def test_single_step(ray_start_2_cpus): # noqa: F811 trainer.shutdown() -def test_resize(ray_start_2_cpus): # noqa: F811 - trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), - num_workers=1) - trainer.train(num_steps=1) - trainer.max_replicas = 2 - results = trainer.train(num_steps=1, reduce_results=False) - assert len(results) == 2 - - -def test_non_serialized_data(ray_start_2_cpus): # noqa: F811 - duration = 10 - - def slow_data(func): - def slowed_func(*args, **kwargs): - time.sleep(duration) - return func(*args, **kwargs) - - return slowed_func - - start = time.time() - trainer = TorchTrainer( - model_creator=model_creator, - data_creator=slow_data(data_creator), - optimizer_creator=optimizer_creator, - serialize_data_creation=False, - loss_creator=lambda config: nn.MSELoss(), - num_workers=2) - elapsed = time.time() - start - assert elapsed < duration * 2 - trainer.shutdown() - - def test_dead_trainer(ray_start_2_cpus): # noqa: F811 - trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), - num_workers=2) + TestOperator = get_test_operator(Operator) + trainer = TorchTrainer(training_operator_cls=TestOperator, num_workers=2) trainer.train(num_steps=1) trainer.shutdown() with pytest.raises(RuntimeError): @@ -111,11 +70,7 @@ def test_dead_trainer(ray_start_2_cpus): # noqa: F811 @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) def test_train(ray_start_2_cpus, num_workers): # noqa: F811 trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), - num_workers=num_workers) + training_operator_cls=Operator, num_workers=num_workers) for i in range(3): train_loss1 = trainer.train()["train_loss"] validation_loss1 = trainer.validate()["val_loss"] @@ -165,22 +120,26 @@ def test_multi_model(ray_start_2_cpus, num_workers): iterator=iter(data)) return result - def multi_model_creator(config): - return nn.Linear(1, 1), nn.Linear(1, 1) + class MultiModelOperator(TrainingOperator): + def setup(self, config): + models = nn.Linear(1, 1), nn.Linear(1, 1) + opts = [ + torch.optim.SGD(model.parameters(), lr=0.0001) + for model in models + ] + loss = nn.MSELoss() + train_dataloader, val_dataloader = data_creator(config) + self.models, self.optimizers, self.criterion = self.register( + models=models, optimizers=opts, criterion=loss) + self.register_data( + train_loader=train_dataloader, + validation_loader=val_dataloader) - def multi_optimizer_creator(models, config): - opts = [ - torch.optim.SGD(model.parameters(), lr=0.0001) for model in models - ] - return opts[0], opts[1] + TestOperator = get_test_operator(MultiModelOperator) trainer1 = TorchTrainer( - model_creator=multi_model_creator, - data_creator=data_creator, - optimizer_creator=multi_optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), config={"custom_func": train_epoch}, - training_operator_cls=_TestingOperator, + training_operator_cls=TestOperator, num_workers=num_workers) trainer1.train() state = trainer1.state_dict() @@ -190,12 +149,8 @@ def test_multi_model(ray_start_2_cpus, num_workers): trainer1.shutdown() trainer2 = TorchTrainer( - model_creator=multi_model_creator, - data_creator=data_creator, - optimizer_creator=multi_optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), config={"custom_func": train_epoch}, - training_operator_cls=_TestingOperator, + training_operator_cls=TestOperator, num_workers=num_workers) trainer2.load_state_dict(state) @@ -252,17 +207,29 @@ def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811 ] return schedulers[0] if len(schedulers) == 1 else schedulers + class MultiModelOperator(TrainingOperator): + def setup(self, config): + models = multi_model_creator(config) + optimizers = multi_optimizer_creator(models, config) + schedulers = multi_scheduler_creator(optimizers, config) + train_loader, val_loader = data_creator(config) + loss = nn.MSELoss() + + self.models, self.optimizers, self.criterion, self.schedulers = \ + self.register(models=models, optimizers=optimizers, + schedulers=schedulers, + criterion=loss) + self.register_data( + train_loader=train_loader, validation_loader=val_loader) + + TestOperator = get_test_operator(MultiModelOperator) + for model_count in range(1, 3): for optimizer_count in range(1, 3): for scheduler_count in range(1, 3): trainer = TorchTrainer( - model_creator=multi_model_creator, - data_creator=data_creator, - optimizer_creator=multi_optimizer_creator, - loss_creator=nn.MSELoss, - scheduler_creator=multi_scheduler_creator, scheduler_step_freq="epoch", - training_operator_cls=_TestingOperator, + training_operator_cls=TestOperator, num_workers=num_workers, config={ "models": model_count, @@ -284,24 +251,31 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811 return torch.optim.lr_scheduler.StepLR( optimizer, step_size=30, gamma=0.1) + class TestTrainingOperator(TrainingOperator): + def setup(self, config): + model = model_creator(config) + optimizer = optimizer_creator(model, config) + train_loader, val_loader = data_creator(config) + scheduler = scheduler_creator(optimizer, config) + loss = nn.MSELoss() + + self.model, self.optimizer, self.criterion, self.scheduler = \ + self.register( + models=model, optimizers=optimizer, + criterion=loss, schedulers=scheduler) + self.register_data( + train_loader=train_loader, validation_loader=val_loader) + if scheduler_freq is None: with pytest.raises(ValueError): trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), - scheduler_creator=scheduler_creator, + config={"custom_func": train_epoch}, + training_operator_cls=TestTrainingOperator, scheduler_step_freq=scheduler_freq) else: trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), config={"custom_func": train_epoch}, - training_operator_cls=_TestingOperator, - scheduler_creator=scheduler_creator, + training_operator_cls=TestTrainingOperator, scheduler_step_freq=scheduler_freq) for i in range(3): @@ -310,11 +284,7 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811 def test_profiling(ray_start_2_cpus): # noqa: F811 - trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss()) + trainer = TorchTrainer(training_operator_cls=Operator) stats = trainer.train(profile=True) assert "profile" in stats @@ -334,11 +304,14 @@ def test_dataset(ray_start_4_cpus): model_creator = mlp_identity.model_creator optimizer_creator = mlp_identity.optimizer_creator dataset_creator = mlp_identity.dataset_creator - trainer = TorchTrainer( + + DatasetOperator = TrainingOperator.from_creators( model_creator=model_creator, - data_creator=None, optimizer_creator=optimizer_creator, - loss_creator=torch.nn.MSELoss, + loss_creator=nn.MSELoss) + + trainer = TorchTrainer( + training_operator_cls=DatasetOperator, num_workers=2, ) @@ -366,12 +339,13 @@ def test_split_batch(ray_start_2_cpus): data_size = 600 batch_size = 21 - + TestOperator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + data_creator, + loss_creator=lambda config: nn.MSELoss()) trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), + training_operator_cls=TestOperator, num_workers=2, config={ BATCH_SIZE: batch_size, @@ -398,11 +372,13 @@ def test_reduce_result(ray_start_2_cpus): data_size = 600 + TestOperator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + data_creator, + loss_creator=lambda config: nn.MSELoss()) trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), + training_operator_cls=TestOperator, num_workers=2, config={"data_size": data_size}) list_stats = trainer.train(reduce_results=False, profile=True) @@ -426,11 +402,10 @@ def test_metrics(ray_start_2_cpus, num_workers): train_scores = [1] + ([0] * num_train_steps) val_scores = [1] + ([0] * num_val_steps) + + TestOperator = get_test_metrics_operator(Operator) trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), + training_operator_cls=TestOperator, num_workers=num_workers, config={ "scores": train_scores, @@ -439,8 +414,7 @@ def test_metrics(ray_start_2_cpus, num_workers): "batch_size": batch_size, "data_size": data_size, "val_size": val_size - }, - training_operator_cls=_TestMetricsOperator) + }) stats = trainer.train(num_steps=num_train_steps) # Test that we output mean and last of custom metrics in an epoch @@ -475,11 +449,9 @@ def test_metrics_nan(ray_start_2_cpus, num_workers): train_scores = [np.nan] + ([0] * num_train_steps) val_scores = [np.nan] + ([0] * num_val_steps) + TestOperator = get_test_metrics_operator(Operator) trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), + training_operator_cls=TestOperator, num_workers=num_workers, config={ "scores": train_scores, @@ -488,8 +460,7 @@ def test_metrics_nan(ray_start_2_cpus, num_workers): "batch_size": batch_size, "data_size": data_size, "val_size": val_size - }, - training_operator_cls=_TestMetricsOperator) + }) stats = trainer.train(num_steps=num_train_steps) assert "score" in stats @@ -506,19 +477,20 @@ def test_metrics_nan(ray_start_2_cpus, num_workers): def test_scheduler_validate(ray_start_2_cpus): # noqa: F811 from torch.optim.lr_scheduler import ReduceLROnPlateau - trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), + TestOperator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + data_creator, scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer), - scheduler_step_freq="manual", - training_operator_cls=_TestingOperator) + loss_creator=lambda config: nn.MSELoss()) + TestOperator = get_test_operator(TestOperator) + trainer = TorchTrainer( + scheduler_step_freq="manual", training_operator_cls=TestOperator) trainer.update_scheduler(0.5) trainer.update_scheduler(0.5) assert all( trainer.apply_all_operators( - lambda op: op.schedulers[0].last_epoch == 2)) + lambda op: op._schedulers[0].last_epoch == 2)) trainer.shutdown() @@ -526,10 +498,7 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811 def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811 TorchTrainable = TorchTrainer.as_trainable( **{ - "model_creator": model_creator, - "data_creator": data_creator, - "optimizer_creator": optimizer_creator, - "loss_creator": lambda config: nn.MSELoss(), + "training_operator_cls": Operator, "num_workers": num_workers, "use_gpu": False, "backend": "gloo", @@ -560,11 +529,7 @@ def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811 def test_save_and_restore(ray_start_2_cpus, num_workers, tmp_path): # noqa: F811 trainer1 = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), - num_workers=num_workers) + training_operator_cls=Operator, num_workers=num_workers) trainer1.train() checkpoint_path = os.path.join(tmp_path, "checkpoint") trainer1.save(checkpoint_path) @@ -574,11 +539,7 @@ def test_save_and_restore(ray_start_2_cpus, num_workers, trainer1.shutdown() trainer2 = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), - num_workers=num_workers) + training_operator_cls=Operator, num_workers=num_workers) trainer2.load(checkpoint_path) model2 = trainer2.get_model() @@ -597,12 +558,7 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811 if not dist.is_available(): return trainer1 = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), - wrap_ddp=False, - num_workers=2) + training_operator_cls=Operator, wrap_ddp=False, num_workers=2) trainer1.train() checkpoint_path = os.path.join(tmp_path, "checkpoint") trainer1.save(checkpoint_path) @@ -613,12 +569,7 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811 trainer1.shutdown() trainer2 = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), - wrap_ddp=False, - num_workers=2) + training_operator_cls=Operator, wrap_ddp=False, num_workers=2) trainer2.load(checkpoint_path) model2 = trainer2.get_model() @@ -672,12 +623,14 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811 step_with_fail = gen_step_with_fail(3) + TestOperator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + single_loader, + loss_creator=lambda config: nn.MSELoss()) with patch.object(TorchTrainer, "_train_epoch", step_with_fail): trainer1 = TorchTrainer( - model_creator=model_creator, - data_creator=single_loader, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), + training_operator_cls=TestOperator, config={"batch_size": 100000}, num_workers=2) @@ -697,13 +650,15 @@ def test_resize(ray_start_2_cpus): # noqa: F811 step_with_fail = gen_step_with_fail(1) + TestOperator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + single_loader, + loss_creator=lambda config: nn.MSELoss()) with patch.object(TorchTrainer, "_train_epoch", step_with_fail): trainer1 = TorchTrainer( - model_creator=model_creator, - data_creator=single_loader, - optimizer_creator=optimizer_creator, + training_operator_cls=TestOperator, config={"batch_size": 100000}, - loss_creator=lambda config: nn.MSELoss(), num_workers=2) @ray.remote @@ -728,13 +683,16 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811 step_with_fail = gen_step_with_fail(2) + TestOperator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + single_loader, + loss_creator=lambda config: nn.MSELoss()) + with patch.object(TorchTrainer, "_train_epoch", step_with_fail): trainer1 = TorchTrainer( - model_creator=model_creator, - data_creator=single_loader, - optimizer_creator=optimizer_creator, + training_operator_cls=TestOperator, config={"batch_size": 100000}, - loss_creator=lambda config: nn.MSELoss(), num_workers=2) # MAX RETRIES SHOULD BE ON BY DEFAULT @@ -778,12 +736,13 @@ def test_multi_input_model(ray_start_2_cpus): ) return train_loader, None - trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=lambda config: nn.MSELoss(), - num_workers=1) + Operator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + data_creator, + loss_creator=lambda config: nn.MSELoss()) + + trainer = TorchTrainer(training_operator_cls=Operator, num_workers=1) metrics = trainer.train(num_steps=1) assert metrics[BATCH_COUNT] == 1 @@ -794,4 +753,5 @@ def test_multi_input_model(ray_start_2_cpus): if __name__ == "__main__": import pytest import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/sgd/tests/test_torch_runner.py b/python/ray/util/sgd/tests/test_torch_runner.py index ca8292373..ef1bc0ca2 100644 --- a/python/ray/util/sgd/tests/test_torch_runner.py +++ b/python/ray/util/sgd/tests/test_torch_runner.py @@ -50,19 +50,22 @@ def create_dataloaders(config): class TestTorchRunner(unittest.TestCase): + def setUp(self): + self.Operator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + create_dataloaders, + loss_creator=loss_creator) + def testValidate(self): - class MockOperator(TrainingOperator): + class MockOperator(self.Operator): def setup(self, config): + super(MockOperator, self).setup(config) self.train_epoch = MagicMock(returns=dict(mean_accuracy=10)) self.validate = MagicMock(returns=dict(mean_accuracy=10)) - runner = TorchRunner( - model_creator, - create_dataloaders, - optimizer_creator, - loss_creator, - training_operator_cls=MockOperator) - runner.setup() + runner = TorchRunner(training_operator_cls=MockOperator) + runner.setup_operator() runner.train_epoch() runner.train_epoch() result = runner.train_epoch() @@ -72,21 +75,17 @@ class TestTorchRunner(unittest.TestCase): self.assertEqual(result["epoch"], 3) def testtrain_epoch(self): - class MockOperator(TrainingOperator): + class MockOperator(self.Operator): def setup(self, config): + super(MockOperator, self).setup(config) self.count = 0 def train_epoch(self, *args, **kwargs): self.count += 1 return {"count": self.count} - runner = TorchRunner( - model_creator, - create_dataloaders, - optimizer_creator, - loss_creator, - training_operator_cls=MockOperator) - runner.setup() + runner = TorchRunner(training_operator_cls=MockOperator) + runner.setup_operator() runner.train_epoch(num_steps=1) runner.train_epoch(num_steps=1) result = runner.train_epoch() @@ -95,11 +94,6 @@ class TestTorchRunner(unittest.TestCase): self.assertEqual(result["epoch"], 3) def testGivens(self): - class MockOperator(TrainingOperator): - def setup(self, config): - self.train_epoch = MagicMock(returns=dict(mean_accuracy=10)) - self.validate = MagicMock(returns=dict(mean_accuracy=10)) - def three_model_creator(config): return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1) @@ -109,20 +103,27 @@ class TestTorchRunner(unittest.TestCase): ] return opts[0], opts[1], opts[2] - runner = TorchRunner( - three_model_creator, - single_loader, - three_optimizer_creator, - loss_creator, - training_operator_cls=MockOperator) - runner.setup() + class MockOperator(TrainingOperator): + def setup(self, config): + models = three_model_creator(config) + optimizers = three_optimizer_creator(models, config) + loader = single_loader(config) + loss = loss_creator(config) + self.models, self.optimizers, self.criterion = \ + self.register(models=models, optimizers=optimizers, + criterion=loss) + self.register_data(train_loader=loader, validation_loader=None) + self.train_epoch = MagicMock(returns=dict(mean_accuracy=10)) + self.validate = MagicMock(returns=dict(mean_accuracy=10)) + + runner = TorchRunner(training_operator_cls=MockOperator) + runner.setup_operator() self.assertEqual(len(runner.given_models), 3) self.assertEqual(len(runner.given_optimizers), 3) - runner2 = TorchRunner(model_creator, single_loader, optimizer_creator, - loss_creator) - runner2.setup() + runner2 = TorchRunner(training_operator_cls=self.Operator) + runner2.setup_operator() self.assertNotEqual(runner2.given_models, runner2.models) self.assertNotEqual(runner2.given_optimizers, runner2.optimizers) @@ -132,49 +133,42 @@ class TestTorchRunner(unittest.TestCase): return (LinearDataset(2, 5), LinearDataset(2, 5, size=400), LinearDataset(2, 5, size=400)) - runner = TorchRunner(model_creator, three_data_loader, - optimizer_creator, loss_creator) - with self.assertRaises(ValueError): - runner.setup() + ThreeOperator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + three_data_loader, + loss_creator=loss_creator) - runner2 = TorchRunner(model_creator, three_data_loader, - optimizer_creator, loss_creator) + runner = TorchRunner(training_operator_cls=ThreeOperator) with self.assertRaises(ValueError): - runner2.setup() + runner.setup_operator() + + runner2 = TorchRunner(training_operator_cls=ThreeOperator) + with self.assertRaises(ValueError): + runner2.setup_operator() def testSingleLoader(self): - runner = TorchRunner(model_creator, single_loader, optimizer_creator, - loss_creator) - runner.setup() + SingleOperator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + single_loader, + loss_creator=loss_creator) + runner = TorchRunner(training_operator_cls=SingleOperator) + runner.setup_operator() runner.train_epoch() with self.assertRaises(ValueError): runner.validate() def testNativeLoss(self): - runner = TorchRunner( + NativeOperator = TrainingOperator.from_creators( model_creator, - single_loader, optimizer_creator, + single_loader, loss_creator=nn.MSELoss) - runner.setup() + runner = TorchRunner(training_operator_cls=NativeOperator) + runner.setup_operator() runner.train_epoch() - def testMultiModel(self): - def multi_model_creator(config): - return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1) - - def multi_optimizer_creator(models, config): - opts = [ - torch.optim.SGD(model.parameters(), lr=0.1) for model in models - ] - return opts[0], opts[1], opts[2] - - runner = TorchRunner(multi_model_creator, single_loader, - multi_optimizer_creator, loss_creator) - - with self.assertRaises(ValueError): - runner.setup() - class TestLocalDistributedRunner(unittest.TestCase): def setUp(self): diff --git a/python/ray/util/sgd/torch/__init__.py b/python/ray/util/sgd/torch/__init__.py index 45e0ff00a..608e67e97 100644 --- a/python/ray/util/sgd/torch/__init__.py +++ b/python/ray/util/sgd/torch/__init__.py @@ -4,6 +4,7 @@ logger = logging.getLogger(__name__) TorchTrainer = None TrainingOperator = None BaseTorchTrainable = None +CreatorOperator = None try: import torch # noqa: F401 @@ -11,9 +12,13 @@ try: from ray.util.sgd.torch.torch_trainer import (TorchTrainer, BaseTorchTrainable) - from ray.util.sgd.torch.training_operator import TrainingOperator + from ray.util.sgd.torch.training_operator import (TrainingOperator, + CreatorOperator) - __all__ = ["TorchTrainer", "BaseTorchTrainable", "TrainingOperator"] + __all__ = [ + "TorchTrainer", "BaseTorchTrainable", "TrainingOperator", + "CreatorOperator" + ] except ImportError as e: logger.warning(e) logger.warning("PyTorch not found. TorchTrainer will not be available") diff --git a/python/ray/util/sgd/torch/distributed_torch_runner.py b/python/ray/util/sgd/torch/distributed_torch_runner.py index 3ab4bd841..375116427 100644 --- a/python/ray/util/sgd/torch/distributed_torch_runner.py +++ b/python/ray/util/sgd/torch/distributed_torch_runner.py @@ -4,7 +4,6 @@ import os import torch import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, IterableDataset from torch.utils.data.distributed import DistributedSampler from ray.util.sgd.torch.utils import setup_process_group @@ -43,9 +42,6 @@ class DistributedTorchRunner(TorchRunner): self.add_dist_sampler = add_dist_sampler self.world_rank = None - def setup(self): - raise RuntimeError("Need to call setup commands separately.") - def setup_process_group(self, url, world_rank, world_size, timeout): """Connects the distributed PyTorch backend. @@ -60,7 +56,7 @@ class DistributedTorchRunner(TorchRunner): setup_process_group( url, world_rank, world_size, timeout, backend=self.backend) - def setup_ddp_and_operator(self): + def setup_operator(self): """Runs distributed coordination components. This helps avoid timeouts due to creator functions (perhaps @@ -70,29 +66,18 @@ class DistributedTorchRunner(TorchRunner): if self.use_gpu and torch.cuda.is_available(): device_ids = self.get_device_ids() - # Wrap dataloaders - self._wrap_dataloaders() - - training_models = self.models - if self.wrap_ddp: - # This needs to happen after apex - training_models = [ - DistributedDataParallel(model, device_ids=device_ids) - for model in self.models - ] self.training_operator = self.training_operator_cls( self.config, - models=training_models, - optimizers=self.optimizers, - criterion=self.criterion, - train_loader=self.train_loader, - validation_loader=self.validation_loader, world_rank=self.world_rank, - schedulers=self.schedulers, device_ids=device_ids, use_gpu=self.use_gpu, use_fp16=self.use_fp16, - use_tqdm=self.use_tqdm) + use_tqdm=self.use_tqdm, + apex_args=self.apex_args, + wrap_ddp=self.wrap_ddp, + wrap_distributed_sampler=True, + add_dist_sampler=self.add_dist_sampler, + scheduler_step_freq=self.scheduler_step_freq) def get_device_ids(self): """Needed for SyncBatchNorm, which needs 1 GPU per process.""" diff --git a/python/ray/util/sgd/torch/examples/benchmarks/benchmark.py b/python/ray/util/sgd/torch/examples/benchmarks/benchmark.py index 34c3fb87d..2cae22f7d 100644 --- a/python/ray/util/sgd/torch/examples/benchmarks/benchmark.py +++ b/python/ray/util/sgd/torch/examples/benchmarks/benchmark.py @@ -72,6 +72,17 @@ def init_hook(): class Training(TrainingOperator): def setup(self, config): + model = getattr(models, config.get("model"))() + optimizer = optim.SGD( + model.parameters(), lr=0.01 * config["lr_scaler"]) + train_data = LinearDataset(4, + 2) # Have to use dummy data for training. + + self.model, self.optimizer = self.register( + models=model, + optimizers=optimizer, + ) + self.register_data(train_loader=train_data, validation_loader=None) data = torch.randn(args.batch_size, 3, 224, 224) target = torch.LongTensor(args.batch_size).random_() % 1000 if args.cuda: @@ -107,14 +118,12 @@ if __name__ == "__main__": print("Number of %ss: %d" % (device, num_workers)) trainer = TorchTrainer( - model_creator=lambda cfg: getattr(models, args.model)(), - optimizer_creator=lambda model, cfg: optim.SGD( - model.parameters(), lr=0.01 * cfg.get("lr_scaler")), - data_creator=lambda cfg: LinearDataset(4, 2), # Mock dataset. - initialization_hook=init_hook, - config=dict( - lr_scaler=num_workers), training_operator_cls=Training, + initialization_hook=init_hook, + config={ + "lr_scaler": num_workers, + "model": args.model + }, num_workers=num_workers, use_gpu=args.cuda, use_fp16=args.fp16, diff --git a/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py b/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py index 94d286d45..f51549f23 100644 --- a/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py +++ b/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py @@ -2,6 +2,8 @@ import os import torch import torch.nn as nn import argparse + +from filelock import FileLock from torch.utils.data import DataLoader, Subset from torchvision.datasets import CIFAR10 import torchvision.transforms as transforms @@ -9,9 +11,9 @@ import torchvision.transforms as transforms from tqdm import trange import ray -from ray.util.sgd.torch import TorchTrainer +from ray.util.sgd.torch import TorchTrainer, TrainingOperator from ray.util.sgd.torch.resnet import ResNet18 -from ray.util.sgd.utils import BATCH_SIZE +from ray.util.sgd.utils import BATCH_SIZE, override def initialization_hook(): @@ -24,47 +26,66 @@ def initialization_hook(): # os.environ["NCCL_DEBUG"] = "INFO" -def cifar_creator(config): - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.2010)), - ]) # meanstd transformation +class CifarTrainingOperator(TrainingOperator): + @override(TrainingOperator) + def setup(self, config): + # Create model. + model = ResNet18(config) - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.2010)), - ]) - train_dataset = CIFAR10( - root="~/data", train=True, download=True, transform=transform_train) - validation_dataset = CIFAR10( - root="~/data", train=False, download=False, transform=transform_test) + # Create optimizer. + optimizer = torch.optim.SGD( + model.parameters(), + lr=config.get("lr", 0.1), + momentum=config.get("momentum", 0.9)) - if config["test_mode"]: - train_dataset = Subset(train_dataset, list(range(64))) - validation_dataset = Subset(validation_dataset, list(range(64))) + # Load in training and validation data. + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)), + ]) # meanstd transformation - train_loader = DataLoader( - train_dataset, batch_size=config[BATCH_SIZE], num_workers=2) - validation_loader = DataLoader( - validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2) - return train_loader, validation_loader + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)), + ]) + with FileLock(".ray.lock"): + train_dataset = CIFAR10( + root="~/data", + train=True, + download=True, + transform=transform_train) + validation_dataset = CIFAR10( + root="~/data", + train=False, + download=False, + transform=transform_test) + if config["test_mode"]: + train_dataset = Subset(train_dataset, list(range(64))) + validation_dataset = Subset(validation_dataset, list(range(64))) -def optimizer_creator(model, config): - """Returns optimizer""" - return torch.optim.SGD( - model.parameters(), - lr=config.get("lr", 0.1), - momentum=config.get("momentum", 0.9)) + train_loader = DataLoader( + train_dataset, batch_size=config[BATCH_SIZE], num_workers=2) + validation_loader = DataLoader( + validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2) + # Create scheduler. + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, milestones=[150, 250, 350], gamma=0.1) -def scheduler_creator(optimizer, config): - return torch.optim.lr_scheduler.MultiStepLR( - optimizer, milestones=[150, 250, 350], gamma=0.1) + # Create loss. + criterion = nn.CrossEntropyLoss() + + # Register all components. + self.model, self.optimizer, self.criterion, self.scheduler = \ + self.register(models=model, optimizers=optimizer, + criterion=criterion, schedulers=scheduler) + self.register_data( + train_loader=train_loader, validation_loader=validation_loader) if __name__ == "__main__": @@ -105,11 +126,7 @@ if __name__ == "__main__": ray.init(address=args.address, num_cpus=num_cpus, log_to_driver=True) trainer1 = TorchTrainer( - model_creator=ResNet18, - data_creator=cifar_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.CrossEntropyLoss, - scheduler_creator=scheduler_creator, + training_operator_cls=CifarTrainingOperator, initialization_hook=initialization_hook, num_workers=args.num_workers, config={ diff --git a/python/ray/util/sgd/torch/examples/cifar_pytorch_pbt.py b/python/ray/util/sgd/torch/examples/cifar_pytorch_pbt.py index f2b6a75c3..40833517c 100644 --- a/python/ray/util/sgd/torch/examples/cifar_pytorch_pbt.py +++ b/python/ray/util/sgd/torch/examples/cifar_pytorch_pbt.py @@ -3,6 +3,8 @@ import os import torch import torch.nn as nn import argparse + +from filelock import FileLock from ray import tune from ray.tune.schedulers import PopulationBasedTraining from torch.utils.data import DataLoader, Subset @@ -11,9 +13,9 @@ import torchvision.transforms as transforms import ray from ray.tune import CLIReporter -from ray.util.sgd.torch import TorchTrainer +from ray.util.sgd.torch import TorchTrainer, TrainingOperator from ray.util.sgd.torch.resnet import ResNet18 -from ray.util.sgd.utils import BATCH_SIZE +from ray.util.sgd.utils import BATCH_SIZE, override def initialization_hook(): @@ -26,42 +28,62 @@ def initialization_hook(): # os.environ["NCCL_DEBUG"] = "INFO" -def cifar_creator(config): - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.2010)), - ]) # meanstd transformation +class CifarTrainingOperator(TrainingOperator): + @override(TrainingOperator) + def setup(self, config): + # Create model. + model = ResNet18(config) - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.2010)), - ]) - train_dataset = CIFAR10( - root="~/data", train=True, download=True, transform=transform_train) - validation_dataset = CIFAR10( - root="~/data", train=False, download=False, transform=transform_test) + # Create optimizer. + optimizer = torch.optim.SGD( + model.parameters(), + lr=config.get("lr", 0.1), + momentum=config.get("momentum", 0.9)) - if config.get("test_mode"): - train_dataset = Subset(train_dataset, list(range(64))) - validation_dataset = Subset(validation_dataset, list(range(64))) + # Load in training and validation data. + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)), + ]) # meanstd transformation - train_loader = DataLoader( - train_dataset, batch_size=config[BATCH_SIZE], num_workers=2) - validation_loader = DataLoader( - validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2) - return train_loader, validation_loader + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)), + ]) + with FileLock(".ray.lock"): + train_dataset = CIFAR10( + root="~/data", + train=True, + download=True, + transform=transform_train) + validation_dataset = CIFAR10( + root="~/data", + train=False, + download=False, + transform=transform_test) -def optimizer_creator(model, config): - """Returns optimizer""" - return torch.optim.SGD( - model.parameters(), - lr=config.get("lr", 0.1), - momentum=config.get("momentum", 0.9)) + if config.get("test_mode"): + train_dataset = Subset(train_dataset, list(range(64))) + validation_dataset = Subset(validation_dataset, list(range(64))) + + train_loader = DataLoader( + train_dataset, batch_size=config[BATCH_SIZE], num_workers=2) + validation_loader = DataLoader( + validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2) + + # Create loss. + criterion = nn.CrossEntropyLoss() + + self.model, self.optimizer, self.criterion = \ + self.register(models=model, optimizers=optimizer, + criterion=criterion,) + self.register_data( + train_loader=train_loader, validation_loader=validation_loader) if __name__ == "__main__": @@ -101,10 +123,7 @@ if __name__ == "__main__": ray.init(address=args.address, log_to_driver=True) TorchTrainable = TorchTrainer.as_trainable( - model_creator=ResNet18, - data_creator=cifar_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.CrossEntropyLoss, + training_operator_cls=CifarTrainingOperator, initialization_hook=initialization_hook, num_workers=args.num_workers, config={ diff --git a/python/ray/util/sgd/torch/examples/dcgan.py b/python/ray/util/sgd/torch/examples/dcgan.py index 4669d266b..90f59458f 100644 --- a/python/ray/util/sgd/torch/examples/dcgan.py +++ b/python/ray/util/sgd/torch/examples/dcgan.py @@ -9,6 +9,7 @@ import torch.utils.data import torchvision.datasets as datasets import torchvision.transforms as transforms import numpy as np +from filelock import FileLock from tqdm import trange @@ -24,22 +25,6 @@ from ray.util.sgd.torch import TrainingOperator MODEL_PATH = os.path.expanduser("~/.ray/models/mnist_cnn.pt") -def data_creator(config): - dataset = datasets.MNIST( - root="~/mnist/", - download=True, - transform=transforms.Compose([ - transforms.Resize(32), - transforms.ToTensor(), - transforms.Normalize((0.5, ), (0.5, )), - ])) - if config.get("test_mode"): - dataset = torch.utils.data.Subset(dataset, list(range(64))) - dataloader = torch.utils.data.DataLoader( - dataset, batch_size=config.get("batch_size", 32)) - return dataloader - - class Generator(nn.Module): def __init__(self, latent_vector_size, features=32, num_channels=1): super(Generator, self).__init__() @@ -101,35 +86,57 @@ class LeNet(nn.Module): return F.log_softmax(x, dim=1) -def model_creator(config): - def weights_init(m): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("BatchNorm") != -1: - nn.init.normal_(m.weight.data, 1.0, 0.02) - nn.init.constant_(m.bias.data, 0) - - discriminator = Discriminator() - discriminator.apply(weights_init) - - generator = Generator( - latent_vector_size=config.get("latent_vector_size", 100)) - generator.apply(weights_init) - return discriminator, generator - - -def optimizer_creator(models, config): - net_d, net_g = models - discriminator_opt = optim.Adam( - net_d.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999)) - generator_opt = optim.Adam( - net_g.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999)) - return discriminator_opt, generator_opt +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) class GANOperator(TrainingOperator): def setup(self, config): + discriminator = Discriminator() + discriminator.apply(weights_init) + + generator = Generator( + latent_vector_size=config.get("latent_vector_size", 100)) + generator.apply(weights_init) + models = (discriminator, generator) + + discriminator_opt = optim.Adam( + discriminator.parameters(), + lr=config.get("lr", 0.01), + betas=(0.5, 0.999)) + generator_opt = optim.Adam( + generator.parameters(), + lr=config.get("lr", 0.01), + betas=(0.5, 0.999)) + optimizers = (discriminator_opt, generator_opt) + + with FileLock(".ray.lock"): + dataset = datasets.MNIST( + root="~/mnist/", + download=True, + transform=transforms.Compose([ + transforms.Resize(32), + transforms.ToTensor(), + transforms.Normalize((0.5, ), (0.5, )), + ])) + if config.get("test_mode"): + dataset = torch.utils.data.Subset(dataset, list(range(64))) + train_dataloader = torch.utils.data.DataLoader( + dataset, batch_size=config.get("batch_size", 32)) + + self.models, self.optimizers, self.criterion = self.register( + models=models, optimizers=optimizers, criterion=nn.BCELoss()) + self.register_data( + train_loader=train_dataloader, validation_loader=None) + + self.model = self.models[0] + self.optimizer = self.optimizers[0] + self.classifier = LeNet() self.classifier.load_state_dict( torch.load(config["classification_model_path"])) @@ -232,10 +239,6 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False): "classification_model_path": MODEL_PATH } trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.BCELoss, training_operator_cls=GANOperator, num_workers=num_workers, config=config, diff --git a/python/ray/util/sgd/torch/examples/image_models/train.py b/python/ray/util/sgd/torch/examples/image_models/train.py index b8fbd218b..e9fe05176 100644 --- a/python/ray/util/sgd/torch/examples/image_models/train.py +++ b/python/ray/util/sgd/torch/examples/image_models/train.py @@ -9,6 +9,7 @@ from os.path import join +from ray.util.sgd.torch import TrainingOperator from tqdm import trange import torch.nn as nn @@ -130,11 +131,13 @@ def main(): ray.init(address=args.ray_address) - trainer = TorchTrainer( + CustomTrainingOperator = TrainingOperator.from_creators( model_creator=model_creator, - data_creator=data_creator, optimizer_creator=optimizer_creator, - loss_creator=loss_creator, + data_creator=data_creator, + loss_creator=loss_creator) + trainer = TorchTrainer( + training_operator_cls=CustomTrainingOperator, use_tqdm=True, use_fp16=args.amp, apex_args={"opt_level": "O1"}, diff --git a/python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py b/python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py index 7d0f819e1..b9f1bb721 100644 --- a/python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py +++ b/python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py @@ -7,6 +7,47 @@ in the documentation. """ # yapf: disable +# __torch_operator_start__ +from ray.util.sgd.torch import TrainingOperator + +class MyTrainingOperator(TrainingOperator): + def setup(self, config): + # Setup all components needed for training here. This could include + # data, models, optimizers, loss & schedulers. + + # Setup data loaders. + train_dataset, val_dataset = LinearDataset(2, 5), LinearDataset(2, + 5) + train_loader = DataLoader(train_dataset, + batch_size=config["batch_size"]) + val_loader = DataLoader(val_dataset, + batch_size=config["batch_size"]) + + # Setup model. + model = nn.Linear(1, 1) + + # Setup optimizer. + optimizer = torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4)) + + # Setup loss. + criterion = torch.nn.BCELoss() + + # Setup scheduler. + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9) + + # Register all of these components with Ray SGD. + # This allows Ray SGD to do framework level setup like Cuda, DDP, + # Distributed Sampling, FP16. + # We also assign the return values of self.register to instance + # attributes so we can access it in our custom training/validation + # methods. + self.model, self.optimizer, self.criterion, self.scheduler = \ + self.register(models=model, optimizers=optimizer, + criterion=criterion, + scheduler=scheduler) + self.register_data(train_loader=train_loader, validation_loader=val_loader) +# __torch_operator_end__ + # __torch_model_start__ import torch.nn as nn @@ -103,6 +144,21 @@ def scheduler_creator(optimizer, config): # __torch_scheduler_end__ +# __backwards_compat__start +from ray.util.sgd import TorchTrainer + +MyTrainingOperator = TrainingOperator.from_creators( + model_creator=model_creator, optimizer_creator=optimizer_creator, + loss_creator=loss_creator, scheduler_creator=scheduler_creator, + data_creator=data_creator) + +trainer = TorchTrainer( + training_operator_cls=MyTrainingOperator, + scheduler_step_freq="epoch", # if scheduler_creator is passed in + config={"lr": 0.001, "batch_size": 64}) + +# __backwards_compat_end + # __torch_ray_start__ import ray @@ -114,12 +170,8 @@ ray.init() from ray.util.sgd import TorchTrainer trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.MSELoss, - scheduler_creator=scheduler_creator, - scheduler_step_freq="epoch", # if scheduler_creator is set + training_operator_cls=MyTrainingOperator, + scheduler_step_freq="epoch", # if scheduler is used config={"lr": 0.001, "batch_size": 64}) # __torch_trainer_end__ diff --git a/python/ray/util/sgd/torch/examples/segmentation/train_segmentation.py b/python/ray/util/sgd/torch/examples/segmentation/train_segmentation.py index 3e3c7c5a6..afefdac30 100644 --- a/python/ray/util/sgd/torch/examples/segmentation/train_segmentation.py +++ b/python/ray/util/sgd/torch/examples/segmentation/train_segmentation.py @@ -4,6 +4,7 @@ import time import torch import torch.utils.data +from filelock import FileLock from torch import nn import torchvision @@ -50,28 +51,6 @@ def get_dataset(name, return ds, num_classes -def data_creator(config): - # Within a machine, this code runs synchronously. - dataset, num_classes = get_dataset( - args.dataset, "train", get_transform(train=True)) - config["num_classes"] = num_classes - dataset_test, _ = get_dataset( - args.dataset, "val", get_transform(train=False)) - data_loader = torch.utils.data.DataLoader( - dataset, - batch_size=args.batch_size, - num_workers=args.data_workers, - collate_fn=utils.collate_fn, - drop_last=True) - - data_loader_test = torch.utils.data.DataLoader( - dataset_test, - batch_size=1, - num_workers=args.data_workers, - collate_fn=utils.collate_fn) - return data_loader, data_loader_test - - def get_transform(train): base_size = 520 crop_size = 480 @@ -101,7 +80,75 @@ def criterion(inputs, target): return losses["out"] + 0.5 * losses["aux"] +def get_optimizer(model, aux_loss): + params_to_optimize = [ + { + "params": [ + p for p in model.backbone.parameters() if p.requires_grad + ] + }, + { + "params": [ + p for p in model.classifier.parameters() if p.requires_grad + ] + }, + ] + if aux_loss: + params = [ + p for p in model.aux_classifier.parameters() if p.requires_grad + ] + params_to_optimize.append({"params": params, "lr": args.lr * 10}) + optimizer = torch.optim.SGD( + params_to_optimize, + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + return optimizer + + class SegOperator(TrainingOperator): + def setup(self, config): + args = config["args"] + # Create Data Loaders. + with FileLock(".ray.lock"): + # Within a machine, this code runs synchronously. + dataset, num_classes = get_dataset( + args.dataset, "train", get_transform(train=True)) + config["num_classes"] = num_classes + dataset_test, _ = get_dataset( + args.dataset, "val", get_transform(train=False)) + + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.data_workers, + collate_fn=utils.collate_fn, + drop_last=True) + + data_loader_test = torch.utils.data.DataLoader( + dataset_test, + batch_size=1, + num_workers=args.data_workers, + collate_fn=utils.collate_fn) + + # Create model. + model = torchvision.models.segmentation.__dict__[args.model]( + num_classes=config["num_classes"], + aux_loss=args.aux_loss, + pretrained=args.pretrained) + if config["num_workers"] > 1: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + + # Create optimizer. + optimizer = get_optimizer(model, aux_loss=args.aux_loss) + + # Register components. + self.model, self.optimizer = self.register( + models=model, + optimizers=optimizer, + train_loader=data_loader, + validation_loader=data_loader_test) + def train_batch(self, batch, batch_info): image, target = batch image, target = image.to(self.device), target.to(self.device) @@ -132,43 +179,6 @@ class SegOperator(TrainingOperator): return confmat -def optimizer_creator(model, config): - args = config["args"] - params_to_optimize = [ - { - "params": [ - p for p in model.backbone.parameters() if p.requires_grad - ] - }, - { - "params": [ - p for p in model.classifier.parameters() if p.requires_grad - ] - }, - ] - if args.aux_loss: - params = [ - p for p in model.aux_classifier.parameters() if p.requires_grad - ] - params_to_optimize.append({"params": params, "lr": args.lr * 10}) - return torch.optim.SGD( - params_to_optimize, - lr=args.lr, - momentum=args.momentum, - weight_decay=args.weight_decay) - - -def model_creator(config): - args = config["args"] - model = torchvision.models.segmentation.__dict__[args.model]( - num_classes=config["num_classes"], - aux_loss=args.aux_loss, - pretrained=args.pretrained) - if config["num_workers"] > 1: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - return model - - def main(args): os.makedirs(args.output_dir, exist_ok=True) @@ -176,9 +186,6 @@ def main(args): start_time = time.time() config = {"args": args, "num_workers": args.num_workers} trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, training_operator_cls=SegOperator, use_tqdm=True, use_fp16=True, diff --git a/python/ray/util/sgd/torch/examples/train_example.py b/python/ray/util/sgd/torch/examples/train_example.py index 2c6bde4da..71e8446ec 100644 --- a/python/ray/util/sgd/torch/examples/train_example.py +++ b/python/ray/util/sgd/torch/examples/train_example.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn from ray.util.sgd import TorchTrainer +from ray.util.sgd.torch import TrainingOperator class LinearDataset(torch.utils.data.Dataset): @@ -67,12 +68,12 @@ def data_creator(config): def train_example(num_workers=1, use_gpu=False): + CustomTrainingOperator = TrainingOperator.from_creators( + model_creator=model_creator, optimizer_creator=optimizer_creator, + data_creator=data_creator, scheduler_creator=scheduler_creator, + loss_creator=nn.MSELoss) trainer1 = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.MSELoss, - scheduler_creator=scheduler_creator, + training_operator_cls=CustomTrainingOperator, num_workers=num_workers, use_gpu=use_gpu, config={ diff --git a/python/ray/util/sgd/torch/examples/transformers/transformers_example.py b/python/ray/util/sgd/torch/examples/transformers/transformers_example.py index dcdab782c..1b3b0826f 100644 --- a/python/ray/util/sgd/torch/examples/transformers/transformers_example.py +++ b/python/ray/util/sgd/torch/examples/transformers/transformers_example.py @@ -92,81 +92,76 @@ def announce_training(args, dataset_len, t_total): logger.info(" Total optimization steps = %d", t_total) -def model_creator(config): - with FileLock(os.path.expanduser("~/.download.lock")): - args = config["args"] - processor = processors[args.task_name]() - label_list = processor.get_labels() - num_labels = len(label_list) - config = AutoConfig.from_pretrained( - args.config_name if args.config_name else args.model_name_or_path, - num_labels=num_labels, - finetuning_task=args.task_name, - cache_dir=args.cache_dir if args.cache_dir else None, - ) - model = AutoModelForSequenceClassification.from_pretrained( - args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - cache_dir=args.cache_dir if args.cache_dir else None, - ) - return model - - -def optimizer_creator(model, cfg): - args = cfg["args"] - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [ - p for n, p in model.named_parameters() - if not any(nd in n for nd in no_decay) - ], - "weight_decay": args.weight_decay, - }, - { - "params": [ - p for n, p in model.named_parameters() - if any(nd in n for nd in no_decay) - ], - "weight_decay": 0.0 - }, - ] - - return AdamW( - optimizer_grouped_parameters, - lr=args.learning_rate, - eps=args.adam_epsilon) - - -def data_creator(config): - args = config["args"] - start = time.time() - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer_name - if args.tokenizer_name else args.model_name_or_path, - cache_dir=args.cache_dir if args.cache_dir else None, - ) - logger.info(f"tokenizer instantiation time: {time.time() - start}") - - train_dataset = load_and_cache_examples( - args, args.task_name, tokenizer, evaluate=False) - train_sampler = RandomSampler( - train_dataset) if not dist.is_initialized() else None - return DataLoader( - train_dataset, - sampler=train_sampler, - batch_size=args.per_device_train_batch_size) - - class TransformerOperator(TrainingOperator): def setup(self, config): self.args = args = config["args"] + start = time.time() self.tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None, ) + logger.info(f"tokenizer instantiation time: {time.time() - start}") + + # Load data. + train_dataset = load_and_cache_examples( + args, args.task_name, self.tokenizer, evaluate=False) + train_sampler = RandomSampler( + train_dataset) if not dist.is_initialized() else None + train_loader = DataLoader( + train_dataset, + sampler=train_sampler, + batch_size=args.per_device_train_batch_size) + + # Create model. + with FileLock(os.path.expanduser("~/.download.lock")): + processor = processors[args.task_name]() + label_list = processor.get_labels() + num_labels = len(label_list) + model_config = AutoConfig.from_pretrained( + args.config_name + if args.config_name else args.model_name_or_path, + num_labels=num_labels, + finetuning_task=args.task_name, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + model = AutoModelForSequenceClassification.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=model_config, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + + # Create optimizer. + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in model.named_parameters() + if not any(nd in n for nd in no_decay) + ], + "weight_decay": args.weight_decay, + }, + { + "params": [ + p for n, p in model.named_parameters() + if any(nd in n for nd in no_decay) + ], + "weight_decay": 0.0 + }, + ] + + optimizer = AdamW( + optimizer_grouped_parameters, + lr=args.learning_rate, + eps=args.adam_epsilon) + + # Register components. + self.model, self.optimizer = self.register( + models=model, + optimizers=optimizer, + train_loader=train_loader, + validation_loader=None) self.train_data_len = len(self.train_loader) self._warmup_scheduler = get_linear_schedule_with_warmup( @@ -334,9 +329,6 @@ def main(): # Training trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, training_operator_cls=TransformerOperator, use_fp16=args.fp16, apex_args={"opt_level": args.fp16_opt_level}, diff --git a/python/ray/util/sgd/torch/examples/transformers/utils.py b/python/ray/util/sgd/torch/examples/transformers/utils.py index f3efca18d..97ca7b916 100644 --- a/python/ray/util/sgd/torch/examples/transformers/utils.py +++ b/python/ray/util/sgd/torch/examples/transformers/utils.py @@ -36,6 +36,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ), ) + # Use FileLock to prevent parallel writes that may corrupt data. with FileLock("/tmp/load_and_cache_examples.lock"): if os.path.exists(cached_features_file) and not args.overwrite_cache: logger.info("Loading features from cached file %s", diff --git a/python/ray/util/sgd/torch/examples/tune_example.py b/python/ray/util/sgd/torch/examples/tune_example.py index ce875e78d..ef6c32893 100644 --- a/python/ray/util/sgd/torch/examples/tune_example.py +++ b/python/ray/util/sgd/torch/examples/tune_example.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader import ray from ray import tune -from ray.util.sgd.torch import TorchTrainer +from ray.util.sgd.torch import TorchTrainer, TrainingOperator from ray.util.sgd.utils import BATCH_SIZE from ray.util.sgd.torch.examples.train_example import LinearDataset @@ -37,12 +37,9 @@ def data_creator(config): # __torch_tune_example__ -def tune_example(num_workers=1, use_gpu=False): +def tune_example(operator_cls, num_workers=1, use_gpu=False): TorchTrainable = TorchTrainer.as_trainable( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.MSELoss, # Note that we specify a Loss class. + training_operator_cls=operator_cls, num_workers=num_workers, use_gpu=use_gpu, config={BATCH_SIZE: 128} @@ -81,4 +78,8 @@ if __name__ == "__main__": args, _ = parser.parse_known_args() ray.init(address=args.address) - tune_example(num_workers=args.num_workers, use_gpu=args.use_gpu) + CustomTrainingOperator = TrainingOperator.from_creators( + model_creator=model_creator, optimizer_creator=optimizer_creator, + data_creator=data_creator, loss_creator=nn.MSELoss) + tune_example(CustomTrainingOperator, num_workers=args.num_workers, + use_gpu=args.use_gpu) diff --git a/python/ray/util/sgd/torch/torch_runner.py b/python/ray/util/sgd/torch/torch_runner.py index 5bb2c8e06..82e47ceb9 100644 --- a/python/ray/util/sgd/torch/torch_runner.py +++ b/python/ray/util/sgd/torch/torch_runner.py @@ -1,26 +1,15 @@ -from filelock import FileLock import logging -import inspect import io import itertools -import os -import tempfile import torch -import torch.nn as nn import ray -from ray.util.sgd.torch.constants import USE_FP16, SCHEDULER_STEP, NUM_STEPS -from ray.util.sgd.torch.training_operator import TrainingOperator +from ray.util.sgd.torch.constants import USE_FP16, NUM_STEPS from ray.util.sgd import utils logger = logging.getLogger(__name__) amp = None -try: - from collections.abc import Iterable -except ImportError: - from collections import Iterable - try: from apex import amp except ImportError: @@ -32,12 +21,7 @@ class TorchRunner: """Manages a PyTorch model for training.""" def __init__(self, - model_creator, - data_creator, - optimizer_creator, - loss_creator=None, - scheduler_creator=None, - training_operator_cls=None, + training_operator_cls, config=None, use_gpu=False, serialize_data_creation=True, @@ -45,22 +29,11 @@ class TorchRunner: use_tqdm=False, apex_args=None, scheduler_step_freq=None): - self.model_creator = model_creator - self.optimizer_creator = optimizer_creator - self.loss_creator = loss_creator - self.data_creator = data_creator - self.scheduler_creator = scheduler_creator - self.training_operator_cls = training_operator_cls or TrainingOperator + self.training_operator_cls = training_operator_cls self.config = {} if config is None else config self.timers = utils.TimerCollection() self.epochs = 0 - self.models = None - self.optimizers = None - self.criterion = None - self.schedulers = None - self.train_loader = None - self.validation_loader = None self.training_operator = None self.serialize_data_creation = serialize_data_creation self.use_gpu = use_gpu @@ -73,107 +46,16 @@ class TorchRunner: "https://www.github.com/nvidia/apex to use fp16 training.") self.scheduler_step_freq = scheduler_step_freq - def _validate_loaders(self, loaders): - assert loaders, "Loaders need to be returned in data_creator." - if isinstance(loaders, (tuple, list)): - if len(loaders) == 1: - return loaders, None - elif len(loaders) == 2: - return loaders - else: - raise ValueError( - f"Number of loaders must be <= 2. Got {loaders}") - # No great way of checking type otherwise - return loaders, None - - def _initialize_dataloaders(self): - logger.debug("Instantiating dataloaders.") - loaders = None - if self.serialize_data_creation: - logger.debug("Serializing the dataloading process.") - with FileLock( - os.path.join(tempfile.gettempdir(), ".raydata.lock")): - loaders = self.data_creator(self.config) - else: - loaders = self.data_creator(self.config) - train_loader, val_loader = self._validate_loaders(loaders) - - self.train_loader, self.validation_loader = train_loader, val_loader - - def _create_loss(self): - if not self.loss_creator: - return - logger.debug("Creating loss.") - if inspect.isclass(self.loss_creator) and issubclass( - self.loss_creator, torch.nn.modules.loss._Loss): - self.criterion = self.loss_creator() - else: - self.criterion = self.loss_creator(self.config) - - if self.use_gpu and torch.cuda.is_available(): - if hasattr(self.criterion, "cuda"): - self.criterion = self.criterion.cuda() - - def _create_schedulers_if_available(self): - # Learning rate schedules are optional. - if not self.scheduler_creator: - return - self.schedulers = self.scheduler_creator(self.given_optimizers, - self.config) - - if not isinstance(self.schedulers, Iterable): - self.schedulers = [self.schedulers] - - def _try_setup_apex(self): - """Sets up the model for fp16 training via apex if available.""" - if self.use_fp16 and amp: - self.models, self.optimizers = amp.initialize( - self.models, self.optimizers, **self.apex_args) - - def setup(self): - """Merges setup_components and setup_operator in one call.""" - self.setup_components() - self.setup_operator() - - def setup_components(self): - """Runs the creator functions without any distributed coordination.""" - logger.debug("Loading data.") - if self.data_creator and callable(self.data_creator): - self._initialize_dataloaders() - - logger.debug("Creating model") - self.models = self.model_creator(self.config) - if not isinstance(self.models, Iterable): - self.models = [self.models] - assert all(isinstance(model, nn.Module) for model in self.models), ( - f"All models must be PyTorch models: {self.models}.") - if self.use_gpu and torch.cuda.is_available(): - self.models = [model.cuda() for model in self.models] - - logger.debug("Creating optimizer.") - self.optimizers = self.optimizer_creator(self.given_models, - self.config) - if not isinstance(self.optimizers, Iterable): - self.optimizers = [self.optimizers] - - self._create_schedulers_if_available() - self._try_setup_apex() - self._create_loss() - def setup_operator(self): """Create the training operator.""" self.training_operator = self.training_operator_cls( self.config, - models=self.models, - optimizers=self.optimizers, - criterion=self.criterion, - train_loader=self.train_loader, - validation_loader=self.validation_loader, world_rank=0, - schedulers=self.schedulers, use_gpu=self.use_gpu, use_fp16=self.use_fp16, - use_tqdm=self.use_tqdm) + use_tqdm=self.use_tqdm, + apex_args=self.apex_args, + scheduler_step_freq=self.scheduler_step_freq) def get_node_ip(self): """Returns the IP address of the current node.""" @@ -196,7 +78,6 @@ class TorchRunner: info.update({ NUM_STEPS: num_steps, USE_FP16: self.use_fp16, - SCHEDULER_STEP: self.scheduler_step_freq }) with self.timers.record("train_epoch"): if iterator is None: @@ -223,15 +104,17 @@ class TorchRunner: def validate(self, num_steps=None, profile=False, info=None): """Evaluates the model on the validation data set.""" if self.validation_loader is None: - raise ValueError("No validation dataloader provided.") + raise ValueError("No validation dataloader provided. Make sure" + "you pass in a validation_loader to " + "TrainingOperator.register_data.") info = info or {} self._toggle_profiling(profile=profile) + validation_loader = self.validation_loader with self.timers.record("validation"): - iterator = self.validation_loader + iterator = validation_loader if num_steps: - iterator = itertools.islice( - iter(self.validation_loader), num_steps) + iterator = itertools.islice(iterator, num_steps) validation_stats = self.training_operator.validate( iterator, info=info) if profile: @@ -255,32 +138,35 @@ class TorchRunner: "models": [model.state_dict() for model in self.models], "optimizers": [opt.state_dict() for opt in self.optimizers] } - if self.schedulers: + schedulers = self.schedulers + if schedulers: state.update({ "schedulers": [ - scheduler.state_dict() for scheduler in self.schedulers + scheduler.state_dict() for scheduler in schedulers ] }) # Check if fp16 is True and if NVIDIA Apex is imported. - if self.use_fp16 and amp: - state.update({"amp": amp.state_dict()}) + if self.use_fp16 and self.training_operator._amp: + state.update({"amp": self.training_operator._amp.state_dict()}) return state def load_state_dict(self, state): """Sets the state of the model.""" - for model, state_dict in zip(self.models, state["models"]): + models = self.models + for model, state_dict in zip(models, state["models"]): model.load_state_dict(state_dict) - for optimizer, state_dict in zip(self.optimizers, state["optimizers"]): + optimizers = self.optimizers + for optimizer, state_dict in zip(optimizers, state["optimizers"]): optimizer.load_state_dict(state_dict) - if self.schedulers: - for scheduler, state_dict in zip(self.schedulers, - state["schedulers"]): + schedulers = self.schedulers + if schedulers: + for scheduler, state_dict in zip(schedulers, state["schedulers"]): scheduler.load_state_dict(state_dict) - if self.use_fp16 and "amp" in state and amp: - amp.load_state_dict(state["amp"]) + if self.use_fp16 and "amp" in state and self.training_operator._amp: + self.training_operator._amp.load_state_dict(state["amp"]) self.epochs = state["epoch"] - self.training_operator.load_state_dict(state_dict) + self.training_operator.load_state_dict(state["operator"]) def state_stream(self): """Returns a bytes object for the state dict.""" @@ -304,14 +190,67 @@ class TorchRunner: def shutdown(self): """Attempts to shut down the worker.""" del self.training_operator - del self.validation_loader - del self.train_loader - del self.criterion - del self.optimizers - del self.models if torch.cuda.is_available(): torch.cuda.empty_cache() + @property + def models(self): + if not hasattr(self.training_operator, "_original_models"): + raise RuntimeError("Training Operator does not have any " + "registered models. Are you calling " + "self.register(...) inside the setup method " + "of your Training Operator?") + return self.training_operator._original_models + + @property + def optimizers(self): + if not hasattr(self.training_operator, "_optimizers"): + raise RuntimeError("Training Operator does not have any " + "registered optimizers. Are you calling " + "self.register(...) inside the setup method " + "of your Training Operator?") + return self.training_operator._optimizers + + @property + def schedulers(self): + if not hasattr(self.training_operator, "_schedulers"): + raise RuntimeError("Training Operator does not have any " + "registered schedulers. Are you calling " + "self.register(...) inside the setup method " + "of your Training Operator?") + return self.training_operator._schedulers + + @property + def train_loader(self): + if not hasattr(self.training_operator, "_train_loader"): + logger.warning("Training Operator does not have any " + "registered train loader. If this is " + "unexepected, make sure to call " + "self.register_data(...) inside the setup method " + "of your Training Operator.") + return None + return self.training_operator._train_loader + + @property + def validation_loader(self): + if not hasattr(self.training_operator, "_validation_loader"): + logger.warning("Training Operator does not have any " + "registered validation loader. If this is " + "unexepected, make sure to call " + "self.register_data(...) inside the setup method " + "of your Training Operator.") + return None + return self.training_operator._validation_loader + + @property + def criterion(self): + if not hasattr(self.training_operator, "_criterion"): + raise RuntimeError("Training Operator does not have any " + "registered criterion. Are you calling " + "self.register(...) inside the setup method " + "of your Training Operator?") + return self.training_operator._criterion + @property def given_models(self): if len(self.models) > 1: diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 9d58c9848..ed811d782 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -13,6 +13,7 @@ from ray.exceptions import RayActorError from ray.tune import Trainable from ray.tune.resources import Resources from ray.tune.utils.util import merge_dicts +from ray.util import log_once from ray.util.sgd.torch.distributed_torch_runner import ( DistributedTorchRunner, LocalDistributedRunner, DeactivatedRunner) from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE @@ -49,80 +50,44 @@ class TorchTrainer: .. code-block:: python - def model_creator(config): - return nn.Linear(1, 1) + class MyTrainingOperator(TrainingOperator): + def setup(self, config): + model = nn.Linear(1, 1) + optimizer = torch.optim.SGD( + model.parameters(), lr=config.get("lr", 1e-4)) + loss = torch.nn.MSELoss() - def optimizer_creator(model, config): - return torch.optim.SGD( - model.parameters(), lr=config.get("lr", 1e-4)) + batch_size = config["batch_size"] + train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5) + train_loader = DataLoader(train_data, batch_size=batch_size) + val_loader = DataLoader(val_data, batch_size=batch_size) + self.model, self.optimizer = self.register( + models=model, + optimizers=optimizer, + criterion=loss) - def data_creator(config): - batch_size = config["batch_size"] - train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5) - train_loader = DataLoader(train_data, batch_size=batch_size) - val_loader = DataLoader(val_data, batch_size=batch_size) - return train_loader, val_loader + self.register_data( + train_loader=train_loader, + validation_loader=val_loader) trainer = TorchTrainer( - model_creator=model_creator, - data_creator=data_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.MSELoss, + training_operator_cls=MyTrainingOperator, config={"batch_size": 32}, use_gpu=True ) for i in range(4): trainer.train() - The creator functions will execute before distributed coordination and - training is setup. This is so that creator functions that download - large datasets will not trigger any timeouts. - - The order of operations for creator functions are: - - ``data_creator`` -> ``model_creator`` -> ``optimizer_creator`` -> - ``scheduler_creator`` -> ``loss_creator``. - Args: - model_creator (dict -> Model(s)): Constructor function that takes in - config and returns the model(s) to be optimized. These must be - ``torch.nn.Module`` objects. If multiple models are returned, - a ``training_operator_cls`` must be specified. You do not need to - handle GPU/devices in this function; RaySGD will do that under - the hood. - data_creator (dict -> Iterable(s)): Constructor function - that takes in the passed config and returns one or - two Iterable objects. Note that even though two Iterable objects - can be returned, only one will be used for training, and the - other will be used for validation. If not provided, you must - provide a custom TrainingOperator. - optimizer_creator ((models, dict) -> optimizers): Constructor - function that takes in the return values from - ``model_creator`` and the passed config and returns One or - more Torch optimizer objects. You do not need to handle - GPU/devices in this function; ``RaySGD`` will do that for you. - loss_creator (torch.nn.*Loss class | dict -> loss): A constructor - function for the training loss. This can be either a function that - takes in the provided config for customization or a subclass - of ``torch.nn.modules.loss._Loss``, which is most Pytorch - loss classes. For example, ``loss_creator=torch.nn.BCELoss``. - If not provided, you must provide a custom TrainingOperator. - scheduler_creator ((optimizers, dict) -> scheduler): - A constructor function for the torch scheduler. This is - a function that takes in the generated optimizers (from - ``optimizer_creator``) provided config for customization. - Be sure to set ``scheduler_step_freq`` to increment the - scheduler correctly. training_operator_cls (type): Custom training operator class that subclasses the TrainingOperator class. This class will be copied onto all remote workers and used to specify - custom training and validation operations. Defaults to - TrainingOperator. + training components and custom training and validation operations. config (dict): Custom configuration value to be passed to - all creator and operator constructors. + all operator constructors. num_workers (int): the number of workers used in distributed training. If 1, the worker will not be wrapped with DistributedDataParallel. @@ -134,10 +99,6 @@ class TorchTrainer: support "nccl", "gloo", and "auto". If "auto", RaySGD will automatically use "nccl" if `use_gpu` is True, and "gloo" otherwise. - serialize_data_creation (bool): A filelock will be used - to ensure no race conditions in data downloading among - different workers on the same node (using the local file system). - Defaults to True. wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel over each model. If False, you are expected to call it yourself. timeout_s (float): Seconds before the torch process group @@ -171,12 +132,7 @@ class TorchTrainer: def __init__( self, *, - model_creator, - data_creator, - optimizer_creator, - loss_creator=None, - scheduler_creator=None, - training_operator_cls=None, + training_operator_cls, initialization_hook=None, config=None, num_workers=1, @@ -185,16 +141,33 @@ class TorchTrainer: backend="auto", wrap_ddp=True, timeout_s=NCCL_TIMEOUT_S, - serialize_data_creation=True, use_fp16=False, use_tqdm=False, apex_args=None, add_dist_sampler=True, scheduler_step_freq=None, + # Deprecated Args. num_replicas=None, batch_size=None, + model_creator=None, + data_creator=None, + optimizer_creator=None, + scheduler_creator=None, + loss_creator=None, + serialize_data_creation=None, data_loader_args=None, ): + if (model_creator or data_creator or optimizer_creator + or scheduler_creator or loss_creator): + raise DeprecationWarning( + "Creator functions are deprecated. You should create a " + "custom TrainingOperator, override setup, and register all " + "training state there. See TrainingOperator for more info. " + "If you would still like to use creator functions, you can " + "do CustomOperator = TrainingOperator.from_creators(" + "model_creator, ...) and pass in CustomOperator into " + "TorchTrainer.") + if num_workers > 1 and not dist.is_available(): raise ValueError( ("Distributed PyTorch is not supported on macOS. " @@ -202,10 +175,6 @@ class TorchTrainer: "For more information, see " "https://github.com/pytorch/examples/issues/467.")) - if not (callable(model_creator) and callable(optimizer_creator)): - raise ValueError( - "Must provide a callable model_creator and optimizer_creator.") - if num_replicas is not None: raise DeprecationWarning( "num_replicas is deprecated. Use num_workers instead.") @@ -217,24 +186,23 @@ class TorchTrainer: "config={ray.util.sgd.utils.BATCH_SIZE: N} to specify a " "batch size to be used across all workers.") + if serialize_data_creation is True: + if log_once("serialize_data_creation"): + logging.warning( + "serialize_data_creation is deprecated and will be " + "ignored. If you require serialized data loading you " + "should implement this in TrainingOperator.setup. " + "You may find FileLock useful here.") + if data_loader_args: - raise ValueError( + raise DeprecationWarning( "data_loader_args is deprecated. You can return a " "torch.utils.data.DataLoader in data_creator. Ray will " "automatically set a DistributedSampler if a DataLoader is " "returned and num_workers > 1.") - self.model_creator = model_creator - self.optimizer_creator = optimizer_creator - self.loss_creator = loss_creator - self.data_creator = data_creator - self.scheduler_creator = scheduler_creator self.training_operator_cls = training_operator_cls - if not training_operator_cls and not loss_creator: - raise ValueError("If a loss_creator is not provided, you must " - "provide a custom training operator.") - self.initialization_hook = initialization_hook self.config = {} if config is None else config if use_gpu == "auto": @@ -269,7 +237,7 @@ class TorchTrainer: self.local_worker = DeactivatedRunner() self.remote_workers = [] - if scheduler_creator: + if scheduler_step_freq: _validate_scheduler_step_freq(scheduler_step_freq) self.scheduler_step_freq = scheduler_step_freq @@ -309,11 +277,6 @@ class TorchTrainer: worker_config[BATCH_SIZE] = batch_size_per_worker params = dict( - model_creator=self.model_creator, - data_creator=self.data_creator, - optimizer_creator=self.optimizer_creator, - loss_creator=self.loss_creator, - scheduler_creator=self.scheduler_creator, training_operator_cls=self.training_operator_cls, config=worker_config, serialize_data_creation=self.serialize_data_creation, @@ -328,7 +291,7 @@ class TorchTrainer: self.local_worker = TorchRunner(**params) if self.initialization_hook: self.apply_all_workers(self.initialization_hook) - self.local_worker.setup() + self.local_worker.setup_operator() else: params.update( backend=self.backend, @@ -355,15 +318,6 @@ class TorchTrainer: # Compute URL for initializing distributed PyTorch address = setup_address() - # Runs the creator functions. - remote_component_setup = [ - worker.setup_components.remote() - for i, worker in enumerate(self.remote_workers) - ] - self.local_worker.setup_components() - # Get setup tasks in order to throw errors on failure - ray.get(remote_component_setup) - # Setup the process group among all workers. remote_pgroup_setups = [ worker.setup_process_group.remote(address, i + 1, num_workers, @@ -377,10 +331,10 @@ class TorchTrainer: # Runs code that requires all creator functions to have run. remote_operator_setups = [ - worker.setup_ddp_and_operator.remote() + worker.setup_operator.remote() for worker in self.remote_workers ] - self.local_worker.setup_ddp_and_operator() + self.local_worker.setup_operator() # Get setup tasks in order to throw errors on failure ray.get(remote_operator_setups) @@ -421,10 +375,10 @@ class TorchTrainer: Returns: (dict | list) A dictionary of metrics for training. - You can provide custom metrics by passing in a custom - ``training_operator_cls``. If ``reduce_results=False``, - this will return a list of metric dictionaries whose - length will be equal to ``num_workers``. + You can provide custom metrics by implementing a custom + training loop. If ``reduce_results=False``, this will return a + list of metric dictionaries whose length will be equal to + ``num_workers``. """ assert max_retries >= 0, "`max_retries` must be non-negative." assert isinstance(dataset, Dataset) is not None \ @@ -577,12 +531,12 @@ class TorchTrainer: return worker_stats def update_scheduler(self, metric): - """Calls ``scheduler.step(metric)`` on all schedulers. + """Calls ``scheduler.step(metric)`` on all registered schedulers. This is useful for lr_schedulers such as ``ReduceLROnPlateau``. """ self.apply_all_operators( - lambda op: [sched.step(metric) for sched in op.schedulers]) + lambda op: [sched.step(metric) for sched in op._schedulers]) def get_model(self): """Returns the learned model(s).""" @@ -729,10 +683,7 @@ class TorchTrainer: .. code-block:: python TorchTrainable = TorchTrainer.as_trainable( - model_creator=ResNet18, - data_creator=cifar_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.CrossEntropyLoss, + training_operator_cls=MyTrainingOperator, num_gpus=2 ) analysis = tune.run( @@ -781,10 +732,7 @@ class BaseTorchTrainable(Trainable): .. code-block:: python TorchTrainable = TorchTrainer.as_trainable( - model_creator=ResNet18, - data_creator=cifar_creator, - optimizer_creator=optimizer_creator, - loss_creator=nn.CrossEntropyLoss, + training_operator_cls=MyTrainingOperator, num_gpus=2 ) # TorchTrainable is subclass of BaseTorchTrainable. diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index 068d38895..bf9529f07 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -1,10 +1,23 @@ +import inspect +import logging +import os +import tempfile + import torch +import torch.nn as nn +from filelock import FileLock from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection, NUM_SAMPLES) -from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS, - SCHEDULER_STEP_BATCH, SCHEDULER_STEP) +from ray.util.sgd.torch.constants import ( + SCHEDULER_STEP_EPOCH, + NUM_STEPS, + SCHEDULER_STEP_BATCH, +) +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DistributedSampler, DataLoader, IterableDataset +logger = logging.getLogger(__name__) amp = None try: @@ -18,6 +31,7 @@ except ImportError: # Apex library is not installed, so we cannot enable mixed precision. # We don't log here because logging happens in the torch_runner, # where amp is initialized. + logger.debug("apex is not installed.") pass tqdm = None @@ -33,15 +47,58 @@ def _is_multiple(component): class TrainingOperator: - """Abstract class for custom training or validation loops. + """Abstract class to define training and validation state and logic. - The scheduler will only be called at a batch or epoch frequency, depending - on the user parameter. Be sure to set ``scheduler_step_freq`` in - ``TorchTrainer`` to either "batch" or "epoch" to increment the scheduler - correctly during training. If using a learning rate scheduler - that depends on validation loss, you can use ``trainer.update_scheduler``. + You must subclass this class and override the ``setup`` method to define + your training components such as the model, optimizer, data, loss, + and scheduler. When you pass this class to ``TorchTrainer``, a copy of + this class will be made on each worker. - For both training and validation, there are two granularities that + .. code-block:: python + + class MyTrainingOperator(TrainingOperator): + + def setup(self, config): + model = nn.Linear(1, 1) + optimizer = torch.optim.SGD( + model.parameters(), lr=config.get("lr", 1e-4)) + loss = torch.nn.MSELoss() + + batch_size = config["batch_size"] + train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5) + train_loader = DataLoader(train_data, batch_size=batch_size) + val_loader = DataLoader(val_data, batch_size=batch_size) + + self.model, self.optimizer = self.register( + models=model, + optimizers=optimizer, + criterion=loss) + + self.register_data( + train_loader=train_loader, + validation_loader=val_loader) + + + trainer = TorchTrainer( + training_operator_cls=MyTrainingOperator, + config={"batch_size": 32}, + use_gpu=True + ) + for i in range(4): + trainer.train() + + This class provides default implementations for training and validation. + Set ``self.model``, ``self.optimizer``, and + ``self.criterion`` to leverage the default training and validation loops. + If ``self.scheduler`` is set, it will only be called at a batch or epoch + frequency, depending on the user parameter. Set + ``scheduler_step_freq`` in ``TorchTrainer`` to either "batch" or "epoch" + to increment the scheduler correctly during training. If using a + learning rate scheduler that depends on validation loss, you can use + ``trainer.update_scheduler``. + + If you want to provide custom training and validation loops, you can do + so using this class as well. There are two granularities that you can provide customization: per epoch or per batch. You do not need to override both. @@ -49,41 +106,30 @@ class TrainingOperator: :scale: 80% :align: center + If you are using multiple models, optimizers, or schedulers, you must + implement custom training and validation. + Raises: - ValueError if multiple models/optimizers/schedulers are provided. - You are expected to subclass this class if you wish - to train over multiple models/optimizers/schedulers. + ValueError + You are expected to either set ``self.model``, + ``self.optimizer``, and ``self.criterion`` instance attributes in + setup or implement custom training & validation. """ def __init__(self, config, - models, - optimizers, - train_loader, - validation_loader, world_rank, - criterion=None, - schedulers=None, device_ids=None, use_gpu=False, use_fp16=False, - use_tqdm=False): + use_tqdm=False, + apex_args=None, + wrap_ddp=False, + wrap_distributed_sampler=False, + add_dist_sampler=False, + scheduler_step_freq=None): # You are not expected to override this method. - self._models = models # List of models - assert isinstance( - models, - Iterable), (f"Components need to be iterable. Got: {type(models)}") - self._optimizers = optimizers # List of optimizers - assert isinstance(optimizers, Iterable), ( - f"Components need to be iterable. Got: {type(optimizers)}") - self._train_loader = train_loader - self._validation_loader = validation_loader self._world_rank = world_rank - self._criterion = criterion - self._schedulers = schedulers - if schedulers: - assert isinstance(schedulers, Iterable), ( - f"Components need to be iterable. Got: {type(schedulers)}") self._config = config self._use_fp16 = use_fp16 self._device_ids = device_ids @@ -93,14 +139,12 @@ class TrainingOperator: raise ValueError("tqdm must be installed to use tqdm in training.") self._use_tqdm = use_tqdm self.global_step = 0 + self._apex_args = apex_args if apex_args else {} + self._wrap_ddp = wrap_ddp + self._wrap_distributed_sampler = wrap_distributed_sampler + self._add_dist_sampler = add_dist_sampler + self._scheduler_step_freq = scheduler_step_freq - if type(self) is TrainingOperator: - for component in (models, schedulers, optimizers): - if _is_multiple(component): - raise ValueError( - "Need to provide a custom operator subclassing " - "TrainingOperator if using multi-scheduler, " - "multi-model or multi-optimizer training/validation.") self.timers = TimerCollection() self.setup(config) @@ -109,13 +153,233 @@ class TrainingOperator: self.timers = timers def setup(self, config): - """Override this method to implement custom operator setup. + """Override this method to implement operator setup. + + You should call self.register and self.register_data here to + register training components and data loaders with Ray SGD. Args: config (dict): Custom configuration value to be passed to all creator and operator constructors. Same as ``self.config``. """ - pass + raise NotImplementedError + + def register(self, *, models, optimizers, criterion=None, schedulers=None): + """Registers parameters with Ray SGD and sets up training components. + + By calling this method to register your models, optimizers, + criterion, and schedulers, Ray SGD will automatically handle + necessary setup such as GPU/devices, Distributed Data Parallel, and + Fp16. The registered components are returned and should be set as + instance attributes to access during training/validation. + + If more than one model, optimizer, or scheduler is passed in, + you should implement your own custom training loop. + + .. code-block:: python + + class MyTrainingOperator(TrainingOperator): + def setup(self, config): + model = ... + optimizer = ... + train_loader = ... + val_loader = ... + loss = ... + + self.model, self.optimizer, self.criterion = self.register( + models=model, optimizers=optimizer, criterion=loss) + + # At this point DDP, Cuda, and Fp16 + # are set up for all our components. We then use + # self.model, self.optimizer, etc. in our training loop. + + self.register_data(train_loader=train_loader, + validation_loader=val_loader) + + + Args: + models (torch.nn.Module or Iterable[nn.Module]): Pytorch model or + multiple Pytorch models to use for training. If + `use_gpu=True` is passed into ``TorchTrainer``, and Cuda is + available, models will automatically be placed on GPU. + If ``wrap_ddp=True`` is passed into ``TorchTrainer``, + models will be wrapped in DDP. If wrap_ddp is False, + you should handle DDP for your models in setup. + optimizers (torch.optim.Optimizer or Iterable[ + torch.optim.Optimizer]): Pytorch optimizer or multiple Pytorch + optimizers to use for training. + criterion (Callable, optional): Function to return loss + metric given features and target. If not provided, + must implement a custom training loop. + schedulers (torch.optim.lr_scheduler or Iterable[ + torch.optim.lr_scheduler], optional): A learning rate + scheduler or multiple learning rate schedulers. + + Returns: + Tuple of model, optimizer, criterion if not None, and scheduler + if not None. + + """ + return_vals = [] + logger.debug("Registering models.") + self._original_models = models + if not isinstance(self._original_models, Iterable): + self._original_models = [self._original_models] + assert all( + isinstance(model, nn.Module) for model in self._original_models), ( + f"All models must be PyTorch models: {self._original_models}.") + if self.use_gpu and torch.cuda.is_available(): + self._original_models = [ + model.cuda() for model in self._original_models + ] + + logger.debug("Registering optimizers.") + self._optimizers = optimizers + if not isinstance(self._optimizers, Iterable): + self._optimizers = [self._optimizers] + + if schedulers: + logger.debug("Registering scheduler.") + self._schedulers = schedulers + if not isinstance(self._schedulers, Iterable): + self._schedulers = [self._schedulers] + else: + self._schedulers = None + + if criterion: + logger.debug("Registering loss.") + self._criterion = criterion + if self.use_gpu and torch.cuda.is_available(): + if hasattr(self._criterion, "cuda"): + self._criterion = self._criterion.cuda() + else: + self._criterion = None + + logger.debug("Setting up Apex.") + if self.use_fp16 and amp: + self._models, self._optimizers = amp.initialize( + self._models, self._optimizers, **self._apex_args) + self._amp = amp + + if self._wrap_ddp: + logging.debug("Setting up DDP for models.") + self._models = [ + DistributedDataParallel(model, device_ids=self.device_ids) + for model in self._original_models + ] + else: + self._models = self._original_models + + if len(self._models) == 1: + return_vals.append(self._models[0]) + else: + return_vals.append(self._models) + + if len(self._optimizers) == 1: + return_vals.append(self._optimizers[0]) + else: + return_vals.append(self._optimizers) + + if self._criterion is not None: + return_vals.append(self._criterion) + + if self._schedulers is not None: + if self.scheduler_step_freq is None: + raise ValueError("scheduler_step_freq passed into " + "TorchTrainer cannot be None if you " + "are registering schedulers. Set this to " + "'manual' if you will be manually stepping " + "the schedulers.") + if len(self._schedulers) == 1: + return_vals.append(self._schedulers[0]) + else: + return_vals.append(self._schedulers) + + return tuple(return_vals) + + def register_data(self, *, train_loader=None, validation_loader=None): + """Registers data loaders with Ray SGD. + + Calling this method will automatically setup Distributed Sampler for + these data loaders if add_dist_sampler=True is passed into the + TorchTrainer. This method does not return the wrapped data loaders. + You should use the iterators passed into train_epoch and validate + instead. + + .. code-block:: python + + class MyTrainingOperator(TrainingOperator): + def setup(self, config): + model = ... + optimizer = ... + train_loader = ... + val_loader = ... + loss = ... + + self.model, self.optimizer, self.criterion = self.register( + models=model, optimizers=optimizer, criterion=loss) + + self.register_data(train_loader=train_loader, + validation_loader=val_loader) + + # At this point the data loaders are registered with + # Ray SGD and are wrapped with Distributed Samplers if + # applicable. + + + def train_epoch(self, iterator, info): + # If providing custom training or validation methods, + # the registered data loaders are passed in through the + # iterator parameter. + ... + + Args: + train_loader (Iterator): An iterator for training + data. If None is explicitly passed in, a Ray SGD Dataset + must be passed in through TorchTrainer.train. Ray SGD will + automatically use a Distributed Sampler if TorchTrainer(..., + add_dist_sampler=True). + validation_loader (Iterator): An iterator for validation + data. Ray SGD will automatically use a Distributed Sampler + if TorchTrainer(..., add_dist_sampler=True). + """ + + logger.debug("Registering data loaders..") + self._train_loader = train_loader + self._validation_loader = validation_loader + + if self._wrap_distributed_sampler: + logging.debug("Wrapping data loaders with DistributedSampler.") + + def with_sampler(loader): + # Automatically set the DistributedSampler + data_loader_args = { + "dataset": loader.dataset, + "batch_size": loader.batch_size, + "shuffle": False, + "num_workers": loader.num_workers, + "collate_fn": loader.collate_fn, + "pin_memory": loader.pin_memory, + "drop_last": loader.drop_last, + "timeout": loader.timeout, + "worker_init_fn": loader.worker_init_fn, + "sampler": DistributedSampler(loader.dataset) + } + return DataLoader(**data_loader_args) + + def should_wrap_dataloader(loader): + return (isinstance(loader, DataLoader) + and not isinstance(loader.dataset, IterableDataset)) + + if should_wrap_dataloader(self._train_loader): + if self._add_dist_sampler: + self._train_loader = with_sampler(self._train_loader) + + if self._validation_loader is not None and should_wrap_dataloader( + self._validation_loader): + if self._add_dist_sampler: + self._validation_loader = with_sampler( + self._validation_loader) def train_epoch(self, iterator, info): """Runs one standard training pass over the training dataloader. @@ -156,6 +420,15 @@ class TrainingOperator: Returns: A dict of metrics from training. """ + if not hasattr(self, "model"): + raise RuntimeError("Either set self.model in setup function or " + "override this method to implement a custom " + "training loop.") + model = self.model + scheduler = None + if hasattr(self, "scheduler"): + scheduler = self.scheduler + if self.use_tqdm and self.world_rank == 0: desc = "" if info is not None and "epoch_idx" in info: @@ -163,15 +436,19 @@ class TrainingOperator: desc = f"{info['epoch_idx'] + 1}/{info['num_epochs']}e" else: desc = f"{info['epoch_idx'] + 1}e" + + # TODO: Implement len for Dataset? + total = info[NUM_STEPS] + if total is None: + if hasattr(iterator, "__len__"): + total = len(iterator) + _progress_bar = tqdm( - total=info[NUM_STEPS] or len(self.train_loader), - desc=desc, - unit="batch", - leave=False) + total=total, desc=desc, unit="batch", leave=False) metric_meters = AverageMeterCollection() - self.model.train() + model.train() for batch_idx, batch in enumerate(iterator): batch_info = { "batch_idx": batch_idx, @@ -187,15 +464,14 @@ class TrainingOperator: postfix.update(loss=metrics["train_loss"]) _progress_bar.set_postfix(postfix) - if self.scheduler and batch_info.get( - SCHEDULER_STEP) == SCHEDULER_STEP_BATCH: - self.scheduler.step() + if scheduler and self.scheduler_step_freq == SCHEDULER_STEP_BATCH: + scheduler.step() metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) self.global_step += 1 - if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH: - self.scheduler.step() + if scheduler and self.scheduler_step_freq == SCHEDULER_STEP_EPOCH: + scheduler.step() return metric_meters.summary() @@ -211,9 +487,7 @@ class TrainingOperator: automatically. You can provide custom loss metrics and training operations if you - override this method. If overriding this method, you can access model, - optimizer, criterion via ``self.model``, ``self.optimizer``, - and ``self.criterion``. + override this method. You do not need to override this method if you plan to override ``train_epoch``. @@ -232,6 +506,21 @@ class TrainingOperator: calculate averages. """ + if not hasattr(self, "model"): + raise RuntimeError("Either set self.model in setup function or " + "override this method to implement a custom " + "training loop.") + if not hasattr(self, "optimizer"): + raise RuntimeError("Either set self.optimizer in setup function " + "or override this method to implement a custom " + "training loop.") + if not hasattr(self, "criterion"): + raise RuntimeError("Either set self.criterion in setup function " + "or override this method to implement a custom " + "training loop.") + model = self.model + optimizer = self.optimizer + criterion = self.criterion # unpack features into list to support multiple inputs model *features, target = batch # Create non_blocking tensors for distributed training @@ -243,21 +532,21 @@ class TrainingOperator: # Compute output. with self.timers.record("fwd"): - output = self.model(*features) - loss = self.criterion(output, target) + output = model(*features) + loss = criterion(output, target) # Compute gradients in a backward pass. with self.timers.record("grad"): - self.optimizer.zero_grad() + optimizer.zero_grad() if self.use_fp16: - with amp.scale_loss(loss, self.optimizer) as scaled_loss: + with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # Call step of optimizer to update model params. with self.timers.record("apply"): - self.optimizer.step() + optimizer.step() return {"train_loss": loss.item(), NUM_SAMPLES: features[0].size(0)} @@ -267,9 +556,8 @@ class TrainingOperator: This will call ``model.eval()`` and ``torch.no_grad`` when iterating over the validation dataloader. - If overriding this method, you can access model, criterion via - ``self.model`` and ``self.criterion``. You also do not need to call - ``validate_batch`` if overriding this method. + You also do not need to call ``validate_batch`` if overriding this + method. Args: val_iterator (iter): Iterable constructed from the @@ -284,10 +572,15 @@ class TrainingOperator: from ``validate_batch`` and dividing it by the sum of ``num_samples`` from all calls to ``self.validate_batch``. """ + if not hasattr(self, "model"): + raise RuntimeError("Either set self.model in setup function or " + "override this method to implement a custom " + "validation loop.") + model = self.model metric_meters = AverageMeterCollection() # switch to evaluate mode - self.model.eval() + model.eval() with torch.no_grad(): for batch_idx, batch in enumerate(val_iterator): batch_info = {"batch_idx": batch_idx} @@ -319,6 +612,16 @@ class TrainingOperator: by default, ``validate`` uses "num_samples" to calculate averages. """ + if not hasattr(self, "model"): + raise RuntimeError("Either set self.model in setup function or " + "override this method to implement a custom " + "training loop.") + if not hasattr(self, "criterion"): + raise RuntimeError("Either set self.criterion in setup function " + "or override this method to implement a custom " + "training loop.") + model = self.model + criterion = self.criterion # unpack features into list to support multiple inputs model *features, target = batch if self.use_gpu: @@ -330,8 +633,8 @@ class TrainingOperator: # compute output with self.timers.record("eval_fwd"): - output = self.model(*features) - loss = self.criterion(output, target) + output = model(*features) + loss = criterion(output, target) _, predicted = torch.max(output.data, 1) num_correct = (predicted == target).sum().item() @@ -344,6 +647,9 @@ class TrainingOperator: def state_dict(self): """Override this to return a representation of the operator state. + Any argument passed into self.register and self.register_data will + automatically be saved. + Use this method to save any additional state. Returns: dict: The state dict of the operator.""" @@ -351,11 +657,81 @@ class TrainingOperator: def load_state_dict(self, state_dict): """Override this to load the representation of the operator state. - + Anything passed into self.register and self.register_data will + automatically be loaded. Use this method to load any additional state. Args: state_dict (dict): State dict as returned by the operator. """ pass + @classmethod + def from_creators(cls, + model_creator, + optimizer_creator, + data_creator=None, + loss_creator=None, + scheduler_creator=None, + serialize_data_creation=True): + """A utility method to create a custom TrainingOperator class from + creator functions. This is useful for backwards compatibility with + previous versions of Ray. To provide custom training and validation, + you should subclass the class that is returned by this method instead + of ``TrainingOperator``. + + Args: + model_creator (dict -> Model(s)): Constructor function that takes + in config and returns the model(s) to be optimized. These + must be ``torch.nn.Module`` objects. If multiple models are + returned, a ``training_operator_cls`` must be specified. + You do not need to handle GPU/devices in this function; + RaySGD will do that under the hood. + data_creator (dict -> Iterable(s)): Constructor function + that takes in the passed config and returns one or + two Iterable objects. Note that even though two Iterable + objects can be returned, only one will be used for training, + and the other will be used for validation. If not provided, + you must pass in a Dataset to ``TorchTrainer.train``. + optimizer_creator ((models, dict) -> optimizers): Constructor + function that takes in the return values from + ``model_creator`` and the passed config and returns One or + more Torch optimizer objects. You do not need to handle + GPU/devices in this function; ``RaySGD`` will do that for you. + loss_creator (torch.nn.*Loss class | dict -> loss): A constructor + function for the training loss. This can be either a function + that takes in the provided config for customization or a + subclass of ``torch.nn.modules.loss._Loss``, which is most + Pytorch loss classes. For example, + ``loss_creator=torch.nn.BCELoss``. If not provided, you must + provide a custom TrainingOperator. + scheduler_creator ((optimizers, dict) -> scheduler): + A constructor function for the torch scheduler. This is + a function that takes in the generated optimizers (from + ``optimizer_creator``) provided config for customization. + Be sure to set ``scheduler_step_freq`` to increment the + scheduler correctly. + serialize_data_creation (bool): A filelock will be used + to ensure no race conditions in data downloading among + different workers on the same node (using the local file + system). Defaults to True. + + Returns: + A TrainingOperator class with a ``setup`` method that utilizes + the passed in creator functions. + """ + + if not (callable(model_creator) and callable(optimizer_creator)): + raise ValueError( + "Must provide a callable model_creator and optimizer_creator.") + + class CustomCreatorOperator(CreatorOperator): + _model_creator = model_creator + _optimizer_creator = optimizer_creator + _data_creator = data_creator + _loss_creator = loss_creator + _scheduler_creator = scheduler_creator + _serialize_data_creation = serialize_data_creation + + return CustomCreatorOperator + @property def device(self): """torch.device: The appropriate torch device, at your convenience.""" @@ -366,58 +742,11 @@ class TrainingOperator: """dict: Provided into TorchTrainer.""" return self._config - @property - def model(self): - """First or only model created by the provided ``model_creator``.""" - return self._models[0] - - @property - def models(self): - """List of models created by the provided ``model_creator``.""" - return self._models - - @property - def optimizer(self): - """First or only optimizer(s) created by the ``optimizer_creator``.""" - return self._optimizers[0] - - @property - def optimizers(self): - """List of optimizers created by the ``optimizer_creator``.""" - return self._optimizers - - @property - def train_loader(self): - """Iterable: 1st Dataloader from ``data_creator``. - """ - return self._train_loader - - @property - def validation_loader(self): - """Iterable: 2nd Dataloader from ``data_creator``.""" - return self._validation_loader - @property def world_rank(self): """int: The rank of the parent runner. Always 0 if not distributed.""" return self._world_rank - @property - def criterion(self): - """Criterion created by the provided ``loss_creator``.""" - return self._criterion - - @property - def scheduler(self): - """First or only scheduler(s) created by the ``scheduler_creator``.""" - if self._schedulers: - return self._schedulers[0] - - @property - def schedulers(self): - """List of schedulers created by the ``scheduler_creator``.""" - return self._schedulers - @property def use_gpu(self): """Returns True if cuda is available and use_gpu is True.""" @@ -441,31 +770,138 @@ class TrainingOperator: """ return self._device_ids + @property + def scheduler_step_freq(self): + """Optional[str]: The ``scheduler_step_freq`` passed into + ``TorchTrainer`` -class _TestingOperator(TrainingOperator): - def train_epoch(self, iterator, info): - func = self.config.get("custom_func") - if callable(func): - return func(self, iterator, info) - return {"done": 1} + This is useful to determine when to call scheduler.step. + """ + return self._scheduler_step_freq -class _TestMetricsOperator(TrainingOperator): +class CreatorOperator(TrainingOperator): + """A subclass of TrainingOperator specifically for defining training + state using creator functions. + """ + + def _validate_loaders(self, loaders): + assert loaders, "Loaders need to be returned in data_creator." + if isinstance(loaders, (tuple, list)): + if len(loaders) == 1: + return loaders, None + elif len(loaders) == 2: + return loaders + else: + raise ValueError( + f"Number of loaders must be <= 2. Got {loaders}") + # No great way of checking type otherwise + return loaders, None + + def _initialize_dataloaders(self, config): + logger.debug("Instantiating dataloaders.") + loaders = None + if self._serialize_data_creation: + logger.debug("Serializing the dataloading process.") + with FileLock( + os.path.join(tempfile.gettempdir(), ".raydata.lock")): + loaders = self.__class__._data_creator(config) + else: + loaders = self.__class__._data_creator(config) + train_loader, val_loader = self._validate_loaders(loaders) + + return train_loader, val_loader + def setup(self, config): - self._train_scores = config["scores"].copy() - self._val_scores = config["val_scores"].copy() - self.key = config["key"] + kwargs = {} + logger.debug("Loading data.") + train_loader = None + validation_loader = None + if self.__class__._data_creator and callable( + self.__class__._data_creator): + train_loader, validation_loader = self._initialize_dataloaders( + config) - def train_batch(self, batch, batch_info=None): - metrics = super(_TestMetricsOperator, self).train_batch( - batch, batch_info) - num_samples = metrics[NUM_SAMPLES] - metrics.update({self.key: self._train_scores.pop(0) / num_samples}) - return metrics + logger.debug("Creating model") + models = self.__class__._model_creator(config) - def validate_batch(self, batch, batch_info=None): - metrics = super(_TestMetricsOperator, self).validate_batch( - batch, batch_info) - num_samples = metrics[NUM_SAMPLES] - metrics.update({self.key: self._val_scores.pop(0) / num_samples}) - return metrics + kwargs["models"] = models + + logger.debug("Creating optimizer.") + optimizers = self.__class__._optimizer_creator(models, config) + + kwargs["optimizers"] = optimizers + + if self.__class__._scheduler_creator: + logger.debug("Creating scheduler.") + schedulers = self.__class__._scheduler_creator(optimizers, config) + kwargs["schedulers"] = schedulers + + if self.__class__._loss_creator: + logger.debug("Creating loss.") + if inspect.isclass(self.__class__._loss_creator) and issubclass( + self.__class__._loss_creator, torch.nn.modules.loss._Loss): + criterion = self.__class__._loss_creator() + else: + criterion = self.__class__._loss_creator(config) + kwargs["criterion"] = criterion + + state = self.register(**kwargs) + self.models, self.optimizers = state[:2] + if isinstance(self.models, tuple): + self.model = self.models[0] + else: + self.model = self.models + + if isinstance(self.optimizers, tuple): + self.optimizer = self.optimizers[0] + else: + self.optimizer = self.optimizers + + if len(state) >= 3: + self.criterion = state[2] + if len(state) == 4: + self.schedulers = state[3] + if isinstance(self.schedulers, tuple): + self.scheduler = self.schedulers[0] + else: + self.scheduler = self.schedulers + + self.register_data( + train_loader=train_loader, validation_loader=validation_loader) + + +def get_test_operator(operator_cls): + class _TestingOperator(operator_cls): + def train_epoch(self, iterator, info): + func = self.config.get("custom_func") + if callable(func): + return func(self, iterator, info) + return {"done": 1} + + return _TestingOperator + + +def get_test_metrics_operator(operator_cls): + class _TestMetricsOperator(operator_cls): + def setup(self, config): + super(_TestMetricsOperator, self).setup(config) + self._train_scores = config["scores"].copy() + self._val_scores = config["val_scores"].copy() + self.key = config["key"] + + def train_batch(self, batch, batch_info=None): + metrics = super(_TestMetricsOperator, self).train_batch( + batch, batch_info) + num_samples = metrics[NUM_SAMPLES] + metrics.update({self.key: self._train_scores.pop(0) / num_samples}) + return metrics + + def validate_batch(self, batch, batch_info=None): + metrics = super(_TestMetricsOperator, self).validate_batch( + batch, batch_info) + num_samples = metrics[NUM_SAMPLES] + metrics.update({self.key: self._val_scores.pop(0) / num_samples}) + return metrics + + return _TestMetricsOperator