mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:00:10 +08:00
[tune] PBT replay utility scheduler (#9953)
Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# flake8: noqa
|
||||
# yapf: disable
|
||||
|
||||
# __tutorial_imports_begin__
|
||||
import argparse
|
||||
import os
|
||||
@@ -15,12 +18,11 @@ from ray import tune
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
from ray.tune.utils import validate_save_restore
|
||||
from ray.tune.trial import ExportFormat
|
||||
|
||||
# __tutorial_imports_end__
|
||||
|
||||
|
||||
# __trainable_begin__
|
||||
class PytorchTrainble(tune.Trainable):
|
||||
class PytorchTrainable(tune.Trainable):
|
||||
"""Train a Pytorch ConvNet with Trainable and PopulationBasedTraining
|
||||
scheduler. The example reuse some of the functions in mnist_pytorch,
|
||||
and is a good demo for how to add the tuning function without
|
||||
@@ -65,10 +67,9 @@ class PytorchTrainble(tune.Trainable):
|
||||
|
||||
self.config = new_config
|
||||
return True
|
||||
|
||||
|
||||
# __trainable_end__
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -79,8 +80,8 @@ if __name__ == "__main__":
|
||||
datasets.MNIST("~/data", train=True, download=True)
|
||||
|
||||
# check if PytorchTrainble will save/restore correctly before execution
|
||||
validate_save_restore(PytorchTrainble)
|
||||
validate_save_restore(PytorchTrainble, use_object_store=True)
|
||||
validate_save_restore(PytorchTrainable)
|
||||
validate_save_restore(PytorchTrainable, use_object_store=True)
|
||||
|
||||
# __pbt_begin__
|
||||
scheduler = PopulationBasedTraining(
|
||||
@@ -94,7 +95,6 @@ if __name__ == "__main__":
|
||||
# allow perturbations within this set of categorical values
|
||||
"momentum": [0.8, 0.9, 0.99],
|
||||
})
|
||||
|
||||
# __pbt_end__
|
||||
|
||||
# __tune_begin__
|
||||
@@ -114,7 +114,7 @@ if __name__ == "__main__":
|
||||
stopper = CustomStopper()
|
||||
|
||||
analysis = tune.run(
|
||||
PytorchTrainble,
|
||||
PytorchTrainable,
|
||||
name="pbt_test",
|
||||
scheduler=scheduler,
|
||||
reuse_actors=True,
|
||||
@@ -134,7 +134,7 @@ if __name__ == "__main__":
|
||||
best_trial = analysis.get_best_trial("mean_accuracy")
|
||||
best_checkpoint = max(
|
||||
analysis.get_trial_checkpoints_paths(best_trial, "mean_accuracy"))
|
||||
restored_trainable = PytorchTrainble()
|
||||
restored_trainable = PytorchTrainable()
|
||||
restored_trainable.restore(best_checkpoint[0])
|
||||
best_model = restored_trainable.model
|
||||
# Note that test only runs on a small random set of the test data, thus the
|
||||
|
||||
@@ -4,10 +4,12 @@ from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
|
||||
from ray.tune.schedulers.async_hyperband import (AsyncHyperBandScheduler,
|
||||
ASHAScheduler)
|
||||
from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.schedulers.pbt import PopulationBasedTraining
|
||||
from ray.tune.schedulers.pbt import (PopulationBasedTraining,
|
||||
PopulationBasedTrainingReplay)
|
||||
|
||||
__all__ = [
|
||||
"TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler",
|
||||
"ASHAScheduler", "MedianStoppingRule", "FIFOScheduler",
|
||||
"PopulationBasedTraining", "HyperBandForBOHB"
|
||||
"PopulationBasedTraining", "PopulationBasedTrainingReplay",
|
||||
"HyperBandForBOHB"
|
||||
]
|
||||
|
||||
@@ -393,3 +393,169 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
def debug_string(self):
|
||||
return "PopulationBasedTraining: {} checkpoints, {} perturbs".format(
|
||||
self._num_checkpoints, self._num_perturbations)
|
||||
|
||||
|
||||
class PopulationBasedTrainingReplay(FIFOScheduler):
|
||||
"""Replays a Population Based Training run.
|
||||
|
||||
Population Based Training does not return a single hyperparameter
|
||||
configuration, but rather a schedule of configurations. For instance,
|
||||
PBT might discover that a larger learning rate leads to good results
|
||||
in the first training iterations, but that a smaller learning rate
|
||||
is preferable later.
|
||||
|
||||
This scheduler enables replaying these parameter schedules from
|
||||
a finished PBT run. This requires that population based training has
|
||||
been run with ``log_config=True``, which is the default setting.
|
||||
|
||||
The scheduler will only accept and train a single trial. It will
|
||||
start with the initial config of the existing trial and update the
|
||||
config according to the schedule.
|
||||
|
||||
Args:
|
||||
policy_file (str): The PBT policy file. Usually this is
|
||||
stored in ``~/ray_results/experiment_name/pbt_policy_xxx.txt``
|
||||
where ``xxx`` is the trial ID.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Replaying a result from ray.tune.examples.pbt_convnet_example
|
||||
from ray import tune
|
||||
|
||||
from ray.tune.examples.pbt_convnet_example import PytorchTrainable
|
||||
from ray.tune.schedulers import PopulationBasedTrainingReplay
|
||||
|
||||
replay = PopulationBasedTrainingReplay(
|
||||
"~/ray_results/pbt_test/pbt_policy_XXXXX_00001.txt")
|
||||
|
||||
tune.run(
|
||||
PytorchTrainable,
|
||||
scheduler=replay,
|
||||
stop={"training_iteration": 100})
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, policy_file):
|
||||
policy_file = os.path.expanduser(policy_file)
|
||||
if not os.path.exists(policy_file):
|
||||
raise ValueError("Policy file not found: {}".format(policy_file))
|
||||
|
||||
self.policy_file = policy_file
|
||||
|
||||
# Find and read pbt policy file, potentially raise error
|
||||
initial_config, self._policy = self._load_policy(self.policy_file)
|
||||
|
||||
self.experiment_tag = "replay_{}".format(
|
||||
os.path.basename(self.policy_file))
|
||||
self.config = initial_config
|
||||
self.current_config = self.config
|
||||
|
||||
self._trial = None
|
||||
self._current_step = 0
|
||||
self._num_perturbations = 0
|
||||
|
||||
self._policy_iter = iter(self._policy)
|
||||
self._next_policy = next(self._policy_iter, None)
|
||||
|
||||
def _load_policy(self, policy_file):
|
||||
raw_policy = []
|
||||
with open(policy_file, "rt") as fp:
|
||||
for row in fp.readlines():
|
||||
try:
|
||||
parsed_row = json.loads(row)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(
|
||||
"Could not read PBT policy file: {}.".format(
|
||||
policy_file)) from None
|
||||
raw_policy.append(tuple(parsed_row))
|
||||
|
||||
# Loop through policy from end to start to obtain changepoints
|
||||
policy = []
|
||||
last_new_tag = None
|
||||
last_old_conf = None
|
||||
for (old_tag, new_tag, old_step, new_step, old_conf,
|
||||
new_conf) in reversed(raw_policy):
|
||||
if last_new_tag and old_tag != last_new_tag:
|
||||
# Tag chain ended. This means that previous changes were
|
||||
# overwritten by the last change and should be ignored.
|
||||
break
|
||||
last_new_tag = new_tag
|
||||
last_old_conf = old_conf
|
||||
|
||||
policy.append((new_step, new_conf))
|
||||
|
||||
return last_old_conf, list(reversed(policy))
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
if self._trial:
|
||||
raise ValueError(
|
||||
"More than one trial added to PBT replay run. This "
|
||||
"means the same schedule will be trained multiple "
|
||||
"times. Do you want to set `n_samples=1`?")
|
||||
self._trial = trial
|
||||
if self._trial.config and self._policy:
|
||||
logger.warning(
|
||||
"Trial was initialized with a config, which was overwritten. "
|
||||
"Did you start the PBT replay with a `config` parameter?")
|
||||
elif self._trial.config and not self._policy:
|
||||
# Only train with initial policy
|
||||
self.config = self._trial.config
|
||||
elif not self._trial.config and not self._policy:
|
||||
raise ValueError(
|
||||
"No replay policy found and trial initialized without a "
|
||||
"valid config. Either pass a `config` argument to `tune.run()`"
|
||||
"or consider not using PBT replay for this run.")
|
||||
self._trial.config = self.config
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
if TRAINING_ITERATION not in result:
|
||||
# No time reported
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
if not self._next_policy:
|
||||
# No more changes in the config
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
step = result[TRAINING_ITERATION]
|
||||
self._current_step = step
|
||||
|
||||
change_at, new_config = self._next_policy
|
||||
|
||||
if step < change_at:
|
||||
# Don't change the policy just yet
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
logger.info("Population Based Training replay is now at step {}. "
|
||||
"Configuration will be changed to {}.".format(
|
||||
step, new_config))
|
||||
|
||||
checkpoint = trial_runner.trial_executor.save(
|
||||
trial, Checkpoint.MEMORY, result=result)
|
||||
|
||||
new_tag = make_experiment_tag(self.experiment_tag, new_config,
|
||||
new_config)
|
||||
|
||||
trial_executor = trial_runner.trial_executor
|
||||
reset_successful = trial_executor.reset_trial(trial, new_config,
|
||||
new_tag)
|
||||
|
||||
if reset_successful:
|
||||
trial_executor.restore(trial, checkpoint, block=True)
|
||||
else:
|
||||
trial_executor.stop_trial(trial, stop_logger=False)
|
||||
trial.config = new_config
|
||||
trial.experiment_tag = new_tag
|
||||
trial_executor.start_trial(trial, checkpoint, train=False)
|
||||
|
||||
self.current_config = new_config
|
||||
self._num_perturbations += 1
|
||||
self._next_policy = next(self._policy_iter, None)
|
||||
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def debug_string(self):
|
||||
return "PopulationBasedTraining replay: Step {}, perturb {}".format(
|
||||
self._current_step, self._num_perturbations)
|
||||
|
||||
@@ -10,12 +10,13 @@ from unittest.mock import MagicMock
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
|
||||
PopulationBasedTraining, MedianStoppingRule,
|
||||
TrialScheduler, HyperBandForBOHB)
|
||||
|
||||
from ray.tune.schedulers.pbt import explore
|
||||
from ray.tune.schedulers.pbt import explore, PopulationBasedTrainingReplay
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.trial_executor import TrialExecutor
|
||||
from ray.tune.resources import Resources
|
||||
@@ -710,6 +711,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def basicSetup(self,
|
||||
num_trials=5,
|
||||
resample_prob=0.0,
|
||||
explore=None,
|
||||
perturbation_interval=10,
|
||||
@@ -731,7 +733,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
custom_explore_fn=explore,
|
||||
log_config=log_config)
|
||||
runner = _MockTrialRunner(pbt)
|
||||
for i in range(5):
|
||||
for i in range(num_trials):
|
||||
trial_hyperparams = hyperparams or {
|
||||
"float_factor": 2.0,
|
||||
"const_factor": 3,
|
||||
@@ -1072,6 +1074,144 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
check_policy(json.loads(line))
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testReplay(self):
|
||||
pbt, runner = self.basicSetup(
|
||||
num_trials=4,
|
||||
perturbation_interval=5,
|
||||
log_config=True,
|
||||
step_once=False)
|
||||
trials = runner.get_trials()
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
# Internal trial state to collect the real PBT history
|
||||
class _TrialState:
|
||||
def __init__(self, config):
|
||||
self.step = 0
|
||||
self.config = config
|
||||
self.history = []
|
||||
|
||||
def forward(self, t):
|
||||
while self.step < t:
|
||||
self.history.append(self.config)
|
||||
self.step += 1
|
||||
|
||||
trial_state = []
|
||||
for i, trial in enumerate(trials):
|
||||
trial.local_dir = tmpdir
|
||||
trial.last_result = {TRAINING_ITERATION: 0}
|
||||
trial_state.append(_TrialState(trial.config))
|
||||
|
||||
# Helper function to simulate stepping trial k a number of steps,
|
||||
# and reporting a score at the end
|
||||
def trial_step(k, steps, score):
|
||||
res = result(trial_state[k].step + steps, score)
|
||||
|
||||
trials[k].last_result = res
|
||||
trial_state[k].forward(res[TRAINING_ITERATION])
|
||||
|
||||
old_config = trials[k].config
|
||||
pbt.on_trial_result(runner, trials[k], res)
|
||||
new_config = trials[k].config
|
||||
trial_state[k].config = new_config.copy()
|
||||
|
||||
if old_config != new_config:
|
||||
# Copy history from source trial
|
||||
source = -1
|
||||
for m, cand in enumerate(trials):
|
||||
if cand.trainable_name == trials[k].restored_checkpoint:
|
||||
source = m
|
||||
break
|
||||
assert source >= 0
|
||||
trial_state[k].history = trial_state[source].history.copy()
|
||||
trial_state[k].step = trial_state[source].step
|
||||
|
||||
# Initial steps
|
||||
trial_step(0, 10, 0)
|
||||
trial_step(1, 11, 10)
|
||||
trial_step(2, 12, 0)
|
||||
trial_step(3, 13, 0)
|
||||
|
||||
# Next block
|
||||
trial_step(0, 10, -10) # 0 <-- 1, new_t=11
|
||||
trial_step(2, 8, -20) # 2 <-- 1, new_t=11
|
||||
trial_step(3, 9, 0)
|
||||
trial_step(1, 7, 0)
|
||||
|
||||
# Next block
|
||||
trial_step(1, 12, 0)
|
||||
trial_step(2, 13, 0)
|
||||
trial_step(3, 14, 10)
|
||||
trial_step(0, 11, 0) # 0 <-- 3, new_t=13+9+14=36
|
||||
|
||||
# Next block
|
||||
trial_step(0, 6, 20)
|
||||
trial_step(3, 9, -40) # 3 <-- 0, new_t=42
|
||||
trial_step(2, 8, -50) # 2 <-- 0, new_t=42
|
||||
trial_step(1, 7, 30)
|
||||
trial_step(2, 8, -60) # 2 <-- 1, new_t=37
|
||||
|
||||
# Next block
|
||||
trial_step(0, 10, 0)
|
||||
trial_step(1, 10, 0)
|
||||
trial_step(2, 10, 0)
|
||||
trial_step(3, 10, 0)
|
||||
|
||||
# Playback trainable to collect configs at each step
|
||||
class Playback(Trainable):
|
||||
def setup(self, config):
|
||||
self.config = config
|
||||
self.replayed = []
|
||||
self.iter = 0
|
||||
|
||||
def step(self):
|
||||
self.iter += 1
|
||||
self.replayed.append(self.config)
|
||||
return {
|
||||
"reward": 0,
|
||||
"done": False,
|
||||
"replayed": self.replayed,
|
||||
TRAINING_ITERATION: self.iter
|
||||
}
|
||||
|
||||
def reset_config(self, new_config):
|
||||
self.config = new_config
|
||||
return True
|
||||
|
||||
def save_checkpoint(self, tmp_checkpoint_dir):
|
||||
return tmp_checkpoint_dir
|
||||
|
||||
def load_checkpoint(self, checkpoint):
|
||||
pass
|
||||
|
||||
# Loop through all trials and check if PBT history is the
|
||||
# same as the playback history
|
||||
for i, trial in enumerate(trials):
|
||||
if trial.trial_id == "1": # Did not exploit anything
|
||||
continue
|
||||
|
||||
replay = PopulationBasedTrainingReplay(
|
||||
os.path.join(tmpdir,
|
||||
"pbt_policy_{}.txt".format(trial.trial_id)))
|
||||
analysis = tune.run(
|
||||
Playback,
|
||||
scheduler=replay,
|
||||
stop={TRAINING_ITERATION: trial_state[i].step})
|
||||
|
||||
replayed = analysis.trials[0].last_result["replayed"]
|
||||
self.assertSequenceEqual(trial_state[i].history, replayed)
|
||||
|
||||
# Trial 1 did not exploit anything and should raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
replay = PopulationBasedTrainingReplay(
|
||||
os.path.join(tmpdir,
|
||||
"pbt_policy_{}.txt".format(trials[1].trial_id)))
|
||||
tune.run(
|
||||
Playback,
|
||||
scheduler=replay,
|
||||
stop={TRAINING_ITERATION: trial_state[1].step})
|
||||
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testPostprocessingHook(self):
|
||||
def explore(new_config):
|
||||
new_config["id_factor"] = 42
|
||||
|
||||
Reference in New Issue
Block a user