From 96d7938fc4b24fb809cd615ed28e6de295f82810 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sat, 3 Mar 2018 13:00:55 -0800 Subject: [PATCH] [tune] Hyperband Max Iter Fix (#1620) * nits * cumul r * docs * min --- python/ray/tune/hyperband.py | 22 +++++++++++++------- python/ray/tune/test/trial_scheduler_test.py | 8 +++---- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/python/ray/tune/hyperband.py b/python/ray/tune/hyperband.py index f0efe99cb..992512461 100644 --- a/python/ray/tune/hyperband.py +++ b/python/ray/tune/hyperband.py @@ -61,9 +61,9 @@ class HyperBandScheduler(FIFOScheduler): procedures will use this attribute. max_t (int): max time units per trial. Trials will be stopped after max_t time units (determined by time_attr) have passed. - The HyperBand scheduler automatically tries to determine a - reasonable number of brackets based on this. The scheduler will - terminate trials after this time has passed. + The scheduler will terminate trials after this time has passed. + Note that this is different from the semantics of `max_t` as + mentioned in the original HyperBand paper. """ def __init__( @@ -73,6 +73,7 @@ class HyperBandScheduler(FIFOScheduler): FIFOScheduler.__init__(self) self._eta = 3 self._s_max_1 = 5 + self._max_t_attr = max_t # bracket max trials self._get_n0 = lambda s: int( np.ceil(self._s_max_1/(s+1) * self._eta**s)) @@ -117,7 +118,7 @@ class HyperBandScheduler(FIFOScheduler): retry = False cur_bracket = Bracket( self._time_attr, self._get_n0(s), self._get_r0(s), - self._eta, s) + self._max_t_attr, self._eta, s) cur_band.append(cur_bracket) self._state["bracket"] = cur_bracket @@ -257,13 +258,14 @@ class Bracket(): Also keeps track of progress to ensure good scheduling. """ - def __init__(self, time_attr, max_trials, init_t_attr, eta, s): + def __init__(self, time_attr, max_trials, init_t_attr, max_t_attr, eta, s): self._live_trials = {} # maps trial -> current result self._all_trials = [] self._time_attr = time_attr # attribute to self._n = self._n0 = max_trials self._r = self._r0 = init_t_attr + self._max_t_attr = max_t_attr self._cumul_r = self._r0 self._eta = eta @@ -314,8 +316,9 @@ class Bracket(): self._halves -= 1 self._n /= self._eta self._n = int(np.ceil(self._n)) + self._r *= self._eta - self._r = int((self._r)) + self._r = int(min(self._r, self._max_t_attr - self._cumul_r)) self._cumul_r += self._r sorted_trials = sorted( self._live_trials, @@ -364,6 +367,8 @@ class Bracket(): This will not be always finish with 100 since dead trials are dropped.""" + if self.finished(): + return 1.0 return self._completed_progress / self._total_work def _get_result_time(self, result): @@ -373,18 +378,19 @@ class Bracket(): def _calculate_total_work(self, n, r, s): work = 0 + cumulative_r = r for i in range(s+1): work += int(n) * int(r) n /= self._eta n = int(np.ceil(n)) r *= self._eta - r = int(r) + r = int(min(r, self._max_t_attr - cumulative_r)) return work def __repr__(self): status = ", ".join([ "Max Size (n)={}".format(self._n), - "Milestone (r)={}".format(self._r), + "Milestone (r)={}".format(self._cumul_r), "completed={:.1%}".format(self.completion_percentage()) ]) counts = collections.Counter([t.status for t in self._all_trials]) diff --git a/python/ray/tune/test/trial_scheduler_test.py b/python/ray/tune/test/trial_scheduler_test.py index 9ada2b04d..bdc8d1a2d 100644 --- a/python/ray/tune/test/trial_scheduler_test.py +++ b/python/ray/tune/test/trial_scheduler_test.py @@ -191,10 +191,10 @@ class HyperbandSuite(unittest.TestCase): Bracketing is placed as follows: (5, 81); - (8, 27) -> (3, 81); - (15, 9) -> (5, 27) -> (2, 81); - (34, 3) -> (12, 9) -> (4, 27) -> (2, 81); - (81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 81);""" + (8, 27) -> (3, 54); + (15, 9) -> (5, 27) -> (2, 45); + (34, 3) -> (12, 9) -> (4, 27) -> (2, 42); + (81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 41);""" sched = HyperBandScheduler() for i in range(num_trials): t = Trial("__fake")