mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:30:45 +08:00
[Tune] Synchronous Mode for PBT (#10283)
This commit is contained in:
@@ -27,10 +27,12 @@ class PBTTrialState:
|
||||
self.last_score = None
|
||||
self.last_checkpoint = None
|
||||
self.last_perturbation_time = 0
|
||||
self.last_train_time = 0 # Used for synchronous mode.
|
||||
self.last_result = None # Used for synchronous mode.
|
||||
|
||||
def __repr__(self):
|
||||
return str((self.last_score, self.last_checkpoint,
|
||||
self.last_perturbation_time))
|
||||
self.last_train_time, self.last_perturbation_time))
|
||||
|
||||
|
||||
def explore(config, mutations, resample_probability, custom_explore_fn):
|
||||
@@ -174,6 +176,13 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
require_attrs (bool): Whether to require time_attr and metric to appear
|
||||
in result for every iteration. If True, error will be raised
|
||||
if these values are not present in trial result.
|
||||
synch (bool): If False, will use asynchronous implementation of
|
||||
PBT. Trial perturbations occur every perturbation_interval for each
|
||||
trial independently. If True, will use synchronous implementation
|
||||
of PBT. Perturbations will occur only after all trials are
|
||||
synced at the same time_attr every perturbation_interval.
|
||||
Defaults to False. See Appendix A.1 here
|
||||
https://arxiv.org/pdf/1711.09846.pdf.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -215,7 +224,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
resample_probability=0.25,
|
||||
custom_explore_fn=None,
|
||||
log_config=True,
|
||||
require_attrs=True):
|
||||
require_attrs=True,
|
||||
synch=False):
|
||||
for value in hyperparam_mutations.values():
|
||||
if not (isinstance(value,
|
||||
(list, dict, sample_from)) or callable(value)):
|
||||
@@ -234,10 +244,15 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
"`custom_explore_fn` to use PBT.")
|
||||
|
||||
if quantile_fraction > 0.5 or quantile_fraction < 0:
|
||||
raise TuneError(
|
||||
raise ValueError(
|
||||
"You must set `quantile_fraction` to a value between 0 and"
|
||||
"0.5. Current value: '{}'".format(quantile_fraction))
|
||||
|
||||
if perturbation_interval <= 0:
|
||||
raise ValueError(
|
||||
"perturbation_interval must be a positive number greater "
|
||||
"than 0. Current value: '{}'".format(perturbation_interval))
|
||||
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
|
||||
if reward_attr is not None:
|
||||
@@ -263,6 +278,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
self._custom_explore_fn = custom_explore_fn
|
||||
self._log_config = log_config
|
||||
self._require_attrs = require_attrs
|
||||
self._synch = synch
|
||||
self._next_perturbation_sync = self._perturbation_interval
|
||||
|
||||
# Metrics
|
||||
self._num_checkpoints = 0
|
||||
@@ -320,34 +337,90 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
time = result[self._time_attr]
|
||||
state = self._trial_state[trial]
|
||||
|
||||
# Continue training if perturbation interval has not been reached yet.
|
||||
if time - state.last_perturbation_time < self._perturbation_interval:
|
||||
return TrialScheduler.CONTINUE # avoid checkpoint overhead
|
||||
|
||||
# This trial has reached its perturbation interval
|
||||
score = self._metric_op * result[self._metric]
|
||||
state.last_score = score
|
||||
state.last_perturbation_time = time
|
||||
lower_quantile, upper_quantile = self._quantiles()
|
||||
state.last_train_time = time
|
||||
state.last_result = result
|
||||
|
||||
if not self._synch:
|
||||
state.last_perturbation_time = time
|
||||
lower_quantile, upper_quantile = self._quantiles()
|
||||
self._perturb_trial(trial, trial_runner, upper_quantile,
|
||||
lower_quantile)
|
||||
for trial in trial_runner.get_trials():
|
||||
if trial.status in [Trial.PENDING, Trial.PAUSED]:
|
||||
return TrialScheduler.PAUSE # yield time to other trials
|
||||
|
||||
return TrialScheduler.CONTINUE
|
||||
else:
|
||||
# Synchronous mode.
|
||||
if any(self._trial_state[t].last_train_time <
|
||||
self._next_perturbation_sync and t != trial
|
||||
for t in trial_runner.get_trials()):
|
||||
logger.debug("Pausing trial {}".format(trial))
|
||||
else:
|
||||
# All trials are synced at the same timestep.
|
||||
lower_quantile, upper_quantile = self._quantiles()
|
||||
all_trials = trial_runner.get_trials()
|
||||
not_in_quantile = []
|
||||
for t in all_trials:
|
||||
if t not in lower_quantile and t not in upper_quantile:
|
||||
not_in_quantile.append(t)
|
||||
# Move upper quantile trials to beginning and lower quantile
|
||||
# to end. This ensures that checkpointing of strong trials
|
||||
# occurs before exploiting of weaker ones.
|
||||
all_trials = upper_quantile + not_in_quantile + lower_quantile
|
||||
for t in all_trials:
|
||||
logger.debug("Perturbing Trial {}".format(t))
|
||||
self._trial_state[t].last_perturbation_time = time
|
||||
self._perturb_trial(t, trial_runner, upper_quantile,
|
||||
lower_quantile)
|
||||
|
||||
all_train_times = [
|
||||
self._trial_state[trial].last_train_time
|
||||
for trial in trial_runner.get_trials()
|
||||
]
|
||||
max_last_train_time = max(all_train_times)
|
||||
self._next_perturbation_sync = max(
|
||||
self._next_perturbation_sync + self._perturbation_interval,
|
||||
max_last_train_time)
|
||||
# In sync mode we should pause all trials once result comes in.
|
||||
# Once a perturbation step happens for all trials, they should
|
||||
# still all be paused.
|
||||
# choose_trial_to_run will then pick the next trial to run out of
|
||||
# the paused trials.
|
||||
return TrialScheduler.PAUSE
|
||||
|
||||
def _perturb_trial(self, trial, trial_runner, upper_quantile,
|
||||
lower_quantile):
|
||||
"""Checkpoint if in upper quantile, exploits if in lower."""
|
||||
state = self._trial_state[trial]
|
||||
if trial in upper_quantile:
|
||||
# The trial last result is only updated after the scheduler
|
||||
# callback. So, we override with the current result.
|
||||
state.last_checkpoint = trial_runner.trial_executor.save(
|
||||
trial, Checkpoint.MEMORY, result=result)
|
||||
logger.debug("Trial {} is in upper quantile".format(trial))
|
||||
logger.debug("Checkpointing {}".format(trial))
|
||||
if trial.status == Trial.PAUSED:
|
||||
# Paused trial will always have an in-memory checkpoint.
|
||||
state.last_checkpoint = trial.checkpoint
|
||||
else:
|
||||
state.last_checkpoint = trial_runner.trial_executor.save(
|
||||
trial, Checkpoint.MEMORY, result=state.last_result)
|
||||
self._num_checkpoints += 1
|
||||
else:
|
||||
state.last_checkpoint = None # not a top trial
|
||||
|
||||
if trial in lower_quantile:
|
||||
logger.debug("Trial {} is in lower quantile".format(trial))
|
||||
trial_to_clone = random.choice(upper_quantile)
|
||||
assert trial is not trial_to_clone
|
||||
self._exploit(trial_runner.trial_executor, trial, trial_to_clone)
|
||||
|
||||
for trial in trial_runner.get_trials():
|
||||
if trial.status in [Trial.PENDING, Trial.PAUSED]:
|
||||
return TrialScheduler.PAUSE # yield time to other trials
|
||||
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def _log_config_on_step(self, trial_state, new_state, trial,
|
||||
trial_to_clone, new_config):
|
||||
"""Logs transition during exploit/exploit step.
|
||||
@@ -417,26 +490,40 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
|
||||
new_tag = make_experiment_tag(trial_state.orig_tag, new_config,
|
||||
self._hyperparam_mutations)
|
||||
reset_successful = trial_executor.reset_trial(trial, new_config,
|
||||
new_tag)
|
||||
|
||||
# TODO(ujvl): Refactor Scheduler abstraction to abstract
|
||||
# mechanism for trial restart away. We block on restore
|
||||
# and suppress train on start as a stop-gap fix to
|
||||
# https://github.com/ray-project/ray/issues/7258.
|
||||
if reset_successful:
|
||||
trial_executor.restore(
|
||||
trial, new_state.last_checkpoint, block=True)
|
||||
else:
|
||||
trial_executor.stop_trial(trial, stop_logger=False)
|
||||
if trial.status == Trial.PAUSED:
|
||||
# If trial is paused we update it with a new checkpoint.
|
||||
# When the trial is started again, the new checkpoint is used.
|
||||
if not self._synch:
|
||||
raise TuneError("Trials should be paused here only if in "
|
||||
"synchronous mode. If you encounter this error"
|
||||
" please raise an issue on Ray Github.")
|
||||
trial.config = new_config
|
||||
trial.experiment_tag = new_tag
|
||||
trial_executor.start_trial(
|
||||
trial, new_state.last_checkpoint, train=False)
|
||||
trial.on_checkpoint(new_state.last_checkpoint)
|
||||
else:
|
||||
# If trial is running, we first try to reset it.
|
||||
# If that is unsuccessful, then we have to stop it and start it
|
||||
# again with a new checkpoint.
|
||||
reset_successful = trial_executor.reset_trial(
|
||||
trial, new_config, new_tag)
|
||||
# TODO(ujvl): Refactor Scheduler abstraction to abstract
|
||||
# mechanism for trial restart away. We block on restore
|
||||
# and suppress train on start as a stop-gap fix to
|
||||
# https://github.com/ray-project/ray/issues/7258.
|
||||
if reset_successful:
|
||||
trial_executor.restore(
|
||||
trial, new_state.last_checkpoint, block=True)
|
||||
else:
|
||||
trial_executor.stop_trial(trial, stop_logger=False)
|
||||
trial.config = new_config
|
||||
trial.experiment_tag = new_tag
|
||||
trial_executor.start_trial(
|
||||
trial, new_state.last_checkpoint, train=False)
|
||||
|
||||
self._num_perturbations += 1
|
||||
# Transfer over the last perturbation time as well
|
||||
trial_state.last_perturbation_time = new_state.last_perturbation_time
|
||||
trial_state.last_train_time = new_state.last_train_time
|
||||
|
||||
def _quantiles(self):
|
||||
"""Returns trials in the lower and upper `quantile` of the population.
|
||||
@@ -445,6 +532,9 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
"""
|
||||
trials = []
|
||||
for trial, state in self._trial_state.items():
|
||||
logger.debug("Trial {}, state {}".format(trial, state))
|
||||
if trial.is_finished():
|
||||
logger.debug("Trial {} is finished".format(trial))
|
||||
if state.last_score is not None and not trial.is_finished():
|
||||
trials.append(trial)
|
||||
trials.sort(key=lambda t: self._trial_state[t].last_score)
|
||||
@@ -469,9 +559,13 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
for trial in trial_runner.get_trials():
|
||||
if trial.status in [Trial.PENDING, Trial.PAUSED] and \
|
||||
trial_runner.has_resources(trial.resources):
|
||||
candidates.append(trial)
|
||||
if not self._synch:
|
||||
candidates.append(trial)
|
||||
elif self._trial_state[trial].last_train_time < \
|
||||
self._next_perturbation_sync:
|
||||
candidates.append(trial)
|
||||
candidates.sort(
|
||||
key=lambda trial: self._trial_state[trial].last_perturbation_time)
|
||||
key=lambda trial: self._trial_state[trial].last_train_time)
|
||||
return candidates[0] if candidates else None
|
||||
|
||||
def reset_stats(self):
|
||||
|
||||
@@ -241,6 +241,7 @@ class _MockTrialRunner():
|
||||
return True
|
||||
|
||||
def _pause_trial(self, trial):
|
||||
self.trial_executor.save(trial, Checkpoint.MEMORY, None)
|
||||
trial.status = Trial.PAUSED
|
||||
|
||||
def _launch_trial(self, trial):
|
||||
@@ -702,6 +703,13 @@ class _MockTrial(Trial):
|
||||
self.custom_trial_name = None
|
||||
self.custom_dirname = None
|
||||
|
||||
def on_checkpoint(self, checkpoint):
|
||||
self.restored_checkpoint = checkpoint.value
|
||||
|
||||
@property
|
||||
def checkpoint(self):
|
||||
return Checkpoint(Checkpoint.MEMORY, self.trainable_name, None)
|
||||
|
||||
|
||||
class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -720,7 +728,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
require_attrs=True,
|
||||
hyperparams=None,
|
||||
hyperparam_mutations=None,
|
||||
step_once=True):
|
||||
step_once=True,
|
||||
synch=False):
|
||||
hyperparam_mutations = hyperparam_mutations or {
|
||||
"float_factor": lambda: 100.0,
|
||||
"int_factor": lambda: 10,
|
||||
@@ -734,7 +743,9 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
hyperparam_mutations=hyperparam_mutations,
|
||||
custom_explore_fn=explore,
|
||||
log_config=log_config,
|
||||
require_attrs=require_attrs)
|
||||
synch=synch,
|
||||
require_attrs=require_attrs,
|
||||
)
|
||||
runner = _MockTrialRunner(pbt)
|
||||
for i in range(num_trials):
|
||||
trial_hyperparams = hyperparams or {
|
||||
@@ -746,10 +757,17 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
trial = _MockTrial(i, trial_hyperparams)
|
||||
runner.add_trial(trial)
|
||||
trial.status = Trial.RUNNING
|
||||
for i in range(num_trials):
|
||||
trial = runner.trials[i]
|
||||
if step_once:
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
|
||||
TrialScheduler.CONTINUE)
|
||||
if synch:
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
|
||||
TrialScheduler.PAUSE)
|
||||
else:
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
|
||||
TrialScheduler.CONTINUE)
|
||||
pbt.reset_stats()
|
||||
return pbt, runner
|
||||
|
||||
@@ -822,6 +840,32 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
self.assertEqual(pbt._num_checkpoints, 2)
|
||||
self.assertEqual(pbt._num_perturbations, 0)
|
||||
|
||||
def testCheckpointMostPromisingTrialsSynch(self):
|
||||
pbt, runner = self.basicSetup(synch=True)
|
||||
trials = runner.get_trials()
|
||||
|
||||
# no checkpoint: haven't hit next perturbation interval yet
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(15, 200)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 0)
|
||||
|
||||
# trials should be paused until all trials are synced.
|
||||
for i in range(len(trials) - 1):
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[i], result(20, 200 + i)),
|
||||
TrialScheduler.PAUSE)
|
||||
|
||||
self.assertEqual(pbt.last_scores(trials), [200, 201, 202, 203, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 0)
|
||||
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[-1], result(20, 204)),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(pbt._num_checkpoints, 2)
|
||||
|
||||
def testPerturbsLowPerformingTrials(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
@@ -852,6 +896,35 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertTrue("@perturbed" in trials[2].experiment_tag)
|
||||
|
||||
def testPerturbsLowPerformingTrialsSynch(self):
|
||||
pbt, runner = self.basicSetup(synch=True)
|
||||
trials = runner.get_trials()
|
||||
|
||||
# no perturbation: haven't hit next perturbation interval
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[-1], result(15, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertTrue("@perturbed" not in trials[-1].experiment_tag)
|
||||
self.assertEqual(pbt._num_perturbations, 0)
|
||||
|
||||
# Don't perturb until all trials are synched.
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[-1], result(20, -100)),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, -100])
|
||||
self.assertTrue("@perturbed" not in trials[-1].experiment_tag)
|
||||
|
||||
# Synch all trials.
|
||||
for i in range(len(trials) - 1):
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[i], result(20, -10 * i)),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(pbt.last_scores(trials), [0, -10, -20, -30, -100])
|
||||
self.assertIn(trials[-1].restored_checkpoint, ["trial_0", "trial_1"])
|
||||
self.assertIn(trials[-2].restored_checkpoint, ["trial_0", "trial_1"])
|
||||
self.assertEqual(pbt._num_perturbations, 2)
|
||||
|
||||
def testPerturbWithoutResample(self):
|
||||
pbt, runner = self.basicSetup(resample_prob=0.0)
|
||||
trials = runner.get_trials()
|
||||
@@ -1088,6 +1161,27 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
trials[i].status = Trial.PENDING
|
||||
self.assertEqual(pbt.choose_trial_to_run(runner), trials[3])
|
||||
|
||||
def testSchedulesMostBehindTrialToRunSynch(self):
|
||||
pbt, runner = self.basicSetup(synch=True)
|
||||
trials = runner.get_trials()
|
||||
runner.process_action(
|
||||
trials[0], pbt.on_trial_result(runner, trials[0], result(
|
||||
800, 1000)))
|
||||
runner.process_action(
|
||||
trials[1], pbt.on_trial_result(runner, trials[1], result(
|
||||
700, 1001)))
|
||||
runner.process_action(
|
||||
trials[2], pbt.on_trial_result(runner, trials[2], result(
|
||||
600, 1002)))
|
||||
runner.process_action(
|
||||
trials[3], pbt.on_trial_result(runner, trials[3], result(
|
||||
500, 1003)))
|
||||
runner.process_action(
|
||||
trials[4], pbt.on_trial_result(runner, trials[4], result(
|
||||
700, 1004)))
|
||||
self.assertIn(
|
||||
pbt.choose_trial_to_run(runner), [trials[0], trials[1], trials[3]])
|
||||
|
||||
def testPerturbationResetsLastPerturbTime(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
@@ -1141,6 +1235,44 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
check_policy(json.loads(line))
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testLogConfigSynch(self):
|
||||
def check_policy(policy):
|
||||
self.assertIsInstance(policy[2], int)
|
||||
self.assertIsInstance(policy[3], int)
|
||||
self.assertIn(policy[0], ["0tag", "1tag"])
|
||||
self.assertIn(policy[1], ["3tag", "4tag"])
|
||||
self.assertIn(policy[2], [0, 1])
|
||||
self.assertIn(policy[3], [3, 4])
|
||||
for i in [4, 5]:
|
||||
self.assertIsInstance(policy[i], dict)
|
||||
for key in [
|
||||
"const_factor", "int_factor", "float_factor",
|
||||
"id_factor"
|
||||
]:
|
||||
self.assertIn(key, policy[i])
|
||||
self.assertIsInstance(policy[i]["float_factor"], float)
|
||||
self.assertIsInstance(policy[i]["int_factor"], int)
|
||||
self.assertIn(policy[i]["const_factor"], [3])
|
||||
self.assertIn(policy[i]["int_factor"], [8, 10, 12])
|
||||
self.assertIn(policy[i]["float_factor"], [2.4, 2, 1.6])
|
||||
self.assertIn(policy[i]["id_factor"], [3, 4, 100])
|
||||
|
||||
pbt, runner = self.basicSetup(
|
||||
log_config=True, synch=True, step_once=False)
|
||||
trials = runner.get_trials()
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
for i, trial in enumerate(trials):
|
||||
trial.local_dir = tmpdir
|
||||
trial.last_result = {TRAINING_ITERATION: i}
|
||||
pbt.on_trial_result(runner, trials[i], result(10, i))
|
||||
log_files = ["pbt_global.txt", "pbt_policy_0.txt", "pbt_policy_1.txt"]
|
||||
for log_file in log_files:
|
||||
self.assertTrue(os.path.exists(os.path.join(tmpdir, log_file)))
|
||||
raw_policy = open(os.path.join(tmpdir, log_file), "r").readlines()
|
||||
for line in raw_policy:
|
||||
check_policy(json.loads(line))
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testReplay(self):
|
||||
# Returns unique increasing parameter mutations
|
||||
class _Counter:
|
||||
@@ -1156,6 +1288,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
perturbation_interval=5,
|
||||
log_config=True,
|
||||
step_once=False,
|
||||
synch=False,
|
||||
hyperparam_mutations={
|
||||
"float_factor": lambda: 100.0,
|
||||
"int_factor": _Counter(1000)
|
||||
@@ -1292,6 +1425,176 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testReplaySynch(self):
|
||||
# Returns unique increasing parameter mutations
|
||||
class _Counter:
|
||||
def __init__(self, start=0):
|
||||
self.count = start - 1
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.count += 1
|
||||
return self.count
|
||||
|
||||
pbt, runner = self.basicSetup(
|
||||
num_trials=4,
|
||||
perturbation_interval=5,
|
||||
log_config=True,
|
||||
step_once=False,
|
||||
synch=True,
|
||||
hyperparam_mutations={
|
||||
"float_factor": lambda: 100.0,
|
||||
"int_factor": _Counter(1000)
|
||||
})
|
||||
trials = runner.get_trials()
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
# Internal trial state to collect the real PBT history
|
||||
class _TrialState:
|
||||
def __init__(self, config):
|
||||
self.step = 0
|
||||
self.config = config
|
||||
self.history = []
|
||||
|
||||
def forward(self, t):
|
||||
while self.step < t:
|
||||
self.history.append(self.config)
|
||||
self.step += 1
|
||||
|
||||
trial_state = []
|
||||
for i, trial in enumerate(trials):
|
||||
trial.local_dir = tmpdir
|
||||
trial.last_result = {TRAINING_ITERATION: 0}
|
||||
trial_state.append(_TrialState(trial.config))
|
||||
|
||||
# Helper function to simulate stepping trial k a number of steps,
|
||||
# and reporting a score at the end
|
||||
def trial_step(k, steps, score, synced=False):
|
||||
res = result(trial_state[k].step + steps, score)
|
||||
|
||||
trials[k].last_result = res
|
||||
trial_state[k].forward(res[TRAINING_ITERATION])
|
||||
|
||||
trials[k].status = Trial.RUNNING
|
||||
if not synced:
|
||||
action = pbt.on_trial_result(runner, trials[k], res)
|
||||
runner.process_action(trials[k], action)
|
||||
return
|
||||
else:
|
||||
# Reached synchronization point
|
||||
old_configs = [trial.config for trial in trials]
|
||||
action = pbt.on_trial_result(runner, trials[k], res)
|
||||
runner.process_action(trials[k], action)
|
||||
new_configs = [trial.config for trial in trials]
|
||||
|
||||
for i in range(len(trials)):
|
||||
old_config = old_configs[i]
|
||||
new_config = new_configs[i]
|
||||
if old_config != new_config:
|
||||
# Copy history from source trial
|
||||
source = -1
|
||||
for m, cand in enumerate(trials):
|
||||
if cand.trainable_name == trials[
|
||||
i].restored_checkpoint:
|
||||
source = m
|
||||
break
|
||||
assert source >= 0
|
||||
trial_state[i].history = trial_state[
|
||||
source].history.copy()
|
||||
trial_state[i].step = trial_state[source].step
|
||||
trial_state[i].config = new_config.copy()
|
||||
|
||||
# Initial steps
|
||||
trial_step(0, 10, 0)
|
||||
trial_step(1, 11, 10)
|
||||
trial_step(2, 12, 0)
|
||||
trial_step(3, 13, -1, synced=True)
|
||||
|
||||
# 3 <-- 1, new_t 11
|
||||
# next_perturb_sync = 13
|
||||
|
||||
# Next block
|
||||
trial_step(0, 17, -10) # 20
|
||||
trial_step(2, 15, -20) # 20
|
||||
trial_step(3, 16, 0) # 20
|
||||
trial_step(1, 7, 1, synced=True) # 18
|
||||
|
||||
# 2 <-- 1, new_t=11+7=18
|
||||
# next_perturb_sync = 20
|
||||
|
||||
# Next block
|
||||
trial_step(2, 13, 0) # 31
|
||||
trial_step(3, 14, 10) # 34
|
||||
trial_step(0, 11, -1) # 31
|
||||
trial_step(1, 12, 0, synced=True) # 30
|
||||
|
||||
# 0 <-- 3, new_t=11+9+14=34
|
||||
# next_perturb_sync = 34
|
||||
|
||||
# Next block
|
||||
trial_step(0, 6, 20) # 40
|
||||
trial_step(3, 9, -40) # 43
|
||||
trial_step(2, 8, -50) # 39
|
||||
trial_step(1, 7, 30, synced=True) # 37
|
||||
|
||||
# 2 <-- 1, new_t=18+13+8=37
|
||||
# next_perturb_sync = 43
|
||||
|
||||
# Playback trainable to collect configs at each step
|
||||
class Playback(Trainable):
|
||||
def setup(self, config):
|
||||
self.config = config
|
||||
self.replayed = []
|
||||
self.iter = 0
|
||||
|
||||
def step(self):
|
||||
self.iter += 1
|
||||
self.replayed.append(self.config)
|
||||
return {
|
||||
"reward": 0,
|
||||
"done": False,
|
||||
"replayed": self.replayed,
|
||||
TRAINING_ITERATION: self.iter
|
||||
}
|
||||
|
||||
def reset_config(self, new_config):
|
||||
self.config = new_config
|
||||
return True
|
||||
|
||||
def save_checkpoint(self, tmp_checkpoint_dir):
|
||||
return tmp_checkpoint_dir
|
||||
|
||||
def load_checkpoint(self, checkpoint):
|
||||
pass
|
||||
|
||||
# Loop through all trials and check if PBT history is the
|
||||
# same as the playback history
|
||||
for i, trial in enumerate(trials):
|
||||
if trial.trial_id in ["1"]: # Did not exploit anything
|
||||
continue
|
||||
|
||||
replay = PopulationBasedTrainingReplay(
|
||||
os.path.join(tmpdir,
|
||||
"pbt_policy_{}.txt".format(trial.trial_id)))
|
||||
analysis = tune.run(
|
||||
Playback,
|
||||
scheduler=replay,
|
||||
stop={TRAINING_ITERATION: trial_state[i].step})
|
||||
|
||||
replayed = analysis.trials[0].last_result["replayed"]
|
||||
self.assertSequenceEqual(trial_state[i].history, replayed)
|
||||
|
||||
# Trial 1 did not exploit anything and should raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
replay = PopulationBasedTrainingReplay(
|
||||
os.path.join(tmpdir,
|
||||
"pbt_policy_{}.txt".format(trials[1].trial_id)))
|
||||
tune.run(
|
||||
Playback,
|
||||
scheduler=replay,
|
||||
stop={TRAINING_ITERATION: trial_state[1].step})
|
||||
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testPostprocessingHook(self):
|
||||
def explore(new_config):
|
||||
new_config["id_factor"] = 42
|
||||
|
||||
@@ -4,64 +4,13 @@ import pickle
|
||||
import random
|
||||
import unittest
|
||||
import sys
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
|
||||
class MockTrainable(tune.Trainable):
|
||||
def setup(self, config):
|
||||
self.iter = 0
|
||||
self.a = config["a"]
|
||||
self.b = config["b"]
|
||||
self.c = config["c"]
|
||||
|
||||
def step(self):
|
||||
self.iter += 1
|
||||
return {"mean_accuracy": (self.a - self.iter) * self.b}
|
||||
|
||||
def save_checkpoint(self, tmp_checkpoint_dir):
|
||||
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock")
|
||||
with open(checkpoint_path, "wb") as fp:
|
||||
pickle.dump((self.a, self.b, self.iter), fp)
|
||||
return tmp_checkpoint_dir
|
||||
|
||||
def load_checkpoint(self, tmp_checkpoint_dir):
|
||||
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock")
|
||||
with open(checkpoint_path, "rb") as fp:
|
||||
self.a, self.b, self.iter = pickle.load(fp)
|
||||
|
||||
|
||||
def MockTrainingFunc(config, checkpoint_dir=None):
|
||||
iter = 0
|
||||
a = config["a"]
|
||||
b = config["b"]
|
||||
|
||||
if checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "model.mock")
|
||||
with open(checkpoint_path, "rb") as fp:
|
||||
a, b, iter = pickle.load(fp)
|
||||
|
||||
while True:
|
||||
iter += 1
|
||||
with tune.checkpoint_dir(step=iter) as checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "model.mock")
|
||||
with open(checkpoint_path, "wb") as fp:
|
||||
pickle.dump((a, b, iter), fp)
|
||||
tune.report(mean_accuracy=(a - iter) * b)
|
||||
|
||||
|
||||
def MockTrainingFunc2(config):
|
||||
a = config["a"]
|
||||
b = config["b"]
|
||||
c1 = config["c"]["c1"]
|
||||
c2 = config["c"]["c2"]
|
||||
|
||||
while True:
|
||||
tune.report(mean_accuracy=a * b * (c1 + c2))
|
||||
|
||||
|
||||
class MockParam(object):
|
||||
def __init__(self, params):
|
||||
self._params = params
|
||||
@@ -73,6 +22,77 @@ class MockParam(object):
|
||||
return val
|
||||
|
||||
|
||||
class PopulationBasedTrainingSynchTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
def MockTrainingFuncSync(config, checkpoint_dir=None):
|
||||
iter = 0
|
||||
|
||||
if checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(checkpoint_path, "rb") as fp:
|
||||
a, iter = pickle.load(fp)
|
||||
|
||||
a = config["a"] # Use the new hyperparameter if perturbed.
|
||||
|
||||
while True:
|
||||
iter += 1
|
||||
with tune.checkpoint_dir(step=iter) as checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint")
|
||||
with open(checkpoint_path, "wb") as fp:
|
||||
pickle.dump((a, iter), fp)
|
||||
# Score gets better every iteration.
|
||||
time.sleep(1)
|
||||
tune.report(mean_accuracy=iter + a, a=a)
|
||||
|
||||
self.MockTrainingFuncSync = MockTrainingFuncSync
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def synchSetup(self, synch, param=[10, 20, 30]):
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="mean_accuracy",
|
||||
mode="max",
|
||||
perturbation_interval=1,
|
||||
log_config=True,
|
||||
hyperparam_mutations={"c": lambda: 1},
|
||||
synch=synch)
|
||||
|
||||
param_a = MockParam(param)
|
||||
|
||||
random.seed(100)
|
||||
np.random.seed(100)
|
||||
analysis = tune.run(
|
||||
self.MockTrainingFuncSync,
|
||||
config={
|
||||
"a": tune.sample_from(lambda _: param_a()),
|
||||
"c": 1
|
||||
},
|
||||
fail_fast=True,
|
||||
num_samples=3,
|
||||
scheduler=scheduler,
|
||||
name="testPBTSync",
|
||||
stop={"training_iteration": 3},
|
||||
)
|
||||
return analysis
|
||||
|
||||
def testAsynchFail(self):
|
||||
analysis = self.synchSetup(False)
|
||||
self.assertTrue(any(analysis.dataframe()["mean_accuracy"] != 33))
|
||||
|
||||
def testSynchPass(self):
|
||||
analysis = self.synchSetup(True)
|
||||
self.assertTrue(all(analysis.dataframe()["mean_accuracy"] == 33))
|
||||
|
||||
def testSynchPassLast(self):
|
||||
analysis = self.synchSetup(True, param=[30, 20, 10])
|
||||
self.assertTrue(all(analysis.dataframe()["mean_accuracy"] == 33))
|
||||
|
||||
|
||||
class PopulationBasedTrainingConfigTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
@@ -81,6 +101,15 @@ class PopulationBasedTrainingConfigTest(unittest.TestCase):
|
||||
ray.shutdown()
|
||||
|
||||
def testNoConfig(self):
|
||||
def MockTrainingFunc(config):
|
||||
a = config["a"]
|
||||
b = config["b"]
|
||||
c1 = config["c"]["c1"]
|
||||
c2 = config["c"]["c2"]
|
||||
|
||||
while True:
|
||||
tune.report(mean_accuracy=a * b * (c1 + c2))
|
||||
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="mean_accuracy",
|
||||
@@ -97,7 +126,7 @@ class PopulationBasedTrainingConfigTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
tune.run(
|
||||
MockTrainingFunc2,
|
||||
MockTrainingFunc,
|
||||
fail_fast=True,
|
||||
num_samples=4,
|
||||
scheduler=scheduler,
|
||||
@@ -120,6 +149,31 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase):
|
||||
fix was not applied.
|
||||
See issues #9036, #9036
|
||||
"""
|
||||
|
||||
class MockTrainable(tune.Trainable):
|
||||
def setup(self, config):
|
||||
self.iter = 0
|
||||
self.a = config["a"]
|
||||
self.b = config["b"]
|
||||
self.c = config["c"]
|
||||
|
||||
def step(self):
|
||||
self.iter += 1
|
||||
return {"mean_accuracy": (self.a - self.iter) * self.b}
|
||||
|
||||
def save_checkpoint(self, tmp_checkpoint_dir):
|
||||
checkpoint_path = os.path.join(tmp_checkpoint_dir,
|
||||
"model.mock")
|
||||
with open(checkpoint_path, "wb") as fp:
|
||||
pickle.dump((self.a, self.b, self.iter), fp)
|
||||
return tmp_checkpoint_dir
|
||||
|
||||
def load_checkpoint(self, tmp_checkpoint_dir):
|
||||
checkpoint_path = os.path.join(tmp_checkpoint_dir,
|
||||
"model.mock")
|
||||
with open(checkpoint_path, "rb") as fp:
|
||||
self.a, self.b, self.iter = pickle.load(fp)
|
||||
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="mean_accuracy",
|
||||
@@ -151,6 +205,25 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase):
|
||||
stop={"training_iteration": 3})
|
||||
|
||||
def testPermutationContinuationFunc(self):
|
||||
def MockTrainingFunc(config, checkpoint_dir=None):
|
||||
iter = 0
|
||||
a = config["a"]
|
||||
b = config["b"]
|
||||
|
||||
if checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "model.mock")
|
||||
with open(checkpoint_path, "rb") as fp:
|
||||
a, b, iter = pickle.load(fp)
|
||||
|
||||
while True:
|
||||
iter += 1
|
||||
with tune.checkpoint_dir(step=iter) as checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"model.mock")
|
||||
with open(checkpoint_path, "wb") as fp:
|
||||
pickle.dump((a, b, iter), fp)
|
||||
tune.report(mean_accuracy=(a - iter) * b)
|
||||
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="mean_accuracy",
|
||||
|
||||
@@ -460,6 +460,7 @@ class TrialRunner:
|
||||
self._update_trial_queue(blocking=wait_for_trial)
|
||||
with warn_if_slow("choose_trial_to_run"):
|
||||
trial = self._scheduler_alg.choose_trial_to_run(self)
|
||||
logger.debug("Running trial {}".format(trial))
|
||||
return trial
|
||||
|
||||
def _process_events(self):
|
||||
|
||||
Reference in New Issue
Block a user