mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 05:52:36 +08:00
[tune] Fixing up Hyperband (#1207)
* Fixing up Hyperband * nit * cleanup * Timing test Added * added_exception_back * fixup_tests * reverse placement * fixes_and_tests * fix * fix * fixlint * cleanup_timing * lint * Update hyperband.py
This commit is contained in:
@@ -60,11 +60,9 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
"""On a new trial add, if current bracket is not filled,
|
||||
add to current bracket. Else, if current hp iteration is not filled,
|
||||
add to current bracket. Else, if current band is not filled,
|
||||
create new bracket, add to current bracket.
|
||||
Else, create new iteration, create new bracket, add to bracket.
|
||||
|
||||
TODO(rliaw): This is messy."""
|
||||
Else, create new iteration, create new bracket, add to bracket."""
|
||||
|
||||
cur_bracket = self._state["bracket"]
|
||||
cur_band = self._hyperbands[self._state["band_idx"]]
|
||||
@@ -76,9 +74,9 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
self._hyperbands.append(cur_band)
|
||||
self._state["band_idx"] += 1
|
||||
|
||||
# cur_band will always be less than s_max or else filled
|
||||
s = self._s_max_1 - len(cur_band) - 1
|
||||
assert s >= 0, "Current band is filled but adding bracket!"
|
||||
# cur_band will always be less than s_max_1 or else filled
|
||||
s = len(cur_band)
|
||||
assert s < self._s_max_1, "Current band is filled!"
|
||||
|
||||
# create new bracket
|
||||
cur_bracket = Bracket(self._get_n0(s),
|
||||
@@ -102,34 +100,44 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
|
||||
If a given trial finishes and bracket iteration is not done,
|
||||
the trial will be paused and resources will be given up.
|
||||
When bracket iteration is done, Trials will be successively halved,
|
||||
and during each halving phase, bad trials will be stopped while good
|
||||
trials will return to "PENDING". This scheduler will not start trials
|
||||
but will stop trials. The current running trial will not be handled,
|
||||
|
||||
This scheduler will not start trials but will stop trials.
|
||||
The current running trial will not be handled,
|
||||
as the trialrunner will be given control to handle it.
|
||||
|
||||
# TODO(rliaw) should be only called if trial has not errored"""
|
||||
bracket, _ = self._trial_info[trial]
|
||||
bracket.update_trial_stats(trial, result)
|
||||
|
||||
if bracket.continue_trial(trial):
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
signal = TrialScheduler.PAUSE
|
||||
action = self._process_bracket(trial_runner, bracket, trial)
|
||||
return action
|
||||
|
||||
def _process_bracket(self, trial_runner, bracket, trial):
|
||||
"""This is called whenever a trial makes progress.
|
||||
|
||||
When all live trials in the bracket have no more iterations left,
|
||||
Trials will be successively halved. If bracket is done, all
|
||||
non-running trials will be stopped and cleaned up,
|
||||
and during each halving phase, bad trials will be stopped while good
|
||||
trials will return to "PENDING"."""
|
||||
|
||||
action = TrialScheduler.PAUSE
|
||||
if bracket.cur_iter_done():
|
||||
if bracket.finished():
|
||||
self._cleanup_bracket(trial_runner, bracket)
|
||||
return TrialScheduler.STOP
|
||||
# what if bracket is done and trial not completed?
|
||||
|
||||
good, bad = bracket.successive_halving()
|
||||
# kill bad trials
|
||||
for t in bad:
|
||||
self._num_stopped += 1
|
||||
if t.status == Trial.PAUSED:
|
||||
trial_runner._stop_trial(t)
|
||||
bracket.cleanup_trial_early(t)
|
||||
elif t is trial:
|
||||
signal = TrialScheduler.STOP
|
||||
self._cleanup_trial(trial_runner, t, bracket, hard=True)
|
||||
elif t.status == Trial.RUNNING:
|
||||
self._cleanup_trial(trial_runner, t, bracket, hard=False)
|
||||
action = TrialScheduler.STOP
|
||||
else:
|
||||
raise Exception("Trial with unexpected status encountered")
|
||||
|
||||
@@ -137,38 +145,42 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
for t in good:
|
||||
if t.status == Trial.PAUSED:
|
||||
t.unpause()
|
||||
elif t is trial:
|
||||
signal = TrialScheduler.CONTINUE
|
||||
elif t.status == Trial.RUNNING:
|
||||
action = TrialScheduler.CONTINUE
|
||||
else:
|
||||
raise Exception("Trial with unexpected status encountered")
|
||||
return action
|
||||
|
||||
return signal
|
||||
def _cleanup_trial(self, trial_runner, t, bracket, hard=False):
|
||||
"""Bookkeeping for trials finished. If `hard=True`, then
|
||||
this scheduler will force the trial_runner to release resources.
|
||||
|
||||
Otherwise, only clean up trial information locally."""
|
||||
self._num_stopped += 1
|
||||
if hard:
|
||||
trial_runner._stop_trial(t)
|
||||
bracket.cleanup_trial(t)
|
||||
|
||||
def _cleanup_bracket(self, trial_runner, bracket):
|
||||
"""Cleans up bracket after bracket is completely finished.
|
||||
|
||||
Bracket information will only be cleaned up after the trialrunner has
|
||||
finished its bookkeeping."""
|
||||
for t in bracket.current_trials():
|
||||
if t.status == Trial.PAUSED:
|
||||
trial_runner._stop_trial(t)
|
||||
bracket.cleanup_trial_early(t)
|
||||
"""Cleans up bracket after bracket is completely finished."""
|
||||
for trial in bracket.current_trials():
|
||||
self._cleanup_trial(
|
||||
trial_runner, trial, bracket,
|
||||
hard=(trial.status == Trial.PAUSED))
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
"""Cleans up trial info from bracket if trial completed early.
|
||||
"""Cleans up trial info from bracket if trial completed early."""
|
||||
|
||||
Bracket information will only be cleaned up after the trialrunner has
|
||||
finished its bookkeeping."""
|
||||
bracket, _ = self._trial_info[trial]
|
||||
bracket.cleanup_trial_early(trial)
|
||||
self._cleanup_trial(trial_runner, trial, bracket, hard=False)
|
||||
self._process_bracket(trial_runner, bracket, trial)
|
||||
|
||||
def on_trial_error(self, trial_runner, trial):
|
||||
"""Cleans up trial info from bracket if trial errored early.
|
||||
"""Cleans up trial info from bracket if trial errored early."""
|
||||
|
||||
Bracket information will only be cleaned up after the trialrunner has
|
||||
finished its bookkeeping."""
|
||||
bracket, _ = self._trial_info[trial]
|
||||
bracket.cleanup_trial_early(trial)
|
||||
self._cleanup_trial(trial_runner, trial, bracket, hard=False)
|
||||
self._process_bracket(trial_runner, bracket, trial)
|
||||
|
||||
def choose_trial_to_run(self, trial_runner, *args):
|
||||
"""Fair scheduling within iteration by completion percentage.
|
||||
@@ -177,6 +189,7 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
|
||||
If iteration is occupied (ie, no trials to run), then look into
|
||||
next iteration."""
|
||||
|
||||
for hyperband in self._hyperbands:
|
||||
for bracket in sorted(hyperband,
|
||||
key=lambda b: b.completion_percentage()):
|
||||
@@ -187,10 +200,17 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
return None
|
||||
|
||||
def debug_string(self):
|
||||
brackets = [
|
||||
"({0}/{1})".format(
|
||||
len(bracket._live_trials), len(bracket._all_trials))
|
||||
for band in self._hyperbands for bracket in band]
|
||||
return " ".join([
|
||||
"Using HyperBand:",
|
||||
"num_stopped={}".format(self._num_stopped),
|
||||
"brackets={}".format(sum(len(band) for band in self._hyperbands))])
|
||||
"total_brackets={}".format(
|
||||
sum(len(band) for band in self._hyperbands)),
|
||||
" ".join(brackets)
|
||||
])
|
||||
|
||||
|
||||
class Bracket():
|
||||
@@ -278,8 +298,9 @@ class Bracket():
|
||||
self._live_trials[trial] = (result, itr - 1)
|
||||
self._completed_progress += 1
|
||||
|
||||
def cleanup_trial_early(self, trial):
|
||||
"""Clean up statistics tracking for trial that terminated early.
|
||||
def cleanup_trial(self, trial):
|
||||
"""Clean up statistics tracking for terminated trials (either by force
|
||||
or otherwise).
|
||||
|
||||
This may cause bad trials to continue for a long time, in the case
|
||||
where all the good trials finish early and there are only bad trials
|
||||
|
||||
Reference in New Issue
Block a user