diff --git a/python/ray/util/sgd/BUILD b/python/ray/util/sgd/BUILD index f614c812e..31191ac7a 100644 --- a/python/ray/util/sgd/BUILD +++ b/python/ray/util/sgd/BUILD @@ -197,6 +197,16 @@ py_test( args = ["--num-workers=2", "--smoke-test"] ) +py_test( + name = "tune_example_3", + size = "small", + main = "torch/examples/tune_example.py", + srcs = ["torch/examples/tune_example.py"], + tags = ["exclusive", "pytorch"], + deps = [":sgd_lib"], + args = ["--num-workers=2", "--smoke-test", "--lr-reduce-on-plateau"] +) + # -------------------------------------------------------------------- # Tests from the python/ray/util/sgd/torch/examples/* directories. diff --git a/python/ray/util/sgd/tests/test_torch.py b/python/ray/util/sgd/tests/test_torch.py index cf57d326e..a34ac3534 100644 --- a/python/ray/util/sgd/tests/test_torch.py +++ b/python/ray/util/sgd/tests/test_torch.py @@ -4,6 +4,7 @@ import pytest import torch import torch.nn as nn import torch.distributed as dist +from ray.tune.utils import merge_dicts from torch.utils.data import DataLoader import ray @@ -659,6 +660,47 @@ def test_tune_train(ray_start_4_cpus, num_workers, use_local): # noqa: F811 assert mean_val_loss2 <= mean_val_loss1 +@pytest.mark.parametrize("num_workers", [2] if dist.is_available() else [1]) +@pytest.mark.parametrize("use_local", [True, False]) +def test_tune_custom_train(ray_start_4_cpus, num_workers, + use_local): # noqa: F811 + def custom_train_func(trainer, info): + train_stats = trainer.train(profile=True) + val_stats = trainer.validate(profile=True) + stats = merge_dicts(train_stats, val_stats) + return stats + + TorchTrainable = TorchTrainer.as_trainable( + **{ + "override_tune_step": custom_train_func, + "training_operator_cls": Operator, + "num_workers": num_workers, + "use_gpu": False, + "backend": "gloo", + "use_local": use_local, + "config": { + "batch_size": 512, + "lr": 0.001 + } + }) + + analysis = tune.run( + TorchTrainable, + num_samples=2, + stop={"training_iteration": 2}, + verbose=1) + + # checks loss decreasing for every trials + for path, df in analysis.trial_dataframes.items(): + mean_train_loss1 = df.loc[0, "train_loss"] + mean_train_loss2 = df.loc[1, "train_loss"] + mean_val_loss1 = df.loc[0, "val_loss"] + mean_val_loss2 = df.loc[1, "val_loss"] + + assert mean_train_loss2 <= mean_train_loss1 + assert mean_val_loss2 <= mean_val_loss1 + + @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) @pytest.mark.parametrize("use_local", [True, False]) def test_save_and_restore(ray_start_2_cpus, num_workers, use_local, diff --git a/python/ray/util/sgd/torch/examples/tune_example.py b/python/ray/util/sgd/torch/examples/tune_example.py index f5c226851..6f918a19a 100644 --- a/python/ray/util/sgd/torch/examples/tune_example.py +++ b/python/ray/util/sgd/torch/examples/tune_example.py @@ -9,6 +9,8 @@ in the documentation. import torch import torch.nn as nn +from ray.tune.utils import merge_dicts +from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.data import DataLoader import ray @@ -36,6 +38,12 @@ def data_creator(config): return train_loader, validation_loader +def scheduler_creator(optimizer, config): + """Returns scheduler. We are using a ReduceLROnPleateau scheduler.""" + scheduler = ReduceLROnPlateau(optimizer, mode="min") + return scheduler + + # __torch_tune_example__ def tune_example(operator_cls, num_workers=1, use_gpu=False): TorchTrainable = TorchTrainer.as_trainable( @@ -56,6 +64,39 @@ def tune_example(operator_cls, num_workers=1, use_gpu=False): # __end_torch_tune_example__ +# __torch_tune_manual_lr_example__ +def tune_example_manual(operator_cls, num_workers=1, use_gpu=False): + def step(trainer, info: dict): + """Define a custom training loop for tune. + This is needed because we want to manually update our scheduler. + """ + train_stats = trainer.train(profile=True) + validation_stats = trainer.validate(profile=True) + # Manually update our scheduler with the given metric. + trainer.update_scheduler(metric=validation_stats["val_loss"]) + all_stats = merge_dicts(train_stats, validation_stats) + return all_stats + + TorchTrainable = TorchTrainer.as_trainable( + override_tune_step=step, + training_operator_cls=operator_cls, + num_workers=num_workers, + use_gpu=use_gpu, + scheduler_step_freq="manual", + config={BATCH_SIZE: 128} + ) + + analysis = tune.run( + TorchTrainable, + num_samples=3, + config={"lr": tune.grid_search([1e-4, 1e-3])}, + stop={"training_iteration": 2}, + verbose=1) + + return analysis.get_best_config(metric="val_loss", mode="min") +# __end_torch_tune_manual_lr_example__ + + if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() @@ -76,6 +117,13 @@ if __name__ == "__main__": action="store_true", default=False, help="Enables GPU training") + parser.add_argument( + "--lr-reduce-on-plateau", + action="store_true", + default=False, + help="If enabled, use a ReduceLROnPlateau scheduler. If not set, " + "no scheduler is used." + ) args, _ = parser.parse_known_args() @@ -85,6 +133,12 @@ if __name__ == "__main__": ray.init(address=args.address) CustomTrainingOperator = TrainingOperator.from_creators( model_creator=model_creator, optimizer_creator=optimizer_creator, - data_creator=data_creator, loss_creator=nn.MSELoss) - tune_example(CustomTrainingOperator, num_workers=args.num_workers, - use_gpu=args.use_gpu) + data_creator=data_creator, loss_creator=nn.MSELoss, + scheduler_creator=scheduler_creator if args.lr_reduce_on_plateau + else None) + if not args.lr_reduce_on_plateau: + tune_example(CustomTrainingOperator, num_workers=args.num_workers, + use_gpu=args.use_gpu) + else: + tune_example_manual(CustomTrainingOperator, + num_workers=args.num_workers, use_gpu=args.use_gpu) diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 19de678ee..e49d0c3aa 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -1,3 +1,4 @@ +import inspect import time import numpy as np @@ -571,25 +572,58 @@ class TorchTrainer: self.worker_group = DeactivatedWorkerGroup() @classmethod - def as_trainable(cls, *args, **kwargs): + def as_trainable(cls, *args, override_tune_step=None, **kwargs): """Creates a BaseTorchTrainable class compatible with Tune. Any configuration parameters will be overridden by the Tune - Trial configuration. You can also subclass the provided Trainable - to implement your own iterative optimization routine. + Trial configuration. You can also pass in a custom + ``override_tune_step`` to implement your own iterative optimization + routine and override the default implementation. .. code-block:: python + def step(trainer, info): + # Implement custom objective function here. + train_stats = trainer.train() + ... + # Return the metrics to report to tune. + # Do not call tune.report here. + return train_stats + TorchTrainable = TorchTrainer.as_trainable( training_operator_cls=MyTrainingOperator, - num_gpus=2 + num_gpus=2, + override_tune_step=step ) analysis = tune.run( TorchTrainable, config={"lr": tune.grid_search([0.01, 0.1])} ) + Args: + override_tune_step (Callable[[TorchTrainer, Dict], Dict]): A + function to override the default training step to be used + for Ray Tune. It accepts two arguments: the first one is an + instance of your TorchTrainer, and the second one is a info + dictionary, containing information about the Trainer + state. If None is passed in, the default step + function will be + used: run 1 epoch of training, 1 epoch of validation, + and report both results to Tune. Passing in + ``override_tune_step`` is useful to define + custom step functions, for example if you need to + manually update the scheduler or want to run more than 1 + training epoch for each tune iteration. + """ + if override_tune_step is not None: + callback_args = inspect.signature(override_tune_step) + if not len(callback_args.parameters) == 2: + raise ValueError("override_tune_step must take in exactly 2 " + "arguments. The passed in function " + "currently takes in {} " + "args".format( + str(len(callback_args.parameters)))) class TorchTrainable(BaseTorchTrainable): @classmethod @@ -618,6 +652,14 @@ class TorchTrainer: extra_cpu=int(remote_worker_count * num_cpus_per_worker), extra_gpu=int(int(use_gpu) * remote_worker_count)) + def step(self): + if override_tune_step is not None: + output = override_tune_step( + self._trainer, {"iteration": self.training_iteration}) + return output + else: + return super(TorchTrainable, self).step() + def _create_trainer(self, tune_config): """Overrides the provided config with Tune config.""" provided_config = kwargs.get("config", {}).copy() @@ -634,27 +676,29 @@ class BaseTorchTrainable(Trainable): This class is produced when you call ``TorchTrainer.as_trainable(...)``. - You can override the produced Trainable to implement custom iterative - training procedures: + By default one step of training runs ``trainer.train()`` once and + ``trainer.validate()`` once. You can implement custom iterative + training procedures by passing in a ``override_tune_step`` function to + ``as_trainable``: .. code-block:: python + def custom_step(trainer, info): + for i in range(5): + train_stats = trainer.train() + validation_stats = trainer.validate() + train_stats.update(validation_stats) + return train_stats + + # TorchTrainable is subclass of BaseTorchTrainable. TorchTrainable = TorchTrainer.as_trainable( training_operator_cls=MyTrainingOperator, - num_gpus=2 + num_gpus=2, + override_tune_step=custom_step ) - # TorchTrainable is subclass of BaseTorchTrainable. - - class CustomTrainable(TorchTrainable): - def step(self): - for i in range(5): - train_stats = self.trainer.train() - validation_stats = self.trainer.validate() - train_stats.update(validation_stats) - return train_stats analysis = tune.run( - CustomTrainable, + TorchTrainable, config={"lr": tune.grid_search([0.01, 0.1])} ) @@ -665,10 +709,7 @@ class BaseTorchTrainable(Trainable): self._trainer = self._create_trainer(config) def step(self): - """Calls `self.trainer.train()` and `self.trainer.validate()` once. - - You may want to override this if using a custom LR scheduler. - """ + """Calls `self.trainer.train()` and `self.trainer.validate()` once.""" if self._is_overridden("_train"): raise DeprecationWarning( "Trainable._train is deprecated and will be "