[tune] PBT replay utility scheduler (#9953)

Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
krfricke
2020-08-07 21:41:49 +02:00
committed by GitHub
parent 326a470bc2
commit 0ef8224446
6 changed files with 390 additions and 26 deletions
@@ -77,19 +77,19 @@ During the training, we can constantly check the status of the models from conso
.. code-block:: bash
== Status ==
Memory usage on this node: 10.4/16.0 GiB
PopulationBasedTraining: 4 checkpoints, 1 perturbs
Resources requested: 4/12 CPUs, 0/0 GPUs, 0.0/3.42 GiB heap, 0.0/1.17 GiB objects
Number of trials: 4 ({'RUNNING': 4})
Result logdir: /Users/yuhao.yang/ray_results/pbt_test
+--------------------------+----------+---------------------+----------+------------+--------+------------------+----------+
| Trial name | status | loc | lr | momentum | iter | total time (s) | acc |
|--------------------------+----------+---------------------+----------+------------+--------+------------------+----------|
| PytorchTrainble_3b42d914 | RUNNING | 30.57.180.224:49840 | 0.122032 | 0.302176 | 18 | 3.8689 | 0.8875 |
| PytorchTrainble_3b45091e | RUNNING | 30.57.180.224:49835 | 0.505325 | 0.628559 | 18 | 3.90404 | 0.134375 |
| PytorchTrainble_3b454c46 | RUNNING | 30.57.180.224:49843 | 0.490228 | 0.969013 | 17 | 3.72111 | 0.0875 |
| PytorchTrainble_3b458a9c | RUNNING | 30.57.180.224:49833 | 0.961861 | 0.169701 | 13 | 2.72594 | 0.1125 |
+--------------------------+----------+---------------------+----------+------------+--------+------------------+----------+
Memory usage on this node: 11.6/16.0 GiB
PopulationBasedTraining: 5 checkpoints, 4 perturbs
Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/3.96 GiB heap, 0.0/1.37 GiB objects
Result logdir: /Users/foo/ray_results/pbt_test
Number of trials: 4 (4 TERMINATED)
+------------------------------+------------+-------+-----------+------------+----------+--------+------------------+
| Trial name | status | loc | lr | momentum | acc | iter | total time (s) |
|------------------------------+------------+-------+-----------+------------+----------+--------+------------------|
| PytorchTrainable_ba982_00000 | TERMINATED | | 0.0457501 | 0.99 | 0.6375 | 25 | 5.35712 |
| PytorchTrainable_ba982_00001 | TERMINATED | | 0.175808 | 0.0667043 | 0.909375 | 29 | 6.18802 |
| PytorchTrainable_ba982_00002 | TERMINATED | | 0.21097 | 0.99 | 0.040625 | 29 | 6.19634 |
| PytorchTrainable_ba982_00003 | TERMINATED | | 0.0571876 | 0.852088 | 0.96875 | 30 | 6.37298 |
+------------------------------+------------+-------+-----------+------------+----------+--------+------------------+
In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in pbt_global.txt
and individual policy perturbations are recorded in pbt_policy_{i}.txt. Tune logs:
@@ -114,6 +114,38 @@ Checking the accuracy:
.. image:: /images/tune_advanced_plot1.png
.. _tune-advanced-tutorial-pbt-replay:
Replaying a PBT run
-------------------
A run of Population Based Training ends with fully trained models. However, sometimes
you might like to train the model from scratch, but use the same hyperparameter
schedule as obtained from PBT. Ray Tune offers a replay utility for this.
All you need to do is pass the policy log file for the trial you want to replay.
This is usually stored in the experiment directory, for instance
``~/ray_results/pbt_test/pbt_policy_ba982_00000.txt``.
The replay utility reads the original configuration for the trial and updates it
each time when it was originally perturbed. You can (and should)
thus just use the same ``Trainable`` for the replay run.
.. code-block:: python
from ray import tune
from ray.tune.examples.pbt_convnet_example import PytorchTrainable
from ray.tune.schedulers import PopulationBasedTrainingReplay
replay = PopulationBasedTrainingReplay(
"~/ray_results/pbt_test/pbt_policy_ba982_00003.txt")
tune.run(
PytorchTrainable,
scheduler=replay,
stop={"training_iteration": 100})
DCGAN with Trainable and PBT
----------------------------
+24
View File
@@ -148,6 +148,30 @@ You can run this :doc:`toy PBT example </tune/examples/pbt_function>` to get an
.. autoclass:: ray.tune.schedulers.PopulationBasedTraining
.. _tune-scheduler-pbt-replay:
Population Based Training Replay (tune.schedulers.PopulationBasedTrainingReplay)
--------------------------------------------------------------------------------
Tune includes a utility to replay hyperparameter schedules of Population Based Training runs.
You just specify an existing experiment directory and the ID of the trial you would
like to replay. The scheduler accepts only one trial, and it will update its
config according to the obtained schedule.
.. code-block:: python
replay = PopulationBasedTrainingReplay(
experiment_dir="~/ray_results/pbt_experiment/",
trial_id="XXXXX_00001")
tune.run(
...,
scheduler=replay)
See :ref:`here for an example <tune-advanced-tutorial-pbt-replay>` on how to use the
replay utility in practice.
.. autoclass:: ray.tune.schedulers.PopulationBasedTrainingReplay
.. _tune-scheduler-bohb:
BOHB (tune.schedulers.HyperBandForBOHB)
@@ -1,5 +1,8 @@
#!/usr/bin/env python
# flake8: noqa
# yapf: disable
# __tutorial_imports_begin__
import argparse
import os
@@ -15,12 +18,11 @@ from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.utils import validate_save_restore
from ray.tune.trial import ExportFormat
# __tutorial_imports_end__
# __trainable_begin__
class PytorchTrainble(tune.Trainable):
class PytorchTrainable(tune.Trainable):
"""Train a Pytorch ConvNet with Trainable and PopulationBasedTraining
scheduler. The example reuse some of the functions in mnist_pytorch,
and is a good demo for how to add the tuning function without
@@ -65,10 +67,9 @@ class PytorchTrainble(tune.Trainable):
self.config = new_config
return True
# __trainable_end__
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -79,8 +80,8 @@ if __name__ == "__main__":
datasets.MNIST("~/data", train=True, download=True)
# check if PytorchTrainble will save/restore correctly before execution
validate_save_restore(PytorchTrainble)
validate_save_restore(PytorchTrainble, use_object_store=True)
validate_save_restore(PytorchTrainable)
validate_save_restore(PytorchTrainable, use_object_store=True)
# __pbt_begin__
scheduler = PopulationBasedTraining(
@@ -94,7 +95,6 @@ if __name__ == "__main__":
# allow perturbations within this set of categorical values
"momentum": [0.8, 0.9, 0.99],
})
# __pbt_end__
# __tune_begin__
@@ -114,7 +114,7 @@ if __name__ == "__main__":
stopper = CustomStopper()
analysis = tune.run(
PytorchTrainble,
PytorchTrainable,
name="pbt_test",
scheduler=scheduler,
reuse_actors=True,
@@ -134,7 +134,7 @@ if __name__ == "__main__":
best_trial = analysis.get_best_trial("mean_accuracy")
best_checkpoint = max(
analysis.get_trial_checkpoints_paths(best_trial, "mean_accuracy"))
restored_trainable = PytorchTrainble()
restored_trainable = PytorchTrainable()
restored_trainable.restore(best_checkpoint[0])
best_model = restored_trainable.model
# Note that test only runs on a small random set of the test data, thus the
+4 -2
View File
@@ -4,10 +4,12 @@ from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.schedulers.async_hyperband import (AsyncHyperBandScheduler,
ASHAScheduler)
from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule
from ray.tune.schedulers.pbt import PopulationBasedTraining
from ray.tune.schedulers.pbt import (PopulationBasedTraining,
PopulationBasedTrainingReplay)
__all__ = [
"TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler",
"ASHAScheduler", "MedianStoppingRule", "FIFOScheduler",
"PopulationBasedTraining", "HyperBandForBOHB"
"PopulationBasedTraining", "PopulationBasedTrainingReplay",
"HyperBandForBOHB"
]
+166
View File
@@ -393,3 +393,169 @@ class PopulationBasedTraining(FIFOScheduler):
def debug_string(self):
return "PopulationBasedTraining: {} checkpoints, {} perturbs".format(
self._num_checkpoints, self._num_perturbations)
class PopulationBasedTrainingReplay(FIFOScheduler):
"""Replays a Population Based Training run.
Population Based Training does not return a single hyperparameter
configuration, but rather a schedule of configurations. For instance,
PBT might discover that a larger learning rate leads to good results
in the first training iterations, but that a smaller learning rate
is preferable later.
This scheduler enables replaying these parameter schedules from
a finished PBT run. This requires that population based training has
been run with ``log_config=True``, which is the default setting.
The scheduler will only accept and train a single trial. It will
start with the initial config of the existing trial and update the
config according to the schedule.
Args:
policy_file (str): The PBT policy file. Usually this is
stored in ``~/ray_results/experiment_name/pbt_policy_xxx.txt``
where ``xxx`` is the trial ID.
Example:
.. code-block:: python
# Replaying a result from ray.tune.examples.pbt_convnet_example
from ray import tune
from ray.tune.examples.pbt_convnet_example import PytorchTrainable
from ray.tune.schedulers import PopulationBasedTrainingReplay
replay = PopulationBasedTrainingReplay(
"~/ray_results/pbt_test/pbt_policy_XXXXX_00001.txt")
tune.run(
PytorchTrainable,
scheduler=replay,
stop={"training_iteration": 100})
"""
def __init__(self, policy_file):
policy_file = os.path.expanduser(policy_file)
if not os.path.exists(policy_file):
raise ValueError("Policy file not found: {}".format(policy_file))
self.policy_file = policy_file
# Find and read pbt policy file, potentially raise error
initial_config, self._policy = self._load_policy(self.policy_file)
self.experiment_tag = "replay_{}".format(
os.path.basename(self.policy_file))
self.config = initial_config
self.current_config = self.config
self._trial = None
self._current_step = 0
self._num_perturbations = 0
self._policy_iter = iter(self._policy)
self._next_policy = next(self._policy_iter, None)
def _load_policy(self, policy_file):
raw_policy = []
with open(policy_file, "rt") as fp:
for row in fp.readlines():
try:
parsed_row = json.loads(row)
except json.JSONDecodeError:
raise ValueError(
"Could not read PBT policy file: {}.".format(
policy_file)) from None
raw_policy.append(tuple(parsed_row))
# Loop through policy from end to start to obtain changepoints
policy = []
last_new_tag = None
last_old_conf = None
for (old_tag, new_tag, old_step, new_step, old_conf,
new_conf) in reversed(raw_policy):
if last_new_tag and old_tag != last_new_tag:
# Tag chain ended. This means that previous changes were
# overwritten by the last change and should be ignored.
break
last_new_tag = new_tag
last_old_conf = old_conf
policy.append((new_step, new_conf))
return last_old_conf, list(reversed(policy))
def on_trial_add(self, trial_runner, trial):
if self._trial:
raise ValueError(
"More than one trial added to PBT replay run. This "
"means the same schedule will be trained multiple "
"times. Do you want to set `n_samples=1`?")
self._trial = trial
if self._trial.config and self._policy:
logger.warning(
"Trial was initialized with a config, which was overwritten. "
"Did you start the PBT replay with a `config` parameter?")
elif self._trial.config and not self._policy:
# Only train with initial policy
self.config = self._trial.config
elif not self._trial.config and not self._policy:
raise ValueError(
"No replay policy found and trial initialized without a "
"valid config. Either pass a `config` argument to `tune.run()`"
"or consider not using PBT replay for this run.")
self._trial.config = self.config
def on_trial_result(self, trial_runner, trial, result):
if TRAINING_ITERATION not in result:
# No time reported
return TrialScheduler.CONTINUE
if not self._next_policy:
# No more changes in the config
return TrialScheduler.CONTINUE
step = result[TRAINING_ITERATION]
self._current_step = step
change_at, new_config = self._next_policy
if step < change_at:
# Don't change the policy just yet
return TrialScheduler.CONTINUE
logger.info("Population Based Training replay is now at step {}. "
"Configuration will be changed to {}.".format(
step, new_config))
checkpoint = trial_runner.trial_executor.save(
trial, Checkpoint.MEMORY, result=result)
new_tag = make_experiment_tag(self.experiment_tag, new_config,
new_config)
trial_executor = trial_runner.trial_executor
reset_successful = trial_executor.reset_trial(trial, new_config,
new_tag)
if reset_successful:
trial_executor.restore(trial, checkpoint, block=True)
else:
trial_executor.stop_trial(trial, stop_logger=False)
trial.config = new_config
trial.experiment_tag = new_tag
trial_executor.start_trial(trial, checkpoint, train=False)
self.current_config = new_config
self._num_perturbations += 1
self._next_policy = next(self._policy_iter, None)
return TrialScheduler.CONTINUE
def debug_string(self):
return "PopulationBasedTraining replay: Step {}, perturb {}".format(
self._current_step, self._num_perturbations)
+142 -2
View File
@@ -10,12 +10,13 @@ from unittest.mock import MagicMock
import ray
from ray import tune
from ray.tune import Trainable
from ray.tune.result import TRAINING_ITERATION
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
PopulationBasedTraining, MedianStoppingRule,
TrialScheduler, HyperBandForBOHB)
from ray.tune.schedulers.pbt import explore
from ray.tune.schedulers.pbt import explore, PopulationBasedTrainingReplay
from ray.tune.trial import Trial, Checkpoint
from ray.tune.trial_executor import TrialExecutor
from ray.tune.resources import Resources
@@ -710,6 +711,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
_register_all() # re-register the evicted objects
def basicSetup(self,
num_trials=5,
resample_prob=0.0,
explore=None,
perturbation_interval=10,
@@ -731,7 +733,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
custom_explore_fn=explore,
log_config=log_config)
runner = _MockTrialRunner(pbt)
for i in range(5):
for i in range(num_trials):
trial_hyperparams = hyperparams or {
"float_factor": 2.0,
"const_factor": 3,
@@ -1072,6 +1074,144 @@ class PopulationBasedTestingSuite(unittest.TestCase):
check_policy(json.loads(line))
shutil.rmtree(tmpdir)
def testReplay(self):
pbt, runner = self.basicSetup(
num_trials=4,
perturbation_interval=5,
log_config=True,
step_once=False)
trials = runner.get_trials()
tmpdir = tempfile.mkdtemp()
# Internal trial state to collect the real PBT history
class _TrialState:
def __init__(self, config):
self.step = 0
self.config = config
self.history = []
def forward(self, t):
while self.step < t:
self.history.append(self.config)
self.step += 1
trial_state = []
for i, trial in enumerate(trials):
trial.local_dir = tmpdir
trial.last_result = {TRAINING_ITERATION: 0}
trial_state.append(_TrialState(trial.config))
# Helper function to simulate stepping trial k a number of steps,
# and reporting a score at the end
def trial_step(k, steps, score):
res = result(trial_state[k].step + steps, score)
trials[k].last_result = res
trial_state[k].forward(res[TRAINING_ITERATION])
old_config = trials[k].config
pbt.on_trial_result(runner, trials[k], res)
new_config = trials[k].config
trial_state[k].config = new_config.copy()
if old_config != new_config:
# Copy history from source trial
source = -1
for m, cand in enumerate(trials):
if cand.trainable_name == trials[k].restored_checkpoint:
source = m
break
assert source >= 0
trial_state[k].history = trial_state[source].history.copy()
trial_state[k].step = trial_state[source].step
# Initial steps
trial_step(0, 10, 0)
trial_step(1, 11, 10)
trial_step(2, 12, 0)
trial_step(3, 13, 0)
# Next block
trial_step(0, 10, -10) # 0 <-- 1, new_t=11
trial_step(2, 8, -20) # 2 <-- 1, new_t=11
trial_step(3, 9, 0)
trial_step(1, 7, 0)
# Next block
trial_step(1, 12, 0)
trial_step(2, 13, 0)
trial_step(3, 14, 10)
trial_step(0, 11, 0) # 0 <-- 3, new_t=13+9+14=36
# Next block
trial_step(0, 6, 20)
trial_step(3, 9, -40) # 3 <-- 0, new_t=42
trial_step(2, 8, -50) # 2 <-- 0, new_t=42
trial_step(1, 7, 30)
trial_step(2, 8, -60) # 2 <-- 1, new_t=37
# Next block
trial_step(0, 10, 0)
trial_step(1, 10, 0)
trial_step(2, 10, 0)
trial_step(3, 10, 0)
# Playback trainable to collect configs at each step
class Playback(Trainable):
def setup(self, config):
self.config = config
self.replayed = []
self.iter = 0
def step(self):
self.iter += 1
self.replayed.append(self.config)
return {
"reward": 0,
"done": False,
"replayed": self.replayed,
TRAINING_ITERATION: self.iter
}
def reset_config(self, new_config):
self.config = new_config
return True
def save_checkpoint(self, tmp_checkpoint_dir):
return tmp_checkpoint_dir
def load_checkpoint(self, checkpoint):
pass
# Loop through all trials and check if PBT history is the
# same as the playback history
for i, trial in enumerate(trials):
if trial.trial_id == "1": # Did not exploit anything
continue
replay = PopulationBasedTrainingReplay(
os.path.join(tmpdir,
"pbt_policy_{}.txt".format(trial.trial_id)))
analysis = tune.run(
Playback,
scheduler=replay,
stop={TRAINING_ITERATION: trial_state[i].step})
replayed = analysis.trials[0].last_result["replayed"]
self.assertSequenceEqual(trial_state[i].history, replayed)
# Trial 1 did not exploit anything and should raise an error
with self.assertRaises(ValueError):
replay = PopulationBasedTrainingReplay(
os.path.join(tmpdir,
"pbt_policy_{}.txt".format(trials[1].trial_id)))
tune.run(
Playback,
scheduler=replay,
stop={TRAINING_ITERATION: trial_state[1].step})
shutil.rmtree(tmpdir)
def testPostprocessingHook(self):
def explore(new_config):
new_config["id_factor"] = 42