[tune] Allow trials to remain paused in BOHB (#10531)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Kai Fricke
2020-09-04 23:29:33 +01:00
committed by GitHub
parent 30911960c8
commit 6b6780a108
3 changed files with 38 additions and 4 deletions
+12 -3
View File
@@ -93,7 +93,7 @@ class HyperBandForBOHB(HyperBandScheduler):
trial_runner.trial_executor.unpause_trial(trial)
trial_runner._search_alg.searcher.on_unpause(trial.trial_id)
def choose_trial_to_run(self, trial_runner):
def choose_trial_to_run(self, trial_runner, allow_recurse=True):
"""Fair scheduling within iteration by completion percentage.
List of trials not used since all trials are tracked as state
@@ -117,8 +117,17 @@ class HyperBandForBOHB(HyperBandScheduler):
for bracket in hyperband:
if bracket and any(trial.status == Trial.PAUSED
for trial in bracket.current_trials()):
# This will change the trial state and let the
# trial runner retry.
# This will change the trial state
self._process_bracket(trial_runner, bracket)
# If there are pending trials now, suggest one.
# This is because there might be both PENDING and
# PAUSED trials now, and PAUSED trials will raise
# an error before the trial runner tries again.
if allow_recurse and any(
trial.status == Trial.PENDING
for trial in bracket.current_trials()):
return self.choose_trial_to_run(
trial_runner, allow_recurse=False)
# MAIN CHANGE HERE!
return None
+2 -1
View File
@@ -282,7 +282,8 @@ class HyperBandScheduler(FIFOScheduler):
for i, band in enumerate(self._hyperbands):
out += "\nRound #{}:".format(i)
for bracket in band:
out += "\n {}".format(bracket)
if bracket:
out += "\n {}".format(bracket)
return out
def state(self):
@@ -689,6 +689,30 @@ class BOHBSuite(unittest.TestCase):
self.assertTrue("hyperband_info" in spy_result)
self.assertEquals(spy_result["hyperband_info"]["budget"], 1)
def testPauseResumeChooseTrial(self):
def result(score, ts):
return {"episode_reward_mean": score, TRAINING_ITERATION: ts}
sched = HyperBandForBOHB(max_t=10, reduction_factor=3, mode="min")
runner = _MockTrialRunner(sched)
runner._search_alg = MagicMock()
runner._search_alg.searcher = MagicMock()
trials = [Trial("__fake") for i in range(3)]
for t in trials:
runner.add_trial(t)
runner._launch_trial(t)
all_results = [result(1, 5), result(2, 1), result(3, 5)]
for trial, trial_result in zip(trials, all_results):
decision = sched.on_trial_result(runner, trial, trial_result)
self.assertEqual(decision, TrialScheduler.PAUSE)
runner._pause_trial(trial)
run_trial = sched.choose_trial_to_run(runner)
self.assertEqual(run_trial, trials[1])
self.assertSequenceEqual([t.status for t in trials],
[Trial.PAUSED, Trial.PENDING, Trial.PAUSED])
class _MockTrial(Trial):
def __init__(self, i, config):