From 0ef8224446ea3b1b65aa915c0e58905f150cc435 Mon Sep 17 00:00:00 2001 From: krfricke Date: Fri, 7 Aug 2020 21:41:49 +0200 Subject: [PATCH] [tune] PBT replay utility scheduler (#9953) Co-authored-by: Kai Fricke --- .../_tutorials/tune-advanced-tutorial.rst | 58 ++++-- doc/source/tune/api_docs/schedulers.rst | 24 +++ .../ray/tune/examples/pbt_convnet_example.py | 18 +- python/ray/tune/schedulers/__init__.py | 6 +- python/ray/tune/schedulers/pbt.py | 166 ++++++++++++++++++ python/ray/tune/tests/test_trial_scheduler.py | 144 ++++++++++++++- 6 files changed, 390 insertions(+), 26 deletions(-) diff --git a/doc/source/tune/_tutorials/tune-advanced-tutorial.rst b/doc/source/tune/_tutorials/tune-advanced-tutorial.rst index e1af9ce3e..ad62f59d1 100644 --- a/doc/source/tune/_tutorials/tune-advanced-tutorial.rst +++ b/doc/source/tune/_tutorials/tune-advanced-tutorial.rst @@ -77,19 +77,19 @@ During the training, we can constantly check the status of the models from conso .. code-block:: bash == Status == - Memory usage on this node: 10.4/16.0 GiB - PopulationBasedTraining: 4 checkpoints, 1 perturbs - Resources requested: 4/12 CPUs, 0/0 GPUs, 0.0/3.42 GiB heap, 0.0/1.17 GiB objects - Number of trials: 4 ({'RUNNING': 4}) - Result logdir: /Users/yuhao.yang/ray_results/pbt_test - +--------------------------+----------+---------------------+----------+------------+--------+------------------+----------+ - | Trial name | status | loc | lr | momentum | iter | total time (s) | acc | - |--------------------------+----------+---------------------+----------+------------+--------+------------------+----------| - | PytorchTrainble_3b42d914 | RUNNING | 30.57.180.224:49840 | 0.122032 | 0.302176 | 18 | 3.8689 | 0.8875 | - | PytorchTrainble_3b45091e | RUNNING | 30.57.180.224:49835 | 0.505325 | 0.628559 | 18 | 3.90404 | 0.134375 | - | PytorchTrainble_3b454c46 | RUNNING | 30.57.180.224:49843 | 0.490228 | 0.969013 | 17 | 3.72111 | 0.0875 | - | PytorchTrainble_3b458a9c | RUNNING | 30.57.180.224:49833 | 0.961861 | 0.169701 | 13 | 2.72594 | 0.1125 | - +--------------------------+----------+---------------------+----------+------------+--------+------------------+----------+ + Memory usage on this node: 11.6/16.0 GiB + PopulationBasedTraining: 5 checkpoints, 4 perturbs + Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/3.96 GiB heap, 0.0/1.37 GiB objects + Result logdir: /Users/foo/ray_results/pbt_test + Number of trials: 4 (4 TERMINATED) + +------------------------------+------------+-------+-----------+------------+----------+--------+------------------+ + | Trial name | status | loc | lr | momentum | acc | iter | total time (s) | + |------------------------------+------------+-------+-----------+------------+----------+--------+------------------| + | PytorchTrainable_ba982_00000 | TERMINATED | | 0.0457501 | 0.99 | 0.6375 | 25 | 5.35712 | + | PytorchTrainable_ba982_00001 | TERMINATED | | 0.175808 | 0.0667043 | 0.909375 | 29 | 6.18802 | + | PytorchTrainable_ba982_00002 | TERMINATED | | 0.21097 | 0.99 | 0.040625 | 29 | 6.19634 | + | PytorchTrainable_ba982_00003 | TERMINATED | | 0.0571876 | 0.852088 | 0.96875 | 30 | 6.37298 | + +------------------------------+------------+-------+-----------+------------+----------+--------+------------------+ In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in pbt_global.txt and individual policy perturbations are recorded in pbt_policy_{i}.txt. Tune logs: @@ -114,6 +114,38 @@ Checking the accuracy: .. image:: /images/tune_advanced_plot1.png +.. _tune-advanced-tutorial-pbt-replay: + +Replaying a PBT run +------------------- +A run of Population Based Training ends with fully trained models. However, sometimes +you might like to train the model from scratch, but use the same hyperparameter +schedule as obtained from PBT. Ray Tune offers a replay utility for this. + +All you need to do is pass the policy log file for the trial you want to replay. +This is usually stored in the experiment directory, for instance +``~/ray_results/pbt_test/pbt_policy_ba982_00000.txt``. + +The replay utility reads the original configuration for the trial and updates it +each time when it was originally perturbed. You can (and should) +thus just use the same ``Trainable`` for the replay run. + +.. code-block:: python + + from ray import tune + + from ray.tune.examples.pbt_convnet_example import PytorchTrainable + from ray.tune.schedulers import PopulationBasedTrainingReplay + + replay = PopulationBasedTrainingReplay( + "~/ray_results/pbt_test/pbt_policy_ba982_00003.txt") + + tune.run( + PytorchTrainable, + scheduler=replay, + stop={"training_iteration": 100}) + + DCGAN with Trainable and PBT ---------------------------- diff --git a/doc/source/tune/api_docs/schedulers.rst b/doc/source/tune/api_docs/schedulers.rst index 6c623d869..811da9824 100644 --- a/doc/source/tune/api_docs/schedulers.rst +++ b/doc/source/tune/api_docs/schedulers.rst @@ -148,6 +148,30 @@ You can run this :doc:`toy PBT example ` to get an .. autoclass:: ray.tune.schedulers.PopulationBasedTraining +.. _tune-scheduler-pbt-replay: + +Population Based Training Replay (tune.schedulers.PopulationBasedTrainingReplay) +-------------------------------------------------------------------------------- + +Tune includes a utility to replay hyperparameter schedules of Population Based Training runs. +You just specify an existing experiment directory and the ID of the trial you would +like to replay. The scheduler accepts only one trial, and it will update its +config according to the obtained schedule. + +.. code-block:: python + + replay = PopulationBasedTrainingReplay( + experiment_dir="~/ray_results/pbt_experiment/", + trial_id="XXXXX_00001") + tune.run( + ..., + scheduler=replay) + +See :ref:`here for an example ` on how to use the +replay utility in practice. + +.. autoclass:: ray.tune.schedulers.PopulationBasedTrainingReplay + .. _tune-scheduler-bohb: BOHB (tune.schedulers.HyperBandForBOHB) diff --git a/python/ray/tune/examples/pbt_convnet_example.py b/python/ray/tune/examples/pbt_convnet_example.py index 8aa6b1d9d..656899966 100644 --- a/python/ray/tune/examples/pbt_convnet_example.py +++ b/python/ray/tune/examples/pbt_convnet_example.py @@ -1,5 +1,8 @@ #!/usr/bin/env python +# flake8: noqa +# yapf: disable + # __tutorial_imports_begin__ import argparse import os @@ -15,12 +18,11 @@ from ray import tune from ray.tune.schedulers import PopulationBasedTraining from ray.tune.utils import validate_save_restore from ray.tune.trial import ExportFormat - # __tutorial_imports_end__ # __trainable_begin__ -class PytorchTrainble(tune.Trainable): +class PytorchTrainable(tune.Trainable): """Train a Pytorch ConvNet with Trainable and PopulationBasedTraining scheduler. The example reuse some of the functions in mnist_pytorch, and is a good demo for how to add the tuning function without @@ -65,10 +67,9 @@ class PytorchTrainble(tune.Trainable): self.config = new_config return True - - # __trainable_end__ + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -79,8 +80,8 @@ if __name__ == "__main__": datasets.MNIST("~/data", train=True, download=True) # check if PytorchTrainble will save/restore correctly before execution - validate_save_restore(PytorchTrainble) - validate_save_restore(PytorchTrainble, use_object_store=True) + validate_save_restore(PytorchTrainable) + validate_save_restore(PytorchTrainable, use_object_store=True) # __pbt_begin__ scheduler = PopulationBasedTraining( @@ -94,7 +95,6 @@ if __name__ == "__main__": # allow perturbations within this set of categorical values "momentum": [0.8, 0.9, 0.99], }) - # __pbt_end__ # __tune_begin__ @@ -114,7 +114,7 @@ if __name__ == "__main__": stopper = CustomStopper() analysis = tune.run( - PytorchTrainble, + PytorchTrainable, name="pbt_test", scheduler=scheduler, reuse_actors=True, @@ -134,7 +134,7 @@ if __name__ == "__main__": best_trial = analysis.get_best_trial("mean_accuracy") best_checkpoint = max( analysis.get_trial_checkpoints_paths(best_trial, "mean_accuracy")) - restored_trainable = PytorchTrainble() + restored_trainable = PytorchTrainable() restored_trainable.restore(best_checkpoint[0]) best_model = restored_trainable.model # Note that test only runs on a small random set of the test data, thus the diff --git a/python/ray/tune/schedulers/__init__.py b/python/ray/tune/schedulers/__init__.py index fdbe016c3..70554789d 100644 --- a/python/ray/tune/schedulers/__init__.py +++ b/python/ray/tune/schedulers/__init__.py @@ -4,10 +4,12 @@ from ray.tune.schedulers.hb_bohb import HyperBandForBOHB from ray.tune.schedulers.async_hyperband import (AsyncHyperBandScheduler, ASHAScheduler) from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule -from ray.tune.schedulers.pbt import PopulationBasedTraining +from ray.tune.schedulers.pbt import (PopulationBasedTraining, + PopulationBasedTrainingReplay) __all__ = [ "TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler", "ASHAScheduler", "MedianStoppingRule", "FIFOScheduler", - "PopulationBasedTraining", "HyperBandForBOHB" + "PopulationBasedTraining", "PopulationBasedTrainingReplay", + "HyperBandForBOHB" ] diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index 495055c12..a8632de3d 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -393,3 +393,169 @@ class PopulationBasedTraining(FIFOScheduler): def debug_string(self): return "PopulationBasedTraining: {} checkpoints, {} perturbs".format( self._num_checkpoints, self._num_perturbations) + + +class PopulationBasedTrainingReplay(FIFOScheduler): + """Replays a Population Based Training run. + + Population Based Training does not return a single hyperparameter + configuration, but rather a schedule of configurations. For instance, + PBT might discover that a larger learning rate leads to good results + in the first training iterations, but that a smaller learning rate + is preferable later. + + This scheduler enables replaying these parameter schedules from + a finished PBT run. This requires that population based training has + been run with ``log_config=True``, which is the default setting. + + The scheduler will only accept and train a single trial. It will + start with the initial config of the existing trial and update the + config according to the schedule. + + Args: + policy_file (str): The PBT policy file. Usually this is + stored in ``~/ray_results/experiment_name/pbt_policy_xxx.txt`` + where ``xxx`` is the trial ID. + + Example: + + .. code-block:: python + + # Replaying a result from ray.tune.examples.pbt_convnet_example + from ray import tune + + from ray.tune.examples.pbt_convnet_example import PytorchTrainable + from ray.tune.schedulers import PopulationBasedTrainingReplay + + replay = PopulationBasedTrainingReplay( + "~/ray_results/pbt_test/pbt_policy_XXXXX_00001.txt") + + tune.run( + PytorchTrainable, + scheduler=replay, + stop={"training_iteration": 100}) + + + """ + + def __init__(self, policy_file): + policy_file = os.path.expanduser(policy_file) + if not os.path.exists(policy_file): + raise ValueError("Policy file not found: {}".format(policy_file)) + + self.policy_file = policy_file + + # Find and read pbt policy file, potentially raise error + initial_config, self._policy = self._load_policy(self.policy_file) + + self.experiment_tag = "replay_{}".format( + os.path.basename(self.policy_file)) + self.config = initial_config + self.current_config = self.config + + self._trial = None + self._current_step = 0 + self._num_perturbations = 0 + + self._policy_iter = iter(self._policy) + self._next_policy = next(self._policy_iter, None) + + def _load_policy(self, policy_file): + raw_policy = [] + with open(policy_file, "rt") as fp: + for row in fp.readlines(): + try: + parsed_row = json.loads(row) + except json.JSONDecodeError: + raise ValueError( + "Could not read PBT policy file: {}.".format( + policy_file)) from None + raw_policy.append(tuple(parsed_row)) + + # Loop through policy from end to start to obtain changepoints + policy = [] + last_new_tag = None + last_old_conf = None + for (old_tag, new_tag, old_step, new_step, old_conf, + new_conf) in reversed(raw_policy): + if last_new_tag and old_tag != last_new_tag: + # Tag chain ended. This means that previous changes were + # overwritten by the last change and should be ignored. + break + last_new_tag = new_tag + last_old_conf = old_conf + + policy.append((new_step, new_conf)) + + return last_old_conf, list(reversed(policy)) + + def on_trial_add(self, trial_runner, trial): + if self._trial: + raise ValueError( + "More than one trial added to PBT replay run. This " + "means the same schedule will be trained multiple " + "times. Do you want to set `n_samples=1`?") + self._trial = trial + if self._trial.config and self._policy: + logger.warning( + "Trial was initialized with a config, which was overwritten. " + "Did you start the PBT replay with a `config` parameter?") + elif self._trial.config and not self._policy: + # Only train with initial policy + self.config = self._trial.config + elif not self._trial.config and not self._policy: + raise ValueError( + "No replay policy found and trial initialized without a " + "valid config. Either pass a `config` argument to `tune.run()`" + "or consider not using PBT replay for this run.") + self._trial.config = self.config + + def on_trial_result(self, trial_runner, trial, result): + if TRAINING_ITERATION not in result: + # No time reported + return TrialScheduler.CONTINUE + + if not self._next_policy: + # No more changes in the config + return TrialScheduler.CONTINUE + + step = result[TRAINING_ITERATION] + self._current_step = step + + change_at, new_config = self._next_policy + + if step < change_at: + # Don't change the policy just yet + return TrialScheduler.CONTINUE + + logger.info("Population Based Training replay is now at step {}. " + "Configuration will be changed to {}.".format( + step, new_config)) + + checkpoint = trial_runner.trial_executor.save( + trial, Checkpoint.MEMORY, result=result) + + new_tag = make_experiment_tag(self.experiment_tag, new_config, + new_config) + + trial_executor = trial_runner.trial_executor + reset_successful = trial_executor.reset_trial(trial, new_config, + new_tag) + + if reset_successful: + trial_executor.restore(trial, checkpoint, block=True) + else: + trial_executor.stop_trial(trial, stop_logger=False) + trial.config = new_config + trial.experiment_tag = new_tag + trial_executor.start_trial(trial, checkpoint, train=False) + + self.current_config = new_config + self._num_perturbations += 1 + self._next_policy = next(self._policy_iter, None) + + return TrialScheduler.CONTINUE + + def debug_string(self): + return "PopulationBasedTraining replay: Step {}, perturb {}".format( + self._current_step, self._num_perturbations) diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index cf44ed736..d21f24d04 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -10,12 +10,13 @@ from unittest.mock import MagicMock import ray from ray import tune +from ray.tune import Trainable from ray.tune.result import TRAINING_ITERATION from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler, PopulationBasedTraining, MedianStoppingRule, TrialScheduler, HyperBandForBOHB) -from ray.tune.schedulers.pbt import explore +from ray.tune.schedulers.pbt import explore, PopulationBasedTrainingReplay from ray.tune.trial import Trial, Checkpoint from ray.tune.trial_executor import TrialExecutor from ray.tune.resources import Resources @@ -710,6 +711,7 @@ class PopulationBasedTestingSuite(unittest.TestCase): _register_all() # re-register the evicted objects def basicSetup(self, + num_trials=5, resample_prob=0.0, explore=None, perturbation_interval=10, @@ -731,7 +733,7 @@ class PopulationBasedTestingSuite(unittest.TestCase): custom_explore_fn=explore, log_config=log_config) runner = _MockTrialRunner(pbt) - for i in range(5): + for i in range(num_trials): trial_hyperparams = hyperparams or { "float_factor": 2.0, "const_factor": 3, @@ -1072,6 +1074,144 @@ class PopulationBasedTestingSuite(unittest.TestCase): check_policy(json.loads(line)) shutil.rmtree(tmpdir) + def testReplay(self): + pbt, runner = self.basicSetup( + num_trials=4, + perturbation_interval=5, + log_config=True, + step_once=False) + trials = runner.get_trials() + tmpdir = tempfile.mkdtemp() + + # Internal trial state to collect the real PBT history + class _TrialState: + def __init__(self, config): + self.step = 0 + self.config = config + self.history = [] + + def forward(self, t): + while self.step < t: + self.history.append(self.config) + self.step += 1 + + trial_state = [] + for i, trial in enumerate(trials): + trial.local_dir = tmpdir + trial.last_result = {TRAINING_ITERATION: 0} + trial_state.append(_TrialState(trial.config)) + + # Helper function to simulate stepping trial k a number of steps, + # and reporting a score at the end + def trial_step(k, steps, score): + res = result(trial_state[k].step + steps, score) + + trials[k].last_result = res + trial_state[k].forward(res[TRAINING_ITERATION]) + + old_config = trials[k].config + pbt.on_trial_result(runner, trials[k], res) + new_config = trials[k].config + trial_state[k].config = new_config.copy() + + if old_config != new_config: + # Copy history from source trial + source = -1 + for m, cand in enumerate(trials): + if cand.trainable_name == trials[k].restored_checkpoint: + source = m + break + assert source >= 0 + trial_state[k].history = trial_state[source].history.copy() + trial_state[k].step = trial_state[source].step + + # Initial steps + trial_step(0, 10, 0) + trial_step(1, 11, 10) + trial_step(2, 12, 0) + trial_step(3, 13, 0) + + # Next block + trial_step(0, 10, -10) # 0 <-- 1, new_t=11 + trial_step(2, 8, -20) # 2 <-- 1, new_t=11 + trial_step(3, 9, 0) + trial_step(1, 7, 0) + + # Next block + trial_step(1, 12, 0) + trial_step(2, 13, 0) + trial_step(3, 14, 10) + trial_step(0, 11, 0) # 0 <-- 3, new_t=13+9+14=36 + + # Next block + trial_step(0, 6, 20) + trial_step(3, 9, -40) # 3 <-- 0, new_t=42 + trial_step(2, 8, -50) # 2 <-- 0, new_t=42 + trial_step(1, 7, 30) + trial_step(2, 8, -60) # 2 <-- 1, new_t=37 + + # Next block + trial_step(0, 10, 0) + trial_step(1, 10, 0) + trial_step(2, 10, 0) + trial_step(3, 10, 0) + + # Playback trainable to collect configs at each step + class Playback(Trainable): + def setup(self, config): + self.config = config + self.replayed = [] + self.iter = 0 + + def step(self): + self.iter += 1 + self.replayed.append(self.config) + return { + "reward": 0, + "done": False, + "replayed": self.replayed, + TRAINING_ITERATION: self.iter + } + + def reset_config(self, new_config): + self.config = new_config + return True + + def save_checkpoint(self, tmp_checkpoint_dir): + return tmp_checkpoint_dir + + def load_checkpoint(self, checkpoint): + pass + + # Loop through all trials and check if PBT history is the + # same as the playback history + for i, trial in enumerate(trials): + if trial.trial_id == "1": # Did not exploit anything + continue + + replay = PopulationBasedTrainingReplay( + os.path.join(tmpdir, + "pbt_policy_{}.txt".format(trial.trial_id))) + analysis = tune.run( + Playback, + scheduler=replay, + stop={TRAINING_ITERATION: trial_state[i].step}) + + replayed = analysis.trials[0].last_result["replayed"] + self.assertSequenceEqual(trial_state[i].history, replayed) + + # Trial 1 did not exploit anything and should raise an error + with self.assertRaises(ValueError): + replay = PopulationBasedTrainingReplay( + os.path.join(tmpdir, + "pbt_policy_{}.txt".format(trials[1].trial_id))) + tune.run( + Playback, + scheduler=replay, + stop={TRAINING_ITERATION: trial_state[1].step}) + + shutil.rmtree(tmpdir) + def testPostprocessingHook(self): def explore(new_config): new_config["id_factor"] = 42