[SGD] Add PTL Docs (#12440)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Amog Kamsetty
2020-11-28 10:09:38 -08:00
committed by GitHub
parent 60a545ab57
commit 8a406e1f9a
10 changed files with 269 additions and 28 deletions
@@ -1,18 +1,29 @@
import argparse
# __import_begin__
import os
# Pytorch imports
import torch
from ray.util.sgd import TorchTrainer
from ray.util.sgd.torch import TrainingOperator
from torch.nn import functional as F
from pytorch_lightning.core.lightning import LightningModule
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torch.nn import functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
# Ray imports
from ray.util.sgd import TorchTrainer
from ray.util.sgd.torch import TrainingOperator
# PTL imports
from pytorch_lightning.core.lightning import LightningModule
# __import_end__
# __ptl_begin__
class LitMNIST(LightningModule):
# We take in an additional config parameter here. But this is not required.
def __init__(self, config):
super().__init__()
@@ -77,6 +88,10 @@ class LitMNIST(LightningModule):
return {"val_loss": loss.item(), "val_acc": num_correct / num_samples}
# __ptl_end__
# __train_begin__
def train_mnist(num_workers=1, use_gpu=False, num_epochs=5):
Operator = TrainingOperator.from_ptl(LitMNIST)
trainer = TorchTrainer(
@@ -101,6 +116,8 @@ def train_mnist(num_workers=1, use_gpu=False, num_epochs=5):
print("success!")
# __train_end__
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -27,6 +27,24 @@ logger = logging.getLogger(__name__)
class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
TrainerOptimizersMixin):
"""A subclass of TrainingOperator created from a PTL ``LightningModule``.
This class is returned by `TrainingOperator.from_ptl` and it's training
state is defined by the Pytorch Lightning ``LightningModule`` that is
passed into `from_ptl`. Training and validation functionality have
already been implemented according to
Pytorch Lightning's Trainer. But if you need to modify training,
you should subclass this class and override the appropriate methods
before passing in the subclass to `TorchTrainer`.
.. code-block:: python
MyLightningOperator = TrainingOperator.from_ptl(
MyLightningModule)
trainer = TorchTrainer(training_operator_cls=MyLightningOperator,
...)
"""
def _configure_amp(self, amp, models, optimizers, apex_args=None):
assert len(models) == 1
model = models[0]
@@ -356,11 +374,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
model.on_after_backward()
with self.timers.record("apply"):
model.optimizer_step(
epoch=epoch_idx,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=0)
optimizer.step()
model.on_before_zero_grad(optimizer)
+36 -8
View File
@@ -777,7 +777,14 @@ class TrainingOperator:
lightning_module_cls,
train_dataloader=None,
val_dataloader=None):
"""Creates a TrainingOperator from a Pytorch Lightning Module.
"""Create a custom TrainingOperator class from a LightningModule.
.. code-block:: python
MyLightningOperator = TrainingOperator.from_ptl(
MyLightningModule)
trainer = TorchTrainer(training_operator_cls=MyLightningOperator,
...)
Args:
lightning_module_cls: Your LightningModule class. An object of
@@ -793,7 +800,7 @@ class TrainingOperator:
A TrainingOperator class properly configured given the
LightningModule.
"""
from ray.util.sgd.torch.ptl_operator import LightningOperator
from ray.util.sgd.torch.lightning_operator import LightningOperator
class CustomLightningOperator(LightningOperator):
_lightning_module_cls = lightning_module_cls
@@ -810,12 +817,20 @@ class TrainingOperator:
loss_creator=None,
scheduler_creator=None,
serialize_data_creation=True):
"""A utility method to create a custom TrainingOperator class from
creator functions. This is useful for backwards compatibility with
"""Create a custom TrainingOperator class from creator functions.
This method is useful for backwards compatibility with
previous versions of Ray. To provide custom training and validation,
you should subclass the class that is returned by this method instead
of ``TrainingOperator``.
.. code-block:: python
MyCreatorOperator = TrainingOperator.from_creators(
model_creator, optimizer_creator)
trainer = TorchTrainer(training_operator_cls=MyCreatorOperator,
...)
Args:
model_creator (dict -> Model(s)): Constructor function that takes
in config and returns the model(s) to be optimized. These
@@ -853,8 +868,8 @@ class TrainingOperator:
system). Defaults to True.
Returns:
A TrainingOperator class with a ``setup`` method that utilizes
the passed in creator functions.
A CreatorOperator class- a subclass of TrainingOperator with a
``setup`` method that utilizes the passed in creator functions.
"""
if not (callable(model_creator) and callable(optimizer_creator)):
@@ -929,8 +944,21 @@ class TrainingOperator:
class CreatorOperator(TrainingOperator):
"""A subclass of TrainingOperator specifically for defining training
state using creator functions.
"""A subclass of TrainingOperator with training defined by creator funcs.
This class allows for backwards compatibility with pre Ray 1.0 versions.
This class is returned by `TrainingOperator.from_creators(...)`. If you
need to add custom functionality, you should subclass this class,
implement the appropriate methods and pass the subclass into
`TorchTrainer`.
.. code-block:: python
MyCreatorOperator = TrainingOperator.from_creators(
model_creator, optimizer_creator)
trainer = TorchTrainer(training_operator_cls=MyCreatorOperator,
...)
"""
def _validate_loaders(self, loaders):