diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index 8d451e70d..7e2f8f27e 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -220,8 +220,11 @@ class HyperBandScheduler(FIFOScheduler): """ for hyperband in self._hyperbands: + # band will have None entries if no resources + # are to be allocated to that bracket. + scrubbed = [b for b in hyperband if b is not None] for bracket in sorted( - hyperband, key=lambda b: b.completion_percentage()): + scrubbed, key=lambda b: b.completion_percentage()): for trial in bracket.current_trials(): if (trial.status == Trial.PENDING and trial_runner.has_resources(trial.resources)): diff --git a/python/ray/tune/test/trial_scheduler_test.py b/python/ray/tune/test/trial_scheduler_test.py index 21aabec81..1c32f72e0 100644 --- a/python/ray/tune/test/trial_scheduler_test.py +++ b/python/ray/tune/test/trial_scheduler_test.py @@ -219,7 +219,7 @@ class HyperbandSuite(unittest.TestCase): ray.shutdown() _register_all() # re-register the evicted objects - def schedulerSetup(self, num_trials): + def schedulerSetup(self, num_trials, max_t=81): """Setup a scheduler and Runner with max Iter = 9. Bracketing is placed as follows: @@ -228,7 +228,7 @@ class HyperbandSuite(unittest.TestCase): (15, 9) -> (5, 27) -> (2, 45); (34, 3) -> (12, 9) -> (4, 27) -> (2, 42); (81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 41);""" - sched = HyperBandScheduler() + sched = HyperBandScheduler(max_t=max_t) for i in range(num_trials): t = Trial("__fake") sched.on_trial_add(None, t) @@ -556,6 +556,18 @@ class HyperbandSuite(unittest.TestCase): sched.on_trial_remove(runner, trial) # where trial is not running self.assertFalse(trial in bracket._live_trials) + def testFilterNoneBracket(self): + sched, runner = self.schedulerSetup(100, 20) + # `sched' should contains None brackets + non_brackets = [ + b for hyperband in sched._hyperbands for b in hyperband + if b is None + ] + self.assertTrue(non_brackets) + # Make sure `choose_trial_to_run' still works + trial = sched.choose_trial_to_run(runner) + self.assertIsNotNone(trial) + class _MockTrial(Trial): def __init__(self, i, config):