[tune] HyperBand Fixes (#1586)

This commit is contained in:
Richard Liaw
2018-02-25 13:26:58 -08:00
committed by GitHub
parent 2026c147ec
commit 31fefa20b7
2 changed files with 28 additions and 17 deletions
+25 -9
View File
@@ -38,7 +38,7 @@ class HyperBandScheduler(FIFOScheduler):
algorithm. It divides trials into brackets of varying sizes, and
periodically early stops low-performing trials within each bracket.
To use this implementation of HyperBand with Ray.tune, all you need
To use this implementation of HyperBand with Ray Tune, all you need
to do is specify the max length of time a trial can run `max_t`, the time
units `time_attr`, and the name of the reported objective value
`reward_attr`. We automatically determine reasonable values for the other
@@ -164,7 +164,7 @@ class HyperBandScheduler(FIFOScheduler):
if bracket.cur_iter_done():
if bracket.finished():
bracket.cleanup_full(trial_runner)
return TrialScheduler.CONTINUE
return TrialScheduler.STOP
good, bad = bracket.successive_halving(self._reward_attr)
# kill bad trials
@@ -225,6 +225,22 @@ class HyperBandScheduler(FIFOScheduler):
return None
def debug_string(self):
"""This provides a progress notification for the algorithm.
For each bracket, the algorithm will output a string as follows:
Bracket(Max Size (n)=5, Milestone (r)=33, completed=14.6%):
{PENDING: 2, RUNNING: 3, TERMINATED: 2}
"Max Size" indicates the max number of pending/running experiments
set according to the Hyperband algorithm.
"Milestone" indicates the iterations a trial will run for before
the next halving will occur.
"Completed" indicates an approximate progress metric. Some brackets,
like ones that are unfilled, will not reach 100%.
"""
out = "Using HyperBand: "
out += "num_stopped={} total_brackets={}".format(
self._num_stopped, sum(len(band) for band in self._hyperbands))
@@ -367,11 +383,11 @@ class Bracket():
def __repr__(self):
status = ", ".join([
"n={}".format(self._n),
"r={}".format(self._r),
"completed={}%".format(int(100 * self.completion_percentage()))
"Max Size (n)={}".format(self._n),
"Milestone (r)={}".format(self._r),
"completed={:.1%}".format(self.completion_percentage())
])
counts = collections.Counter()
for t in self._all_trials:
counts[t.status] += 1
return "Bracket({}): {}".format(status, dict(counts))
counts = collections.Counter([t.status for t in self._all_trials])
trial_statuses = ", ".join(sorted(
["{}: {}".format(k, v) for k, v in counts.items()]))
return "Bracket({}): {{{}}} ".format(status, trial_statuses)
+3 -8
View File
@@ -327,9 +327,9 @@ class HyperbandSuite(unittest.TestCase):
self.assertEqual(action, TrialScheduler.STOP)
def testContinueLastOne(self):
def testStopsLastOne(self):
stats = self.default_statistics()
num_trials = stats[str(0)]["n"]
num_trials = stats[str(0)]["n"] # setup one bracket
sched, mock_runner = self.schedulerSetup(num_trials)
big_bracket = sched._state["bracket"]
for trl in big_bracket.current_trials():
@@ -342,12 +342,7 @@ class HyperbandSuite(unittest.TestCase):
mock_runner, trl, result(cur_units, i))
mock_runner.process_action(trl, action)
self.assertEqual(action, TrialScheduler.CONTINUE)
for x in range(100):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units + x, 10))
self.assertEqual(action, TrialScheduler.CONTINUE)
self.assertEqual(action, TrialScheduler.STOP)
def testTrialErrored(self):
"""If a trial errored, make sure successive halving still happens"""