mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:01:24 +08:00
@@ -61,9 +61,9 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
procedures will use this attribute.
|
||||
max_t (int): max time units per trial. Trials will be stopped after
|
||||
max_t time units (determined by time_attr) have passed.
|
||||
The HyperBand scheduler automatically tries to determine a
|
||||
reasonable number of brackets based on this. The scheduler will
|
||||
terminate trials after this time has passed.
|
||||
The scheduler will terminate trials after this time has passed.
|
||||
Note that this is different from the semantics of `max_t` as
|
||||
mentioned in the original HyperBand paper.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -73,6 +73,7 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
FIFOScheduler.__init__(self)
|
||||
self._eta = 3
|
||||
self._s_max_1 = 5
|
||||
self._max_t_attr = max_t
|
||||
# bracket max trials
|
||||
self._get_n0 = lambda s: int(
|
||||
np.ceil(self._s_max_1/(s+1) * self._eta**s))
|
||||
@@ -117,7 +118,7 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
retry = False
|
||||
cur_bracket = Bracket(
|
||||
self._time_attr, self._get_n0(s), self._get_r0(s),
|
||||
self._eta, s)
|
||||
self._max_t_attr, self._eta, s)
|
||||
cur_band.append(cur_bracket)
|
||||
self._state["bracket"] = cur_bracket
|
||||
|
||||
@@ -257,13 +258,14 @@ class Bracket():
|
||||
|
||||
Also keeps track of progress to ensure good scheduling.
|
||||
"""
|
||||
def __init__(self, time_attr, max_trials, init_t_attr, eta, s):
|
||||
def __init__(self, time_attr, max_trials, init_t_attr, max_t_attr, eta, s):
|
||||
self._live_trials = {} # maps trial -> current result
|
||||
self._all_trials = []
|
||||
self._time_attr = time_attr # attribute to
|
||||
|
||||
self._n = self._n0 = max_trials
|
||||
self._r = self._r0 = init_t_attr
|
||||
self._max_t_attr = max_t_attr
|
||||
self._cumul_r = self._r0
|
||||
|
||||
self._eta = eta
|
||||
@@ -314,8 +316,9 @@ class Bracket():
|
||||
self._halves -= 1
|
||||
self._n /= self._eta
|
||||
self._n = int(np.ceil(self._n))
|
||||
|
||||
self._r *= self._eta
|
||||
self._r = int((self._r))
|
||||
self._r = int(min(self._r, self._max_t_attr - self._cumul_r))
|
||||
self._cumul_r += self._r
|
||||
sorted_trials = sorted(
|
||||
self._live_trials,
|
||||
@@ -364,6 +367,8 @@ class Bracket():
|
||||
|
||||
This will not be always finish with 100 since dead trials
|
||||
are dropped."""
|
||||
if self.finished():
|
||||
return 1.0
|
||||
return self._completed_progress / self._total_work
|
||||
|
||||
def _get_result_time(self, result):
|
||||
@@ -373,18 +378,19 @@ class Bracket():
|
||||
|
||||
def _calculate_total_work(self, n, r, s):
|
||||
work = 0
|
||||
cumulative_r = r
|
||||
for i in range(s+1):
|
||||
work += int(n) * int(r)
|
||||
n /= self._eta
|
||||
n = int(np.ceil(n))
|
||||
r *= self._eta
|
||||
r = int(r)
|
||||
r = int(min(r, self._max_t_attr - cumulative_r))
|
||||
return work
|
||||
|
||||
def __repr__(self):
|
||||
status = ", ".join([
|
||||
"Max Size (n)={}".format(self._n),
|
||||
"Milestone (r)={}".format(self._r),
|
||||
"Milestone (r)={}".format(self._cumul_r),
|
||||
"completed={:.1%}".format(self.completion_percentage())
|
||||
])
|
||||
counts = collections.Counter([t.status for t in self._all_trials])
|
||||
|
||||
@@ -191,10 +191,10 @@ class HyperbandSuite(unittest.TestCase):
|
||||
|
||||
Bracketing is placed as follows:
|
||||
(5, 81);
|
||||
(8, 27) -> (3, 81);
|
||||
(15, 9) -> (5, 27) -> (2, 81);
|
||||
(34, 3) -> (12, 9) -> (4, 27) -> (2, 81);
|
||||
(81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 81);"""
|
||||
(8, 27) -> (3, 54);
|
||||
(15, 9) -> (5, 27) -> (2, 45);
|
||||
(34, 3) -> (12, 9) -> (4, 27) -> (2, 42);
|
||||
(81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 41);"""
|
||||
sched = HyperBandScheduler()
|
||||
for i in range(num_trials):
|
||||
t = Trial("__fake")
|
||||
|
||||
Reference in New Issue
Block a user