diff --git a/python/ray/tune/schedulers/hb_bohb.py b/python/ray/tune/schedulers/hb_bohb.py index 826a32b8d..7204e71e3 100644 --- a/python/ray/tune/schedulers/hb_bohb.py +++ b/python/ray/tune/schedulers/hb_bohb.py @@ -93,7 +93,7 @@ class HyperBandForBOHB(HyperBandScheduler): trial_runner.trial_executor.unpause_trial(trial) trial_runner._search_alg.searcher.on_unpause(trial.trial_id) - def choose_trial_to_run(self, trial_runner): + def choose_trial_to_run(self, trial_runner, allow_recurse=True): """Fair scheduling within iteration by completion percentage. List of trials not used since all trials are tracked as state @@ -117,8 +117,17 @@ class HyperBandForBOHB(HyperBandScheduler): for bracket in hyperband: if bracket and any(trial.status == Trial.PAUSED for trial in bracket.current_trials()): - # This will change the trial state and let the - # trial runner retry. + # This will change the trial state self._process_bracket(trial_runner, bracket) + + # If there are pending trials now, suggest one. + # This is because there might be both PENDING and + # PAUSED trials now, and PAUSED trials will raise + # an error before the trial runner tries again. + if allow_recurse and any( + trial.status == Trial.PENDING + for trial in bracket.current_trials()): + return self.choose_trial_to_run( + trial_runner, allow_recurse=False) # MAIN CHANGE HERE! return None diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index 94897b71d..a2fe3ad91 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -282,7 +282,8 @@ class HyperBandScheduler(FIFOScheduler): for i, band in enumerate(self._hyperbands): out += "\nRound #{}:".format(i) for bracket in band: - out += "\n {}".format(bracket) + if bracket: + out += "\n {}".format(bracket) return out def state(self): diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index ec7e96c5b..320e76af3 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -689,6 +689,30 @@ class BOHBSuite(unittest.TestCase): self.assertTrue("hyperband_info" in spy_result) self.assertEquals(spy_result["hyperband_info"]["budget"], 1) + def testPauseResumeChooseTrial(self): + def result(score, ts): + return {"episode_reward_mean": score, TRAINING_ITERATION: ts} + + sched = HyperBandForBOHB(max_t=10, reduction_factor=3, mode="min") + runner = _MockTrialRunner(sched) + runner._search_alg = MagicMock() + runner._search_alg.searcher = MagicMock() + trials = [Trial("__fake") for i in range(3)] + for t in trials: + runner.add_trial(t) + runner._launch_trial(t) + + all_results = [result(1, 5), result(2, 1), result(3, 5)] + for trial, trial_result in zip(trials, all_results): + decision = sched.on_trial_result(runner, trial, trial_result) + self.assertEqual(decision, TrialScheduler.PAUSE) + runner._pause_trial(trial) + + run_trial = sched.choose_trial_to_run(runner) + self.assertEqual(run_trial, trials[1]) + self.assertSequenceEqual([t.status for t in trials], + [Trial.PAUSED, Trial.PENDING, Trial.PAUSED]) + class _MockTrial(Trial): def __init__(self, i, config):