diff --git a/doc/source/pbt.png b/doc/source/pbt.png index e05c1d89b..7ff4e03a4 100644 Binary files a/doc/source/pbt.png and b/doc/source/pbt.png differ diff --git a/doc/source/tune-schedulers.rst b/doc/source/tune-schedulers.rst index 4c6919fb3..cbb105ff5 100644 --- a/doc/source/tune-schedulers.rst +++ b/doc/source/tune-schedulers.rst @@ -36,7 +36,7 @@ Tune includes a distributed implementation of `Population Based Training (PBT) < When the PBT scheduler is enabled, each trial variant is treated as a member of the population. Periodically, top-performing trials are checkpointed (this requires your Trainable to support `checkpointing `__). Low-performing trials clone the checkpoints of top performers and perturb the configurations in the hope of discovering an even better variation. -You can run this `toy PBT example `__ to get an idea of how how PBT operates. When training in PBT mode, a single trial may see many different hyperparameters over its lifetime, which is recorded in its ``result.json`` file. The following figure generated by the example shows PBT discovering new hyperparams over the course of a single experiment: +You can run this `toy PBT example `__ to get an idea of how how PBT operates. When training in PBT mode, a single trial may see many different hyperparameters over its lifetime, which is recorded in its ``result.json`` file. The following figure generated by the example shows PBT with optimizing a LR schedule over the course of a single experiment: .. image:: pbt.png diff --git a/python/ray/tune/examples/pbt_example.py b/python/ray/tune/examples/pbt_example.py index 8ff0bea0c..7c8dda483 100755 --- a/python/ray/tune/examples/pbt_example.py +++ b/python/ray/tune/examples/pbt_example.py @@ -4,57 +4,83 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np import argparse -import json -import os import random -import time import ray from ray.tune import Trainable, run from ray.tune.schedulers import PopulationBasedTraining -class MyTrainableClass(Trainable): - """Fake agent whose learning rate is determined by dummy factors.""" +class PBTBenchmarkExample(Trainable): + """Toy PBT problem for benchmarking adaptive learning rate. + + The goal is to optimize this trainable's accuracy. The accuracy increases + fastest at the optimal lr, which is a function of the current accuracy. + + The optimal lr schedule for this problem is the triangle wave as follows. + Note that many lr schedules for real models also follow this shape: + + best lr + ^ + | /\ + | / \ + | / \ + | / \ + ------------> accuracy + + In this problem, using PBT with a population of 2-4 is sufficient to + roughly approximate this lr schedule. Higher population sizes will yield + faster convergence. Training will not converge without PBT. + """ def _setup(self, config): - self.timestep = 0 - self.current_value = 0.0 + self.lr = config["lr"] + self.accuracy = 0.0 # end = 1000 def _train(self): - time.sleep(0.1) + midpoint = 100 # lr starts decreasing after acc > midpoint + q_tolerance = 3 # penalize exceeding lr by more than this multiple + noise_level = 2 # add gaussian noise to the acc increase + # triangle wave: + # - start at 0.001 @ t=0, + # - peak at 0.01 @ t=midpoint, + # - end at 0.001 @ t=midpoint * 2, + if self.accuracy < midpoint: + optimal_lr = 0.01 * self.accuracy / midpoint + else: + optimal_lr = 0.01 - 0.01 * (self.accuracy - midpoint) / midpoint + optimal_lr = min(0.01, max(0.001, optimal_lr)) - # Reward increase is parabolic as a function of factor_2, with a - # maxima around factor_1=10.0. - self.current_value += max( - 0.0, random.gauss(5.0 - (self.config["factor_1"] - 10.0)**2, 2.0)) + # compute accuracy increase + q_err = max(self.lr, optimal_lr) / min(self.lr, optimal_lr) + if q_err < q_tolerance: + self.accuracy += (1.0 / q_err) * random.random() + elif self.lr > optimal_lr: + self.accuracy -= (q_err - q_tolerance) * random.random() + self.accuracy += noise_level * np.random.normal() + self.accuracy = max(0, self.accuracy) - # Flat increase by factor_2 - self.current_value += random.gauss(self.config["factor_2"], 1.0) - - # Here we use `episode_reward_mean`, but you can also report other - # objectives such as loss or accuracy. - return {"episode_reward_mean": self.current_value} + return { + "mean_accuracy": self.accuracy, + "cur_lr": self.lr, + "optimal_lr": optimal_lr, # for debugging + "q_err": q_err, # for debugging + "done": self.accuracy > midpoint * 2, + } def _save(self, checkpoint_dir): - path = os.path.join(checkpoint_dir, "checkpoint") - with open(path, "w") as f: - f.write( - json.dumps({ - "timestep": self.timestep, - "value": self.current_value - })) - return path + return { + "accuracy": self.accuracy, + "lr": self.lr, + } - def _restore(self, checkpoint_path): - with open(checkpoint_path) as f: - data = json.loads(f.read()) - self.timestep = data["timestep"] - self.current_value = data["value"] + def _restore(self, checkpoint): + self.accuracy = checkpoint["accuracy"] def reset_config(self, new_config): - self.config = new_config + self.lr = new_config["lr"] return True @@ -64,35 +90,36 @@ if __name__ == "__main__": "--smoke-test", action="store_true", help="Finish quickly for testing") args, _ = parser.parse_known_args() if args.smoke_test: - ray.init(num_cpus=4) # force pausing to happen for test + ray.init(num_cpus=2) # force pausing to happen for test else: ray.init() pbt = PopulationBasedTraining( time_attr="training_iteration", - reward_attr="episode_reward_mean", - perturbation_interval=10, + reward_attr="mean_accuracy", + perturbation_interval=20, hyperparam_mutations={ - # Allow for scaling-based perturbations, with a uniform backing - # distribution for resampling. - "factor_1": lambda: random.uniform(0.0, 20.0), - # Allow perturbations within this set of categorical values. - "factor_2": [1, 2], + # distribution for resampling + "lr": lambda: random.uniform(0.0001, 0.02), + # allow perturbations within this set of categorical values + "some_other_factor": [1, 2], }) - # Try to find the best factor 1 and factor 2 - run(MyTrainableClass, + run( + PBTBenchmarkExample, name="pbt_test", scheduler=pbt, reuse_actors=True, verbose=False, **{ "stop": { - "training_iteration": 20 if args.smoke_test else 99999 + "training_iteration": 2000, }, - "num_samples": 10, + "num_samples": 4, "config": { - "factor_1": 4.0, - "factor_2": 1.0, + "lr": 0.0001, + # note: this parameter is perturbed but has no effect on + # the model training in this example + "some_other_factor": 1, }, })