mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 21:43:13 +08:00
[tune] HyperBand Fixes (#1586)
This commit is contained in:
@@ -38,7 +38,7 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
algorithm. It divides trials into brackets of varying sizes, and
|
||||
periodically early stops low-performing trials within each bracket.
|
||||
|
||||
To use this implementation of HyperBand with Ray.tune, all you need
|
||||
To use this implementation of HyperBand with Ray Tune, all you need
|
||||
to do is specify the max length of time a trial can run `max_t`, the time
|
||||
units `time_attr`, and the name of the reported objective value
|
||||
`reward_attr`. We automatically determine reasonable values for the other
|
||||
@@ -164,7 +164,7 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
if bracket.cur_iter_done():
|
||||
if bracket.finished():
|
||||
bracket.cleanup_full(trial_runner)
|
||||
return TrialScheduler.CONTINUE
|
||||
return TrialScheduler.STOP
|
||||
|
||||
good, bad = bracket.successive_halving(self._reward_attr)
|
||||
# kill bad trials
|
||||
@@ -225,6 +225,22 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
return None
|
||||
|
||||
def debug_string(self):
|
||||
"""This provides a progress notification for the algorithm.
|
||||
|
||||
For each bracket, the algorithm will output a string as follows:
|
||||
|
||||
Bracket(Max Size (n)=5, Milestone (r)=33, completed=14.6%):
|
||||
{PENDING: 2, RUNNING: 3, TERMINATED: 2}
|
||||
|
||||
"Max Size" indicates the max number of pending/running experiments
|
||||
set according to the Hyperband algorithm.
|
||||
|
||||
"Milestone" indicates the iterations a trial will run for before
|
||||
the next halving will occur.
|
||||
|
||||
"Completed" indicates an approximate progress metric. Some brackets,
|
||||
like ones that are unfilled, will not reach 100%.
|
||||
"""
|
||||
out = "Using HyperBand: "
|
||||
out += "num_stopped={} total_brackets={}".format(
|
||||
self._num_stopped, sum(len(band) for band in self._hyperbands))
|
||||
@@ -367,11 +383,11 @@ class Bracket():
|
||||
|
||||
def __repr__(self):
|
||||
status = ", ".join([
|
||||
"n={}".format(self._n),
|
||||
"r={}".format(self._r),
|
||||
"completed={}%".format(int(100 * self.completion_percentage()))
|
||||
"Max Size (n)={}".format(self._n),
|
||||
"Milestone (r)={}".format(self._r),
|
||||
"completed={:.1%}".format(self.completion_percentage())
|
||||
])
|
||||
counts = collections.Counter()
|
||||
for t in self._all_trials:
|
||||
counts[t.status] += 1
|
||||
return "Bracket({}): {}".format(status, dict(counts))
|
||||
counts = collections.Counter([t.status for t in self._all_trials])
|
||||
trial_statuses = ", ".join(sorted(
|
||||
["{}: {}".format(k, v) for k, v in counts.items()]))
|
||||
return "Bracket({}): {{{}}} ".format(status, trial_statuses)
|
||||
|
||||
@@ -327,9 +327,9 @@ class HyperbandSuite(unittest.TestCase):
|
||||
|
||||
self.assertEqual(action, TrialScheduler.STOP)
|
||||
|
||||
def testContinueLastOne(self):
|
||||
def testStopsLastOne(self):
|
||||
stats = self.default_statistics()
|
||||
num_trials = stats[str(0)]["n"]
|
||||
num_trials = stats[str(0)]["n"] # setup one bracket
|
||||
sched, mock_runner = self.schedulerSetup(num_trials)
|
||||
big_bracket = sched._state["bracket"]
|
||||
for trl in big_bracket.current_trials():
|
||||
@@ -342,12 +342,7 @@ class HyperbandSuite(unittest.TestCase):
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
mock_runner.process_action(trl, action)
|
||||
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
|
||||
for x in range(100):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units + x, 10))
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
self.assertEqual(action, TrialScheduler.STOP)
|
||||
|
||||
def testTrialErrored(self):
|
||||
"""If a trial errored, make sure successive halving still happens"""
|
||||
|
||||
Reference in New Issue
Block a user