[tune] PBT replay utility scheduler (#9953)

Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
krfricke
2020-08-07 21:41:49 +02:00
committed by GitHub
parent 326a470bc2
commit 0ef8224446
6 changed files with 390 additions and 26 deletions
@@ -1,5 +1,8 @@
#!/usr/bin/env python
# flake8: noqa
# yapf: disable
# __tutorial_imports_begin__
import argparse
import os
@@ -15,12 +18,11 @@ from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.utils import validate_save_restore
from ray.tune.trial import ExportFormat
# __tutorial_imports_end__
# __trainable_begin__
class PytorchTrainble(tune.Trainable):
class PytorchTrainable(tune.Trainable):
"""Train a Pytorch ConvNet with Trainable and PopulationBasedTraining
scheduler. The example reuse some of the functions in mnist_pytorch,
and is a good demo for how to add the tuning function without
@@ -65,10 +67,9 @@ class PytorchTrainble(tune.Trainable):
self.config = new_config
return True
# __trainable_end__
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -79,8 +80,8 @@ if __name__ == "__main__":
datasets.MNIST("~/data", train=True, download=True)
# check if PytorchTrainble will save/restore correctly before execution
validate_save_restore(PytorchTrainble)
validate_save_restore(PytorchTrainble, use_object_store=True)
validate_save_restore(PytorchTrainable)
validate_save_restore(PytorchTrainable, use_object_store=True)
# __pbt_begin__
scheduler = PopulationBasedTraining(
@@ -94,7 +95,6 @@ if __name__ == "__main__":
# allow perturbations within this set of categorical values
"momentum": [0.8, 0.9, 0.99],
})
# __pbt_end__
# __tune_begin__
@@ -114,7 +114,7 @@ if __name__ == "__main__":
stopper = CustomStopper()
analysis = tune.run(
PytorchTrainble,
PytorchTrainable,
name="pbt_test",
scheduler=scheduler,
reuse_actors=True,
@@ -134,7 +134,7 @@ if __name__ == "__main__":
best_trial = analysis.get_best_trial("mean_accuracy")
best_checkpoint = max(
analysis.get_trial_checkpoints_paths(best_trial, "mean_accuracy"))
restored_trainable = PytorchTrainble()
restored_trainable = PytorchTrainable()
restored_trainable.restore(best_checkpoint[0])
best_model = restored_trainable.model
# Note that test only runs on a small random set of the test data, thus the
+4 -2
View File
@@ -4,10 +4,12 @@ from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.schedulers.async_hyperband import (AsyncHyperBandScheduler,
ASHAScheduler)
from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule
from ray.tune.schedulers.pbt import PopulationBasedTraining
from ray.tune.schedulers.pbt import (PopulationBasedTraining,
PopulationBasedTrainingReplay)
__all__ = [
"TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler",
"ASHAScheduler", "MedianStoppingRule", "FIFOScheduler",
"PopulationBasedTraining", "HyperBandForBOHB"
"PopulationBasedTraining", "PopulationBasedTrainingReplay",
"HyperBandForBOHB"
]
+166
View File
@@ -393,3 +393,169 @@ class PopulationBasedTraining(FIFOScheduler):
def debug_string(self):
return "PopulationBasedTraining: {} checkpoints, {} perturbs".format(
self._num_checkpoints, self._num_perturbations)
class PopulationBasedTrainingReplay(FIFOScheduler):
"""Replays a Population Based Training run.
Population Based Training does not return a single hyperparameter
configuration, but rather a schedule of configurations. For instance,
PBT might discover that a larger learning rate leads to good results
in the first training iterations, but that a smaller learning rate
is preferable later.
This scheduler enables replaying these parameter schedules from
a finished PBT run. This requires that population based training has
been run with ``log_config=True``, which is the default setting.
The scheduler will only accept and train a single trial. It will
start with the initial config of the existing trial and update the
config according to the schedule.
Args:
policy_file (str): The PBT policy file. Usually this is
stored in ``~/ray_results/experiment_name/pbt_policy_xxx.txt``
where ``xxx`` is the trial ID.
Example:
.. code-block:: python
# Replaying a result from ray.tune.examples.pbt_convnet_example
from ray import tune
from ray.tune.examples.pbt_convnet_example import PytorchTrainable
from ray.tune.schedulers import PopulationBasedTrainingReplay
replay = PopulationBasedTrainingReplay(
"~/ray_results/pbt_test/pbt_policy_XXXXX_00001.txt")
tune.run(
PytorchTrainable,
scheduler=replay,
stop={"training_iteration": 100})
"""
def __init__(self, policy_file):
policy_file = os.path.expanduser(policy_file)
if not os.path.exists(policy_file):
raise ValueError("Policy file not found: {}".format(policy_file))
self.policy_file = policy_file
# Find and read pbt policy file, potentially raise error
initial_config, self._policy = self._load_policy(self.policy_file)
self.experiment_tag = "replay_{}".format(
os.path.basename(self.policy_file))
self.config = initial_config
self.current_config = self.config
self._trial = None
self._current_step = 0
self._num_perturbations = 0
self._policy_iter = iter(self._policy)
self._next_policy = next(self._policy_iter, None)
def _load_policy(self, policy_file):
raw_policy = []
with open(policy_file, "rt") as fp:
for row in fp.readlines():
try:
parsed_row = json.loads(row)
except json.JSONDecodeError:
raise ValueError(
"Could not read PBT policy file: {}.".format(
policy_file)) from None
raw_policy.append(tuple(parsed_row))
# Loop through policy from end to start to obtain changepoints
policy = []
last_new_tag = None
last_old_conf = None
for (old_tag, new_tag, old_step, new_step, old_conf,
new_conf) in reversed(raw_policy):
if last_new_tag and old_tag != last_new_tag:
# Tag chain ended. This means that previous changes were
# overwritten by the last change and should be ignored.
break
last_new_tag = new_tag
last_old_conf = old_conf
policy.append((new_step, new_conf))
return last_old_conf, list(reversed(policy))
def on_trial_add(self, trial_runner, trial):
if self._trial:
raise ValueError(
"More than one trial added to PBT replay run. This "
"means the same schedule will be trained multiple "
"times. Do you want to set `n_samples=1`?")
self._trial = trial
if self._trial.config and self._policy:
logger.warning(
"Trial was initialized with a config, which was overwritten. "
"Did you start the PBT replay with a `config` parameter?")
elif self._trial.config and not self._policy:
# Only train with initial policy
self.config = self._trial.config
elif not self._trial.config and not self._policy:
raise ValueError(
"No replay policy found and trial initialized without a "
"valid config. Either pass a `config` argument to `tune.run()`"
"or consider not using PBT replay for this run.")
self._trial.config = self.config
def on_trial_result(self, trial_runner, trial, result):
if TRAINING_ITERATION not in result:
# No time reported
return TrialScheduler.CONTINUE
if not self._next_policy:
# No more changes in the config
return TrialScheduler.CONTINUE
step = result[TRAINING_ITERATION]
self._current_step = step
change_at, new_config = self._next_policy
if step < change_at:
# Don't change the policy just yet
return TrialScheduler.CONTINUE
logger.info("Population Based Training replay is now at step {}. "
"Configuration will be changed to {}.".format(
step, new_config))
checkpoint = trial_runner.trial_executor.save(
trial, Checkpoint.MEMORY, result=result)
new_tag = make_experiment_tag(self.experiment_tag, new_config,
new_config)
trial_executor = trial_runner.trial_executor
reset_successful = trial_executor.reset_trial(trial, new_config,
new_tag)
if reset_successful:
trial_executor.restore(trial, checkpoint, block=True)
else:
trial_executor.stop_trial(trial, stop_logger=False)
trial.config = new_config
trial.experiment_tag = new_tag
trial_executor.start_trial(trial, checkpoint, train=False)
self.current_config = new_config
self._num_perturbations += 1
self._next_policy = next(self._policy_iter, None)
return TrialScheduler.CONTINUE
def debug_string(self):
return "PopulationBasedTraining replay: Step {}, perturb {}".format(
self._current_step, self._num_perturbations)
+142 -2
View File
@@ -10,12 +10,13 @@ from unittest.mock import MagicMock
import ray
from ray import tune
from ray.tune import Trainable
from ray.tune.result import TRAINING_ITERATION
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
PopulationBasedTraining, MedianStoppingRule,
TrialScheduler, HyperBandForBOHB)
from ray.tune.schedulers.pbt import explore
from ray.tune.schedulers.pbt import explore, PopulationBasedTrainingReplay
from ray.tune.trial import Trial, Checkpoint
from ray.tune.trial_executor import TrialExecutor
from ray.tune.resources import Resources
@@ -710,6 +711,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
_register_all() # re-register the evicted objects
def basicSetup(self,
num_trials=5,
resample_prob=0.0,
explore=None,
perturbation_interval=10,
@@ -731,7 +733,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
custom_explore_fn=explore,
log_config=log_config)
runner = _MockTrialRunner(pbt)
for i in range(5):
for i in range(num_trials):
trial_hyperparams = hyperparams or {
"float_factor": 2.0,
"const_factor": 3,
@@ -1072,6 +1074,144 @@ class PopulationBasedTestingSuite(unittest.TestCase):
check_policy(json.loads(line))
shutil.rmtree(tmpdir)
def testReplay(self):
pbt, runner = self.basicSetup(
num_trials=4,
perturbation_interval=5,
log_config=True,
step_once=False)
trials = runner.get_trials()
tmpdir = tempfile.mkdtemp()
# Internal trial state to collect the real PBT history
class _TrialState:
def __init__(self, config):
self.step = 0
self.config = config
self.history = []
def forward(self, t):
while self.step < t:
self.history.append(self.config)
self.step += 1
trial_state = []
for i, trial in enumerate(trials):
trial.local_dir = tmpdir
trial.last_result = {TRAINING_ITERATION: 0}
trial_state.append(_TrialState(trial.config))
# Helper function to simulate stepping trial k a number of steps,
# and reporting a score at the end
def trial_step(k, steps, score):
res = result(trial_state[k].step + steps, score)
trials[k].last_result = res
trial_state[k].forward(res[TRAINING_ITERATION])
old_config = trials[k].config
pbt.on_trial_result(runner, trials[k], res)
new_config = trials[k].config
trial_state[k].config = new_config.copy()
if old_config != new_config:
# Copy history from source trial
source = -1
for m, cand in enumerate(trials):
if cand.trainable_name == trials[k].restored_checkpoint:
source = m
break
assert source >= 0
trial_state[k].history = trial_state[source].history.copy()
trial_state[k].step = trial_state[source].step
# Initial steps
trial_step(0, 10, 0)
trial_step(1, 11, 10)
trial_step(2, 12, 0)
trial_step(3, 13, 0)
# Next block
trial_step(0, 10, -10) # 0 <-- 1, new_t=11
trial_step(2, 8, -20) # 2 <-- 1, new_t=11
trial_step(3, 9, 0)
trial_step(1, 7, 0)
# Next block
trial_step(1, 12, 0)
trial_step(2, 13, 0)
trial_step(3, 14, 10)
trial_step(0, 11, 0) # 0 <-- 3, new_t=13+9+14=36
# Next block
trial_step(0, 6, 20)
trial_step(3, 9, -40) # 3 <-- 0, new_t=42
trial_step(2, 8, -50) # 2 <-- 0, new_t=42
trial_step(1, 7, 30)
trial_step(2, 8, -60) # 2 <-- 1, new_t=37
# Next block
trial_step(0, 10, 0)
trial_step(1, 10, 0)
trial_step(2, 10, 0)
trial_step(3, 10, 0)
# Playback trainable to collect configs at each step
class Playback(Trainable):
def setup(self, config):
self.config = config
self.replayed = []
self.iter = 0
def step(self):
self.iter += 1
self.replayed.append(self.config)
return {
"reward": 0,
"done": False,
"replayed": self.replayed,
TRAINING_ITERATION: self.iter
}
def reset_config(self, new_config):
self.config = new_config
return True
def save_checkpoint(self, tmp_checkpoint_dir):
return tmp_checkpoint_dir
def load_checkpoint(self, checkpoint):
pass
# Loop through all trials and check if PBT history is the
# same as the playback history
for i, trial in enumerate(trials):
if trial.trial_id == "1": # Did not exploit anything
continue
replay = PopulationBasedTrainingReplay(
os.path.join(tmpdir,
"pbt_policy_{}.txt".format(trial.trial_id)))
analysis = tune.run(
Playback,
scheduler=replay,
stop={TRAINING_ITERATION: trial_state[i].step})
replayed = analysis.trials[0].last_result["replayed"]
self.assertSequenceEqual(trial_state[i].history, replayed)
# Trial 1 did not exploit anything and should raise an error
with self.assertRaises(ValueError):
replay = PopulationBasedTrainingReplay(
os.path.join(tmpdir,
"pbt_policy_{}.txt".format(trials[1].trial_id)))
tune.run(
Playback,
scheduler=replay,
stop={TRAINING_ITERATION: trial_state[1].step})
shutil.rmtree(tmpdir)
def testPostprocessingHook(self):
def explore(new_config):
new_config["id_factor"] = 42