mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 05:05:21 +08:00
[SGD] Callback API for SGD+Tune (#11316)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -197,6 +197,16 @@ py_test(
|
||||
args = ["--num-workers=2", "--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tune_example_3",
|
||||
size = "small",
|
||||
main = "torch/examples/tune_example.py",
|
||||
srcs = ["torch/examples/tune_example.py"],
|
||||
tags = ["exclusive", "pytorch"],
|
||||
deps = [":sgd_lib"],
|
||||
args = ["--num-workers=2", "--smoke-test", "--lr-reduce-on-plateau"]
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Tests from the python/ray/util/sgd/torch/examples/* directories.
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from ray.tune.utils import merge_dicts
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import ray
|
||||
@@ -659,6 +660,47 @@ def test_tune_train(ray_start_4_cpus, num_workers, use_local): # noqa: F811
|
||||
assert mean_val_loss2 <= mean_val_loss1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_tune_custom_train(ray_start_4_cpus, num_workers,
|
||||
use_local): # noqa: F811
|
||||
def custom_train_func(trainer, info):
|
||||
train_stats = trainer.train(profile=True)
|
||||
val_stats = trainer.validate(profile=True)
|
||||
stats = merge_dicts(train_stats, val_stats)
|
||||
return stats
|
||||
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
**{
|
||||
"override_tune_step": custom_train_func,
|
||||
"training_operator_cls": Operator,
|
||||
"num_workers": num_workers,
|
||||
"use_gpu": False,
|
||||
"backend": "gloo",
|
||||
"use_local": use_local,
|
||||
"config": {
|
||||
"batch_size": 512,
|
||||
"lr": 0.001
|
||||
}
|
||||
})
|
||||
|
||||
analysis = tune.run(
|
||||
TorchTrainable,
|
||||
num_samples=2,
|
||||
stop={"training_iteration": 2},
|
||||
verbose=1)
|
||||
|
||||
# checks loss decreasing for every trials
|
||||
for path, df in analysis.trial_dataframes.items():
|
||||
mean_train_loss1 = df.loc[0, "train_loss"]
|
||||
mean_train_loss2 = df.loc[1, "train_loss"]
|
||||
mean_val_loss1 = df.loc[0, "val_loss"]
|
||||
mean_val_loss2 = df.loc[1, "val_loss"]
|
||||
|
||||
assert mean_train_loss2 <= mean_train_loss1
|
||||
assert mean_val_loss2 <= mean_val_loss1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_save_and_restore(ray_start_2_cpus, num_workers, use_local,
|
||||
|
||||
@@ -9,6 +9,8 @@ in the documentation.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ray.tune.utils import merge_dicts
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import ray
|
||||
@@ -36,6 +38,12 @@ def data_creator(config):
|
||||
return train_loader, validation_loader
|
||||
|
||||
|
||||
def scheduler_creator(optimizer, config):
|
||||
"""Returns scheduler. We are using a ReduceLROnPleateau scheduler."""
|
||||
scheduler = ReduceLROnPlateau(optimizer, mode="min")
|
||||
return scheduler
|
||||
|
||||
|
||||
# __torch_tune_example__
|
||||
def tune_example(operator_cls, num_workers=1, use_gpu=False):
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
@@ -56,6 +64,39 @@ def tune_example(operator_cls, num_workers=1, use_gpu=False):
|
||||
# __end_torch_tune_example__
|
||||
|
||||
|
||||
# __torch_tune_manual_lr_example__
|
||||
def tune_example_manual(operator_cls, num_workers=1, use_gpu=False):
|
||||
def step(trainer, info: dict):
|
||||
"""Define a custom training loop for tune.
|
||||
This is needed because we want to manually update our scheduler.
|
||||
"""
|
||||
train_stats = trainer.train(profile=True)
|
||||
validation_stats = trainer.validate(profile=True)
|
||||
# Manually update our scheduler with the given metric.
|
||||
trainer.update_scheduler(metric=validation_stats["val_loss"])
|
||||
all_stats = merge_dicts(train_stats, validation_stats)
|
||||
return all_stats
|
||||
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
override_tune_step=step,
|
||||
training_operator_cls=operator_cls,
|
||||
num_workers=num_workers,
|
||||
use_gpu=use_gpu,
|
||||
scheduler_step_freq="manual",
|
||||
config={BATCH_SIZE: 128}
|
||||
)
|
||||
|
||||
analysis = tune.run(
|
||||
TorchTrainable,
|
||||
num_samples=3,
|
||||
config={"lr": tune.grid_search([1e-4, 1e-3])},
|
||||
stop={"training_iteration": 2},
|
||||
verbose=1)
|
||||
|
||||
return analysis.get_best_config(metric="val_loss", mode="min")
|
||||
# __end_torch_tune_manual_lr_example__
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -76,6 +117,13 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enables GPU training")
|
||||
parser.add_argument(
|
||||
"--lr-reduce-on-plateau",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="If enabled, use a ReduceLROnPlateau scheduler. If not set, "
|
||||
"no scheduler is used."
|
||||
)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
@@ -85,6 +133,12 @@ if __name__ == "__main__":
|
||||
ray.init(address=args.address)
|
||||
CustomTrainingOperator = TrainingOperator.from_creators(
|
||||
model_creator=model_creator, optimizer_creator=optimizer_creator,
|
||||
data_creator=data_creator, loss_creator=nn.MSELoss)
|
||||
tune_example(CustomTrainingOperator, num_workers=args.num_workers,
|
||||
use_gpu=args.use_gpu)
|
||||
data_creator=data_creator, loss_creator=nn.MSELoss,
|
||||
scheduler_creator=scheduler_creator if args.lr_reduce_on_plateau
|
||||
else None)
|
||||
if not args.lr_reduce_on_plateau:
|
||||
tune_example(CustomTrainingOperator, num_workers=args.num_workers,
|
||||
use_gpu=args.use_gpu)
|
||||
else:
|
||||
tune_example_manual(CustomTrainingOperator,
|
||||
num_workers=args.num_workers, use_gpu=args.use_gpu)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
@@ -571,25 +572,58 @@ class TorchTrainer:
|
||||
self.worker_group = DeactivatedWorkerGroup()
|
||||
|
||||
@classmethod
|
||||
def as_trainable(cls, *args, **kwargs):
|
||||
def as_trainable(cls, *args, override_tune_step=None, **kwargs):
|
||||
"""Creates a BaseTorchTrainable class compatible with Tune.
|
||||
|
||||
Any configuration parameters will be overridden by the Tune
|
||||
Trial configuration. You can also subclass the provided Trainable
|
||||
to implement your own iterative optimization routine.
|
||||
Trial configuration. You can also pass in a custom
|
||||
``override_tune_step`` to implement your own iterative optimization
|
||||
routine and override the default implementation.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def step(trainer, info):
|
||||
# Implement custom objective function here.
|
||||
train_stats = trainer.train()
|
||||
...
|
||||
# Return the metrics to report to tune.
|
||||
# Do not call tune.report here.
|
||||
return train_stats
|
||||
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
training_operator_cls=MyTrainingOperator,
|
||||
num_gpus=2
|
||||
num_gpus=2,
|
||||
override_tune_step=step
|
||||
)
|
||||
analysis = tune.run(
|
||||
TorchTrainable,
|
||||
config={"lr": tune.grid_search([0.01, 0.1])}
|
||||
)
|
||||
|
||||
Args:
|
||||
override_tune_step (Callable[[TorchTrainer, Dict], Dict]): A
|
||||
function to override the default training step to be used
|
||||
for Ray Tune. It accepts two arguments: the first one is an
|
||||
instance of your TorchTrainer, and the second one is a info
|
||||
dictionary, containing information about the Trainer
|
||||
state. If None is passed in, the default step
|
||||
function will be
|
||||
used: run 1 epoch of training, 1 epoch of validation,
|
||||
and report both results to Tune. Passing in
|
||||
``override_tune_step`` is useful to define
|
||||
custom step functions, for example if you need to
|
||||
manually update the scheduler or want to run more than 1
|
||||
training epoch for each tune iteration.
|
||||
|
||||
"""
|
||||
if override_tune_step is not None:
|
||||
callback_args = inspect.signature(override_tune_step)
|
||||
if not len(callback_args.parameters) == 2:
|
||||
raise ValueError("override_tune_step must take in exactly 2 "
|
||||
"arguments. The passed in function "
|
||||
"currently takes in {} "
|
||||
"args".format(
|
||||
str(len(callback_args.parameters))))
|
||||
|
||||
class TorchTrainable(BaseTorchTrainable):
|
||||
@classmethod
|
||||
@@ -618,6 +652,14 @@ class TorchTrainer:
|
||||
extra_cpu=int(remote_worker_count * num_cpus_per_worker),
|
||||
extra_gpu=int(int(use_gpu) * remote_worker_count))
|
||||
|
||||
def step(self):
|
||||
if override_tune_step is not None:
|
||||
output = override_tune_step(
|
||||
self._trainer, {"iteration": self.training_iteration})
|
||||
return output
|
||||
else:
|
||||
return super(TorchTrainable, self).step()
|
||||
|
||||
def _create_trainer(self, tune_config):
|
||||
"""Overrides the provided config with Tune config."""
|
||||
provided_config = kwargs.get("config", {}).copy()
|
||||
@@ -634,27 +676,29 @@ class BaseTorchTrainable(Trainable):
|
||||
|
||||
This class is produced when you call ``TorchTrainer.as_trainable(...)``.
|
||||
|
||||
You can override the produced Trainable to implement custom iterative
|
||||
training procedures:
|
||||
By default one step of training runs ``trainer.train()`` once and
|
||||
``trainer.validate()`` once. You can implement custom iterative
|
||||
training procedures by passing in a ``override_tune_step`` function to
|
||||
``as_trainable``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def custom_step(trainer, info):
|
||||
for i in range(5):
|
||||
train_stats = trainer.train()
|
||||
validation_stats = trainer.validate()
|
||||
train_stats.update(validation_stats)
|
||||
return train_stats
|
||||
|
||||
# TorchTrainable is subclass of BaseTorchTrainable.
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
training_operator_cls=MyTrainingOperator,
|
||||
num_gpus=2
|
||||
num_gpus=2,
|
||||
override_tune_step=custom_step
|
||||
)
|
||||
# TorchTrainable is subclass of BaseTorchTrainable.
|
||||
|
||||
class CustomTrainable(TorchTrainable):
|
||||
def step(self):
|
||||
for i in range(5):
|
||||
train_stats = self.trainer.train()
|
||||
validation_stats = self.trainer.validate()
|
||||
train_stats.update(validation_stats)
|
||||
return train_stats
|
||||
|
||||
analysis = tune.run(
|
||||
CustomTrainable,
|
||||
TorchTrainable,
|
||||
config={"lr": tune.grid_search([0.01, 0.1])}
|
||||
)
|
||||
|
||||
@@ -665,10 +709,7 @@ class BaseTorchTrainable(Trainable):
|
||||
self._trainer = self._create_trainer(config)
|
||||
|
||||
def step(self):
|
||||
"""Calls `self.trainer.train()` and `self.trainer.validate()` once.
|
||||
|
||||
You may want to override this if using a custom LR scheduler.
|
||||
"""
|
||||
"""Calls `self.trainer.train()` and `self.trainer.validate()` once."""
|
||||
if self._is_overridden("_train"):
|
||||
raise DeprecationWarning(
|
||||
"Trainable._train is deprecated and will be "
|
||||
|
||||
Reference in New Issue
Block a user