[SGD] Callback API for SGD+Tune (#11316)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Amog Kamsetty
2020-10-15 15:22:14 -07:00
committed by GitHub
parent 414041c6dd
commit 38eb61442b
4 changed files with 171 additions and 24 deletions
+10
View File
@@ -197,6 +197,16 @@ py_test(
args = ["--num-workers=2", "--smoke-test"]
)
py_test(
name = "tune_example_3",
size = "small",
main = "torch/examples/tune_example.py",
srcs = ["torch/examples/tune_example.py"],
tags = ["exclusive", "pytorch"],
deps = [":sgd_lib"],
args = ["--num-workers=2", "--smoke-test", "--lr-reduce-on-plateau"]
)
# --------------------------------------------------------------------
# Tests from the python/ray/util/sgd/torch/examples/* directories.
+42
View File
@@ -4,6 +4,7 @@ import pytest
import torch
import torch.nn as nn
import torch.distributed as dist
from ray.tune.utils import merge_dicts
from torch.utils.data import DataLoader
import ray
@@ -659,6 +660,47 @@ def test_tune_train(ray_start_4_cpus, num_workers, use_local): # noqa: F811
assert mean_val_loss2 <= mean_val_loss1
@pytest.mark.parametrize("num_workers", [2] if dist.is_available() else [1])
@pytest.mark.parametrize("use_local", [True, False])
def test_tune_custom_train(ray_start_4_cpus, num_workers,
use_local): # noqa: F811
def custom_train_func(trainer, info):
train_stats = trainer.train(profile=True)
val_stats = trainer.validate(profile=True)
stats = merge_dicts(train_stats, val_stats)
return stats
TorchTrainable = TorchTrainer.as_trainable(
**{
"override_tune_step": custom_train_func,
"training_operator_cls": Operator,
"num_workers": num_workers,
"use_gpu": False,
"backend": "gloo",
"use_local": use_local,
"config": {
"batch_size": 512,
"lr": 0.001
}
})
analysis = tune.run(
TorchTrainable,
num_samples=2,
stop={"training_iteration": 2},
verbose=1)
# checks loss decreasing for every trials
for path, df in analysis.trial_dataframes.items():
mean_train_loss1 = df.loc[0, "train_loss"]
mean_train_loss2 = df.loc[1, "train_loss"]
mean_val_loss1 = df.loc[0, "val_loss"]
mean_val_loss2 = df.loc[1, "val_loss"]
assert mean_train_loss2 <= mean_train_loss1
assert mean_val_loss2 <= mean_val_loss1
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
@pytest.mark.parametrize("use_local", [True, False])
def test_save_and_restore(ray_start_2_cpus, num_workers, use_local,
@@ -9,6 +9,8 @@ in the documentation.
import torch
import torch.nn as nn
from ray.tune.utils import merge_dicts
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
import ray
@@ -36,6 +38,12 @@ def data_creator(config):
return train_loader, validation_loader
def scheduler_creator(optimizer, config):
"""Returns scheduler. We are using a ReduceLROnPleateau scheduler."""
scheduler = ReduceLROnPlateau(optimizer, mode="min")
return scheduler
# __torch_tune_example__
def tune_example(operator_cls, num_workers=1, use_gpu=False):
TorchTrainable = TorchTrainer.as_trainable(
@@ -56,6 +64,39 @@ def tune_example(operator_cls, num_workers=1, use_gpu=False):
# __end_torch_tune_example__
# __torch_tune_manual_lr_example__
def tune_example_manual(operator_cls, num_workers=1, use_gpu=False):
def step(trainer, info: dict):
"""Define a custom training loop for tune.
This is needed because we want to manually update our scheduler.
"""
train_stats = trainer.train(profile=True)
validation_stats = trainer.validate(profile=True)
# Manually update our scheduler with the given metric.
trainer.update_scheduler(metric=validation_stats["val_loss"])
all_stats = merge_dicts(train_stats, validation_stats)
return all_stats
TorchTrainable = TorchTrainer.as_trainable(
override_tune_step=step,
training_operator_cls=operator_cls,
num_workers=num_workers,
use_gpu=use_gpu,
scheduler_step_freq="manual",
config={BATCH_SIZE: 128}
)
analysis = tune.run(
TorchTrainable,
num_samples=3,
config={"lr": tune.grid_search([1e-4, 1e-3])},
stop={"training_iteration": 2},
verbose=1)
return analysis.get_best_config(metric="val_loss", mode="min")
# __end_torch_tune_manual_lr_example__
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
@@ -76,6 +117,13 @@ if __name__ == "__main__":
action="store_true",
default=False,
help="Enables GPU training")
parser.add_argument(
"--lr-reduce-on-plateau",
action="store_true",
default=False,
help="If enabled, use a ReduceLROnPlateau scheduler. If not set, "
"no scheduler is used."
)
args, _ = parser.parse_known_args()
@@ -85,6 +133,12 @@ if __name__ == "__main__":
ray.init(address=args.address)
CustomTrainingOperator = TrainingOperator.from_creators(
model_creator=model_creator, optimizer_creator=optimizer_creator,
data_creator=data_creator, loss_creator=nn.MSELoss)
tune_example(CustomTrainingOperator, num_workers=args.num_workers,
use_gpu=args.use_gpu)
data_creator=data_creator, loss_creator=nn.MSELoss,
scheduler_creator=scheduler_creator if args.lr_reduce_on_plateau
else None)
if not args.lr_reduce_on_plateau:
tune_example(CustomTrainingOperator, num_workers=args.num_workers,
use_gpu=args.use_gpu)
else:
tune_example_manual(CustomTrainingOperator,
num_workers=args.num_workers, use_gpu=args.use_gpu)
+62 -21
View File
@@ -1,3 +1,4 @@
import inspect
import time
import numpy as np
@@ -571,25 +572,58 @@ class TorchTrainer:
self.worker_group = DeactivatedWorkerGroup()
@classmethod
def as_trainable(cls, *args, **kwargs):
def as_trainable(cls, *args, override_tune_step=None, **kwargs):
"""Creates a BaseTorchTrainable class compatible with Tune.
Any configuration parameters will be overridden by the Tune
Trial configuration. You can also subclass the provided Trainable
to implement your own iterative optimization routine.
Trial configuration. You can also pass in a custom
``override_tune_step`` to implement your own iterative optimization
routine and override the default implementation.
.. code-block:: python
def step(trainer, info):
# Implement custom objective function here.
train_stats = trainer.train()
...
# Return the metrics to report to tune.
# Do not call tune.report here.
return train_stats
TorchTrainable = TorchTrainer.as_trainable(
training_operator_cls=MyTrainingOperator,
num_gpus=2
num_gpus=2,
override_tune_step=step
)
analysis = tune.run(
TorchTrainable,
config={"lr": tune.grid_search([0.01, 0.1])}
)
Args:
override_tune_step (Callable[[TorchTrainer, Dict], Dict]): A
function to override the default training step to be used
for Ray Tune. It accepts two arguments: the first one is an
instance of your TorchTrainer, and the second one is a info
dictionary, containing information about the Trainer
state. If None is passed in, the default step
function will be
used: run 1 epoch of training, 1 epoch of validation,
and report both results to Tune. Passing in
``override_tune_step`` is useful to define
custom step functions, for example if you need to
manually update the scheduler or want to run more than 1
training epoch for each tune iteration.
"""
if override_tune_step is not None:
callback_args = inspect.signature(override_tune_step)
if not len(callback_args.parameters) == 2:
raise ValueError("override_tune_step must take in exactly 2 "
"arguments. The passed in function "
"currently takes in {} "
"args".format(
str(len(callback_args.parameters))))
class TorchTrainable(BaseTorchTrainable):
@classmethod
@@ -618,6 +652,14 @@ class TorchTrainer:
extra_cpu=int(remote_worker_count * num_cpus_per_worker),
extra_gpu=int(int(use_gpu) * remote_worker_count))
def step(self):
if override_tune_step is not None:
output = override_tune_step(
self._trainer, {"iteration": self.training_iteration})
return output
else:
return super(TorchTrainable, self).step()
def _create_trainer(self, tune_config):
"""Overrides the provided config with Tune config."""
provided_config = kwargs.get("config", {}).copy()
@@ -634,27 +676,29 @@ class BaseTorchTrainable(Trainable):
This class is produced when you call ``TorchTrainer.as_trainable(...)``.
You can override the produced Trainable to implement custom iterative
training procedures:
By default one step of training runs ``trainer.train()`` once and
``trainer.validate()`` once. You can implement custom iterative
training procedures by passing in a ``override_tune_step`` function to
``as_trainable``:
.. code-block:: python
def custom_step(trainer, info):
for i in range(5):
train_stats = trainer.train()
validation_stats = trainer.validate()
train_stats.update(validation_stats)
return train_stats
# TorchTrainable is subclass of BaseTorchTrainable.
TorchTrainable = TorchTrainer.as_trainable(
training_operator_cls=MyTrainingOperator,
num_gpus=2
num_gpus=2,
override_tune_step=custom_step
)
# TorchTrainable is subclass of BaseTorchTrainable.
class CustomTrainable(TorchTrainable):
def step(self):
for i in range(5):
train_stats = self.trainer.train()
validation_stats = self.trainer.validate()
train_stats.update(validation_stats)
return train_stats
analysis = tune.run(
CustomTrainable,
TorchTrainable,
config={"lr": tune.grid_search([0.01, 0.1])}
)
@@ -665,10 +709,7 @@ class BaseTorchTrainable(Trainable):
self._trainer = self._create_trainer(config)
def step(self):
"""Calls `self.trainer.train()` and `self.trainer.validate()` once.
You may want to override this if using a custom LR scheduler.
"""
"""Calls `self.trainer.train()` and `self.trainer.validate()` once."""
if self._is_overridden("_train"):
raise DeprecationWarning(
"Trainable._train is deprecated and will be "