mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 16:54:21 +08:00
[tune] Allow trials to remain paused in BOHB (#10531)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -93,7 +93,7 @@ class HyperBandForBOHB(HyperBandScheduler):
|
||||
trial_runner.trial_executor.unpause_trial(trial)
|
||||
trial_runner._search_alg.searcher.on_unpause(trial.trial_id)
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
def choose_trial_to_run(self, trial_runner, allow_recurse=True):
|
||||
"""Fair scheduling within iteration by completion percentage.
|
||||
|
||||
List of trials not used since all trials are tracked as state
|
||||
@@ -117,8 +117,17 @@ class HyperBandForBOHB(HyperBandScheduler):
|
||||
for bracket in hyperband:
|
||||
if bracket and any(trial.status == Trial.PAUSED
|
||||
for trial in bracket.current_trials()):
|
||||
# This will change the trial state and let the
|
||||
# trial runner retry.
|
||||
# This will change the trial state
|
||||
self._process_bracket(trial_runner, bracket)
|
||||
|
||||
# If there are pending trials now, suggest one.
|
||||
# This is because there might be both PENDING and
|
||||
# PAUSED trials now, and PAUSED trials will raise
|
||||
# an error before the trial runner tries again.
|
||||
if allow_recurse and any(
|
||||
trial.status == Trial.PENDING
|
||||
for trial in bracket.current_trials()):
|
||||
return self.choose_trial_to_run(
|
||||
trial_runner, allow_recurse=False)
|
||||
# MAIN CHANGE HERE!
|
||||
return None
|
||||
|
||||
@@ -282,7 +282,8 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
for i, band in enumerate(self._hyperbands):
|
||||
out += "\nRound #{}:".format(i)
|
||||
for bracket in band:
|
||||
out += "\n {}".format(bracket)
|
||||
if bracket:
|
||||
out += "\n {}".format(bracket)
|
||||
return out
|
||||
|
||||
def state(self):
|
||||
|
||||
@@ -689,6 +689,30 @@ class BOHBSuite(unittest.TestCase):
|
||||
self.assertTrue("hyperband_info" in spy_result)
|
||||
self.assertEquals(spy_result["hyperband_info"]["budget"], 1)
|
||||
|
||||
def testPauseResumeChooseTrial(self):
|
||||
def result(score, ts):
|
||||
return {"episode_reward_mean": score, TRAINING_ITERATION: ts}
|
||||
|
||||
sched = HyperBandForBOHB(max_t=10, reduction_factor=3, mode="min")
|
||||
runner = _MockTrialRunner(sched)
|
||||
runner._search_alg = MagicMock()
|
||||
runner._search_alg.searcher = MagicMock()
|
||||
trials = [Trial("__fake") for i in range(3)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
runner._launch_trial(t)
|
||||
|
||||
all_results = [result(1, 5), result(2, 1), result(3, 5)]
|
||||
for trial, trial_result in zip(trials, all_results):
|
||||
decision = sched.on_trial_result(runner, trial, trial_result)
|
||||
self.assertEqual(decision, TrialScheduler.PAUSE)
|
||||
runner._pause_trial(trial)
|
||||
|
||||
run_trial = sched.choose_trial_to_run(runner)
|
||||
self.assertEqual(run_trial, trials[1])
|
||||
self.assertSequenceEqual([t.status for t in trials],
|
||||
[Trial.PAUSED, Trial.PENDING, Trial.PAUSED])
|
||||
|
||||
|
||||
class _MockTrial(Trial):
|
||||
def __init__(self, i, config):
|
||||
|
||||
Reference in New Issue
Block a user