[tune] Hyperband Max Iter Fix (#1620)

* nits

* cumul r

* docs

* min
This commit is contained in:
Richard Liaw
2018-03-03 13:00:55 -08:00
committed by Eric Liang
parent 6685d4c446
commit 96d7938fc4
2 changed files with 18 additions and 12 deletions
+14 -8
View File
@@ -61,9 +61,9 @@ class HyperBandScheduler(FIFOScheduler):
procedures will use this attribute.
max_t (int): max time units per trial. Trials will be stopped after
max_t time units (determined by time_attr) have passed.
The HyperBand scheduler automatically tries to determine a
reasonable number of brackets based on this. The scheduler will
terminate trials after this time has passed.
The scheduler will terminate trials after this time has passed.
Note that this is different from the semantics of `max_t` as
mentioned in the original HyperBand paper.
"""
def __init__(
@@ -73,6 +73,7 @@ class HyperBandScheduler(FIFOScheduler):
FIFOScheduler.__init__(self)
self._eta = 3
self._s_max_1 = 5
self._max_t_attr = max_t
# bracket max trials
self._get_n0 = lambda s: int(
np.ceil(self._s_max_1/(s+1) * self._eta**s))
@@ -117,7 +118,7 @@ class HyperBandScheduler(FIFOScheduler):
retry = False
cur_bracket = Bracket(
self._time_attr, self._get_n0(s), self._get_r0(s),
self._eta, s)
self._max_t_attr, self._eta, s)
cur_band.append(cur_bracket)
self._state["bracket"] = cur_bracket
@@ -257,13 +258,14 @@ class Bracket():
Also keeps track of progress to ensure good scheduling.
"""
def __init__(self, time_attr, max_trials, init_t_attr, eta, s):
def __init__(self, time_attr, max_trials, init_t_attr, max_t_attr, eta, s):
self._live_trials = {} # maps trial -> current result
self._all_trials = []
self._time_attr = time_attr # attribute to
self._n = self._n0 = max_trials
self._r = self._r0 = init_t_attr
self._max_t_attr = max_t_attr
self._cumul_r = self._r0
self._eta = eta
@@ -314,8 +316,9 @@ class Bracket():
self._halves -= 1
self._n /= self._eta
self._n = int(np.ceil(self._n))
self._r *= self._eta
self._r = int((self._r))
self._r = int(min(self._r, self._max_t_attr - self._cumul_r))
self._cumul_r += self._r
sorted_trials = sorted(
self._live_trials,
@@ -364,6 +367,8 @@ class Bracket():
This will not be always finish with 100 since dead trials
are dropped."""
if self.finished():
return 1.0
return self._completed_progress / self._total_work
def _get_result_time(self, result):
@@ -373,18 +378,19 @@ class Bracket():
def _calculate_total_work(self, n, r, s):
work = 0
cumulative_r = r
for i in range(s+1):
work += int(n) * int(r)
n /= self._eta
n = int(np.ceil(n))
r *= self._eta
r = int(r)
r = int(min(r, self._max_t_attr - cumulative_r))
return work
def __repr__(self):
status = ", ".join([
"Max Size (n)={}".format(self._n),
"Milestone (r)={}".format(self._r),
"Milestone (r)={}".format(self._cumul_r),
"completed={:.1%}".format(self.completion_percentage())
])
counts = collections.Counter([t.status for t in self._all_trials])
+4 -4
View File
@@ -191,10 +191,10 @@ class HyperbandSuite(unittest.TestCase):
Bracketing is placed as follows:
(5, 81);
(8, 27) -> (3, 81);
(15, 9) -> (5, 27) -> (2, 81);
(34, 3) -> (12, 9) -> (4, 27) -> (2, 81);
(81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 81);"""
(8, 27) -> (3, 54);
(15, 9) -> (5, 27) -> (2, 45);
(34, 3) -> (12, 9) -> (4, 27) -> (2, 42);
(81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 41);"""
sched = HyperBandScheduler()
for i in range(num_trials):
t = Trial("__fake")