mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 21:05:07 +08:00
[SGD] Add PTL Docs (#12440)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -1,18 +1,29 @@
|
||||
import argparse
|
||||
|
||||
# __import_begin__
|
||||
import os
|
||||
|
||||
# Pytorch imports
|
||||
import torch
|
||||
from ray.util.sgd import TorchTrainer
|
||||
from ray.util.sgd.torch import TrainingOperator
|
||||
from torch.nn import functional as F
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torchvision.datasets import MNIST
|
||||
import os
|
||||
from torch.nn import functional as F
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
# Ray imports
|
||||
from ray.util.sgd import TorchTrainer
|
||||
from ray.util.sgd.torch import TrainingOperator
|
||||
|
||||
# PTL imports
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
# __import_end__
|
||||
|
||||
|
||||
# __ptl_begin__
|
||||
class LitMNIST(LightningModule):
|
||||
# We take in an additional config parameter here. But this is not required.
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
@@ -77,6 +88,10 @@ class LitMNIST(LightningModule):
|
||||
return {"val_loss": loss.item(), "val_acc": num_correct / num_samples}
|
||||
|
||||
|
||||
# __ptl_end__
|
||||
|
||||
|
||||
# __train_begin__
|
||||
def train_mnist(num_workers=1, use_gpu=False, num_epochs=5):
|
||||
Operator = TrainingOperator.from_ptl(LitMNIST)
|
||||
trainer = TorchTrainer(
|
||||
@@ -101,6 +116,8 @@ def train_mnist(num_workers=1, use_gpu=False, num_epochs=5):
|
||||
print("success!")
|
||||
|
||||
|
||||
# __train_end__
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
||||
+19
-5
@@ -27,6 +27,24 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
TrainerOptimizersMixin):
|
||||
"""A subclass of TrainingOperator created from a PTL ``LightningModule``.
|
||||
|
||||
This class is returned by `TrainingOperator.from_ptl` and it's training
|
||||
state is defined by the Pytorch Lightning ``LightningModule`` that is
|
||||
passed into `from_ptl`. Training and validation functionality have
|
||||
already been implemented according to
|
||||
Pytorch Lightning's Trainer. But if you need to modify training,
|
||||
you should subclass this class and override the appropriate methods
|
||||
before passing in the subclass to `TorchTrainer`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
MyLightningOperator = TrainingOperator.from_ptl(
|
||||
MyLightningModule)
|
||||
trainer = TorchTrainer(training_operator_cls=MyLightningOperator,
|
||||
...)
|
||||
"""
|
||||
|
||||
def _configure_amp(self, amp, models, optimizers, apex_args=None):
|
||||
assert len(models) == 1
|
||||
model = models[0]
|
||||
@@ -356,11 +374,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
model.on_after_backward()
|
||||
|
||||
with self.timers.record("apply"):
|
||||
model.optimizer_step(
|
||||
epoch=epoch_idx,
|
||||
batch_idx=batch_idx,
|
||||
optimizer=optimizer,
|
||||
optimizer_idx=0)
|
||||
optimizer.step()
|
||||
|
||||
model.on_before_zero_grad(optimizer)
|
||||
|
||||
@@ -777,7 +777,14 @@ class TrainingOperator:
|
||||
lightning_module_cls,
|
||||
train_dataloader=None,
|
||||
val_dataloader=None):
|
||||
"""Creates a TrainingOperator from a Pytorch Lightning Module.
|
||||
"""Create a custom TrainingOperator class from a LightningModule.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
MyLightningOperator = TrainingOperator.from_ptl(
|
||||
MyLightningModule)
|
||||
trainer = TorchTrainer(training_operator_cls=MyLightningOperator,
|
||||
...)
|
||||
|
||||
Args:
|
||||
lightning_module_cls: Your LightningModule class. An object of
|
||||
@@ -793,7 +800,7 @@ class TrainingOperator:
|
||||
A TrainingOperator class properly configured given the
|
||||
LightningModule.
|
||||
"""
|
||||
from ray.util.sgd.torch.ptl_operator import LightningOperator
|
||||
from ray.util.sgd.torch.lightning_operator import LightningOperator
|
||||
|
||||
class CustomLightningOperator(LightningOperator):
|
||||
_lightning_module_cls = lightning_module_cls
|
||||
@@ -810,12 +817,20 @@ class TrainingOperator:
|
||||
loss_creator=None,
|
||||
scheduler_creator=None,
|
||||
serialize_data_creation=True):
|
||||
"""A utility method to create a custom TrainingOperator class from
|
||||
creator functions. This is useful for backwards compatibility with
|
||||
"""Create a custom TrainingOperator class from creator functions.
|
||||
|
||||
This method is useful for backwards compatibility with
|
||||
previous versions of Ray. To provide custom training and validation,
|
||||
you should subclass the class that is returned by this method instead
|
||||
of ``TrainingOperator``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
MyCreatorOperator = TrainingOperator.from_creators(
|
||||
model_creator, optimizer_creator)
|
||||
trainer = TorchTrainer(training_operator_cls=MyCreatorOperator,
|
||||
...)
|
||||
|
||||
Args:
|
||||
model_creator (dict -> Model(s)): Constructor function that takes
|
||||
in config and returns the model(s) to be optimized. These
|
||||
@@ -853,8 +868,8 @@ class TrainingOperator:
|
||||
system). Defaults to True.
|
||||
|
||||
Returns:
|
||||
A TrainingOperator class with a ``setup`` method that utilizes
|
||||
the passed in creator functions.
|
||||
A CreatorOperator class- a subclass of TrainingOperator with a
|
||||
``setup`` method that utilizes the passed in creator functions.
|
||||
"""
|
||||
|
||||
if not (callable(model_creator) and callable(optimizer_creator)):
|
||||
@@ -929,8 +944,21 @@ class TrainingOperator:
|
||||
|
||||
|
||||
class CreatorOperator(TrainingOperator):
|
||||
"""A subclass of TrainingOperator specifically for defining training
|
||||
state using creator functions.
|
||||
"""A subclass of TrainingOperator with training defined by creator funcs.
|
||||
|
||||
This class allows for backwards compatibility with pre Ray 1.0 versions.
|
||||
|
||||
This class is returned by `TrainingOperator.from_creators(...)`. If you
|
||||
need to add custom functionality, you should subclass this class,
|
||||
implement the appropriate methods and pass the subclass into
|
||||
`TorchTrainer`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
MyCreatorOperator = TrainingOperator.from_creators(
|
||||
model_creator, optimizer_creator)
|
||||
trainer = TorchTrainer(training_operator_cls=MyCreatorOperator,
|
||||
...)
|
||||
"""
|
||||
|
||||
def _validate_loaders(self, loaders):
|
||||
|
||||
Reference in New Issue
Block a user