[tune] Make HyperBand Usable (#1215)

This commit is contained in:
Richard Liaw
2017-11-16 10:31:42 -08:00
committed by GitHub
parent 3a0206a1f4
commit eadb998643
3 changed files with 379 additions and 193 deletions
@@ -1,7 +1,7 @@
cartpole-ppo:
env: CartPole-v0
alg: PPO
num_trials: 20
repeat: 3
stop:
episode_reward_mean: 200
time_total_s: 180
+96 -67
View File
@@ -8,13 +8,11 @@ from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.tune.trial import Trial
def calculate_bracket_count(max_iter, eta):
return int(np.log(max_iter)/np.log(eta)) + 1
class HyperBandScheduler(FIFOScheduler):
"""Implements HyperBand.
Blog post: https://people.eecs.berkeley.edu/~kjamieson/hyperband.html
This implementation contains 3 logical levels.
Each HyperBand iteration is a "band". There can be multiple
bands running at once, and there can be 1 band that is incomplete.
@@ -30,26 +28,39 @@ class HyperBandScheduler(FIFOScheduler):
Trials added will be inserted into the most recent bracket
and band and will spill over to new brackets/bands accordingly.
This maintains the bracket size and max trial count per band
to 5 and 117 respectively, which correspond to that of
`max_attr=81, eta=3` from the blog post. Trials will fill up
from smallest bracket to largest, with largest
having the most rounds of successive halving.
Args:
time_attr (str): The TrainingResult attr to use for comparing time.
Note that you can pass in something non-temporal such as
`training_iteration` as a measure of progress, the only requirement
is that the attribute should increase monotonically.
reward_attr (str): The TrainingResult objective value attribute. As
with `time_attr`, this may refer to any objective value. Stopping
procedures will use this attribute.
max_t (int): max time units per trial. Trials will be stopped after
max_t time units (determined by time_attr) have passed.
The HyperBand scheduler automatically tries to determine a
reasonable number of brackets based on this and eta.
"""
def __init__(self, max_iter=200, eta=3):
"""
args:
max_iter (int): maximum iterations per configuration
eta (int): # defines downsampling rate (default=3)
"""
assert max_iter > 0, "Max Iterations not valid!"
assert eta > 1, "Downsampling rate (eta) not valid!"
def __init__(
self, time_attr='training_iteration',
reward_attr='episode_reward_mean', max_t=81):
assert max_t > 0, "Max (time_attr) not valid!"
FIFOScheduler.__init__(self)
self._eta = eta
self._s_max_1 = s_max_1 = calculate_bracket_count(max_iter, eta)
# total number of iterations per execution of Succesive Halving (n,r)
B = s_max_1 * max_iter
# bracket trial count total
self._get_n0 = lambda s: int(np.ceil(B/max_iter/(s+1)*eta**s))
self._eta = 3
self._s_max_1 = 5
# bracket max trials
self._get_n0 = lambda s: int(
np.ceil(self._s_max_1/(s+1) * self._eta**s))
# bracket initial iterations
self._get_r0 = lambda s: int(max_iter*eta**(-s))
self._get_r0 = lambda s: int((max_t*self._eta**(-s)))
self._hyperbands = [[]] # list of hyperband iterations
self._trial_info = {} # Stores Trial -> Bracket, Band Iteration
@@ -57,6 +68,8 @@ class HyperBandScheduler(FIFOScheduler):
self._state = {"bracket": None,
"band_idx": 0}
self._num_stopped = 0
self._reward_attr = reward_attr
self._time_attr = time_attr
def on_trial_add(self, trial_runner, trial):
"""On a new trial add, if current bracket is not filled,
@@ -67,22 +80,27 @@ class HyperBandScheduler(FIFOScheduler):
cur_bracket = self._state["bracket"]
cur_band = self._hyperbands[self._state["band_idx"]]
if cur_bracket is None or cur_bracket.filled():
retry = True
while retry:
# if current iteration is filled, create new iteration
if self._cur_band_filled():
cur_band = []
self._hyperbands.append(cur_band)
self._state["band_idx"] += 1
# if current iteration is filled, create new iteration
if self._cur_band_filled():
cur_band = []
self._hyperbands.append(cur_band)
self._state["band_idx"] += 1
# cur_band will always be less than s_max_1 or else filled
s = len(cur_band)
assert s < self._s_max_1, "Current band is filled!"
# create new bracket
cur_bracket = Bracket(self._get_n0(s),
self._get_r0(s), self._eta, s)
cur_band.append(cur_bracket)
self._state["bracket"] = cur_bracket
# cur_band will always be less than s_max_1 or else filled
s = len(cur_band)
assert s < self._s_max_1, "Current band is filled!"
if self._get_r0(s) == 0:
print("Bracket too small - Retrying...")
cur_bracket = None
else:
retry = False
cur_bracket = Bracket(
self._time_attr, self._get_n0(s), self._get_r0(s),
self._eta, s)
cur_band.append(cur_bracket)
self._state["bracket"] = cur_bracket
self._state["bracket"].add_trial(trial)
self._trial_info[trial] = cur_bracket, self._state["band_idx"]
@@ -128,9 +146,9 @@ class HyperBandScheduler(FIFOScheduler):
if bracket.cur_iter_done():
if bracket.finished():
self._cleanup_bracket(trial_runner, bracket)
return TrialScheduler.STOP
return TrialScheduler.CONTINUE
good, bad = bracket.successive_halving()
good, bad = bracket.successive_halving(self._reward_attr)
# kill bad trials
for t in bad:
if t.status == Trial.PAUSED:
@@ -141,14 +159,15 @@ class HyperBandScheduler(FIFOScheduler):
else:
raise Exception("Trial with unexpected status encountered")
# ready the good trials
# ready the good trials - if trial is too far ahead, don't continue
for t in good:
if t.status == Trial.PAUSED:
t.unpause()
elif t.status == Trial.RUNNING:
action = TrialScheduler.CONTINUE
else:
if t.status not in [Trial.PAUSED, Trial.RUNNING]:
raise Exception("Trial with unexpected status encountered")
if bracket.continue_trial(t):
if t.status == Trial.PAUSED:
t.unpause()
elif t.status == Trial.RUNNING:
action = TrialScheduler.CONTINUE
return action
def _cleanup_trial(self, trial_runner, t, bracket, hard=False):
@@ -162,11 +181,14 @@ class HyperBandScheduler(FIFOScheduler):
bracket.cleanup_trial(t)
def _cleanup_bracket(self, trial_runner, bracket):
"""Cleans up bracket after bracket is completely finished."""
"""Cleans up bracket after bracket is completely finished.
Lets the last trial continue to run until termination condition
kicks in."""
for trial in bracket.current_trials():
self._cleanup_trial(
trial_runner, trial, bracket,
hard=(trial.status == Trial.PAUSED))
if (trial.status == Trial.PAUSED):
self._cleanup_trial(
trial_runner, trial, bracket,
hard=True)
def on_trial_complete(self, trial_runner, trial, result):
"""Cleans up trial info from bracket if trial completed early."""
@@ -219,12 +241,15 @@ class Bracket():
Also keeps track of progress to ensure good scheduling.
"""
def __init__(self, max_trials, init_iters, eta, s):
self._live_trials = {} # stores (result, itrs left before halving)
def __init__(self, time_attr, max_trials, init_t_attr, eta, s):
self._live_trials = {} # maps trial -> current result
self._all_trials = []
self._time_attr = time_attr # attribute to
self._n = self._n0 = max_trials
self._r = self._r0 = init_iters
self._r = self._r0 = init_t_attr
self._cumul_r = self._r0
self._eta = eta
self._halves = s
@@ -237,15 +262,15 @@ class Bracket():
At a later iteration, a newly added trial will be given equal
opportunity to catch up."""
assert not self.filled(), "Cannot add trial to filled bracket!"
self._live_trials[trial] = (None, self._cumul_r)
self._live_trials[trial] = None
self._all_trials.append(trial)
def cur_iter_done(self):
"""Checks if all iterations have completed.
TODO(rliaw): also check that `t.iterations == self._r`"""
all_done = all(itr == 0 for _, itr in self._live_trials.values())
return all_done
return all(self._get_result_time(result) >= self._cumul_r
for result in self._live_trials.values())
def finished(self):
return self._halves == 0 and self.cur_iter_done()
@@ -254,8 +279,8 @@ class Bracket():
return list(self._live_trials)
def continue_trial(self, trial):
_, itr = self._live_trials[trial]
if itr > 0:
result = self._live_trials[trial]
if self._get_result_time(result) < self._cumul_r:
return True
else:
return False
@@ -265,24 +290,19 @@ class Bracket():
minimizing the need to backtrack and bookkeep previous medians"""
return len(self._live_trials) == self._n
def successive_halving(self):
def successive_halving(self, reward_attr):
assert self._halves > 0
self._halves -= 1
self._n /= self._eta
self._n = int(np.ceil(self._n))
self._r *= self._eta
self._r = int(np.ceil(self._r))
self._r = int((self._r))
self._cumul_r += self._r
sorted_trials = sorted(
self._live_trials,
key=lambda t: self._live_trials[t][0].episode_reward_mean)
key=lambda t: getattr(self._live_trials[t], reward_attr))
good, bad = sorted_trials[-self._n:], sorted_trials[:-self._n]
# reset good trials to track updated iterations
for t in good:
res, old_itr = self._live_trials[t]
self._live_trials[t] = (res, self._r)
return good, bad
def update_trial_stats(self, trial, result):
@@ -293,10 +313,13 @@ class Bracket():
in and make sure they're not set as pending later."""
assert trial in self._live_trials
_, itr = self._live_trials[trial]
assert itr > 0
self._live_trials[trial] = (result, itr - 1)
self._completed_progress += 1
assert self._get_result_time(result) >= 0
delta = self._get_result_time(result) - \
self._get_result_time(self._live_trials[trial])
assert delta >= 0
self._completed_progress += delta
self._live_trials[trial] = result
def cleanup_trial(self, trial):
"""Clean up statistics tracking for terminated trials (either by force
@@ -315,6 +338,11 @@ class Bracket():
are dropped."""
return self._completed_progress / self._total_work
def _get_result_time(self, result):
if result is None:
return 0
return getattr(result, self._time_attr)
def _calculate_total_work(self, n, r, s):
work = 0
for i in range(s+1):
@@ -322,6 +350,7 @@ class Bracket():
n /= self._eta
n = int(np.ceil(n))
r *= self._eta
r = int(r)
return work
def __repr__(self):