mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 10:01:11 +08:00
[tune] Make HyperBand Usable (#1215)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
cartpole-ppo:
|
||||
env: CartPole-v0
|
||||
alg: PPO
|
||||
num_trials: 20
|
||||
repeat: 3
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 180
|
||||
|
||||
@@ -8,13 +8,11 @@ from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
|
||||
def calculate_bracket_count(max_iter, eta):
|
||||
return int(np.log(max_iter)/np.log(eta)) + 1
|
||||
|
||||
|
||||
class HyperBandScheduler(FIFOScheduler):
|
||||
"""Implements HyperBand.
|
||||
|
||||
Blog post: https://people.eecs.berkeley.edu/~kjamieson/hyperband.html
|
||||
|
||||
This implementation contains 3 logical levels.
|
||||
Each HyperBand iteration is a "band". There can be multiple
|
||||
bands running at once, and there can be 1 band that is incomplete.
|
||||
@@ -30,26 +28,39 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
|
||||
Trials added will be inserted into the most recent bracket
|
||||
and band and will spill over to new brackets/bands accordingly.
|
||||
|
||||
This maintains the bracket size and max trial count per band
|
||||
to 5 and 117 respectively, which correspond to that of
|
||||
`max_attr=81, eta=3` from the blog post. Trials will fill up
|
||||
from smallest bracket to largest, with largest
|
||||
having the most rounds of successive halving.
|
||||
|
||||
Args:
|
||||
time_attr (str): The TrainingResult attr to use for comparing time.
|
||||
Note that you can pass in something non-temporal such as
|
||||
`training_iteration` as a measure of progress, the only requirement
|
||||
is that the attribute should increase monotonically.
|
||||
reward_attr (str): The TrainingResult objective value attribute. As
|
||||
with `time_attr`, this may refer to any objective value. Stopping
|
||||
procedures will use this attribute.
|
||||
max_t (int): max time units per trial. Trials will be stopped after
|
||||
max_t time units (determined by time_attr) have passed.
|
||||
The HyperBand scheduler automatically tries to determine a
|
||||
reasonable number of brackets based on this and eta.
|
||||
"""
|
||||
|
||||
def __init__(self, max_iter=200, eta=3):
|
||||
"""
|
||||
args:
|
||||
max_iter (int): maximum iterations per configuration
|
||||
eta (int): # defines downsampling rate (default=3)
|
||||
"""
|
||||
assert max_iter > 0, "Max Iterations not valid!"
|
||||
assert eta > 1, "Downsampling rate (eta) not valid!"
|
||||
|
||||
def __init__(
|
||||
self, time_attr='training_iteration',
|
||||
reward_attr='episode_reward_mean', max_t=81):
|
||||
assert max_t > 0, "Max (time_attr) not valid!"
|
||||
FIFOScheduler.__init__(self)
|
||||
self._eta = eta
|
||||
self._s_max_1 = s_max_1 = calculate_bracket_count(max_iter, eta)
|
||||
# total number of iterations per execution of Succesive Halving (n,r)
|
||||
B = s_max_1 * max_iter
|
||||
# bracket trial count total
|
||||
self._get_n0 = lambda s: int(np.ceil(B/max_iter/(s+1)*eta**s))
|
||||
self._eta = 3
|
||||
self._s_max_1 = 5
|
||||
# bracket max trials
|
||||
self._get_n0 = lambda s: int(
|
||||
np.ceil(self._s_max_1/(s+1) * self._eta**s))
|
||||
# bracket initial iterations
|
||||
self._get_r0 = lambda s: int(max_iter*eta**(-s))
|
||||
self._get_r0 = lambda s: int((max_t*self._eta**(-s)))
|
||||
self._hyperbands = [[]] # list of hyperband iterations
|
||||
self._trial_info = {} # Stores Trial -> Bracket, Band Iteration
|
||||
|
||||
@@ -57,6 +68,8 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
self._state = {"bracket": None,
|
||||
"band_idx": 0}
|
||||
self._num_stopped = 0
|
||||
self._reward_attr = reward_attr
|
||||
self._time_attr = time_attr
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
"""On a new trial add, if current bracket is not filled,
|
||||
@@ -67,22 +80,27 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
cur_bracket = self._state["bracket"]
|
||||
cur_band = self._hyperbands[self._state["band_idx"]]
|
||||
if cur_bracket is None or cur_bracket.filled():
|
||||
retry = True
|
||||
while retry:
|
||||
# if current iteration is filled, create new iteration
|
||||
if self._cur_band_filled():
|
||||
cur_band = []
|
||||
self._hyperbands.append(cur_band)
|
||||
self._state["band_idx"] += 1
|
||||
|
||||
# if current iteration is filled, create new iteration
|
||||
if self._cur_band_filled():
|
||||
cur_band = []
|
||||
self._hyperbands.append(cur_band)
|
||||
self._state["band_idx"] += 1
|
||||
|
||||
# cur_band will always be less than s_max_1 or else filled
|
||||
s = len(cur_band)
|
||||
assert s < self._s_max_1, "Current band is filled!"
|
||||
|
||||
# create new bracket
|
||||
cur_bracket = Bracket(self._get_n0(s),
|
||||
self._get_r0(s), self._eta, s)
|
||||
cur_band.append(cur_bracket)
|
||||
self._state["bracket"] = cur_bracket
|
||||
# cur_band will always be less than s_max_1 or else filled
|
||||
s = len(cur_band)
|
||||
assert s < self._s_max_1, "Current band is filled!"
|
||||
if self._get_r0(s) == 0:
|
||||
print("Bracket too small - Retrying...")
|
||||
cur_bracket = None
|
||||
else:
|
||||
retry = False
|
||||
cur_bracket = Bracket(
|
||||
self._time_attr, self._get_n0(s), self._get_r0(s),
|
||||
self._eta, s)
|
||||
cur_band.append(cur_bracket)
|
||||
self._state["bracket"] = cur_bracket
|
||||
|
||||
self._state["bracket"].add_trial(trial)
|
||||
self._trial_info[trial] = cur_bracket, self._state["band_idx"]
|
||||
@@ -128,9 +146,9 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
if bracket.cur_iter_done():
|
||||
if bracket.finished():
|
||||
self._cleanup_bracket(trial_runner, bracket)
|
||||
return TrialScheduler.STOP
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
good, bad = bracket.successive_halving()
|
||||
good, bad = bracket.successive_halving(self._reward_attr)
|
||||
# kill bad trials
|
||||
for t in bad:
|
||||
if t.status == Trial.PAUSED:
|
||||
@@ -141,14 +159,15 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
else:
|
||||
raise Exception("Trial with unexpected status encountered")
|
||||
|
||||
# ready the good trials
|
||||
# ready the good trials - if trial is too far ahead, don't continue
|
||||
for t in good:
|
||||
if t.status == Trial.PAUSED:
|
||||
t.unpause()
|
||||
elif t.status == Trial.RUNNING:
|
||||
action = TrialScheduler.CONTINUE
|
||||
else:
|
||||
if t.status not in [Trial.PAUSED, Trial.RUNNING]:
|
||||
raise Exception("Trial with unexpected status encountered")
|
||||
if bracket.continue_trial(t):
|
||||
if t.status == Trial.PAUSED:
|
||||
t.unpause()
|
||||
elif t.status == Trial.RUNNING:
|
||||
action = TrialScheduler.CONTINUE
|
||||
return action
|
||||
|
||||
def _cleanup_trial(self, trial_runner, t, bracket, hard=False):
|
||||
@@ -162,11 +181,14 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
bracket.cleanup_trial(t)
|
||||
|
||||
def _cleanup_bracket(self, trial_runner, bracket):
|
||||
"""Cleans up bracket after bracket is completely finished."""
|
||||
"""Cleans up bracket after bracket is completely finished.
|
||||
Lets the last trial continue to run until termination condition
|
||||
kicks in."""
|
||||
for trial in bracket.current_trials():
|
||||
self._cleanup_trial(
|
||||
trial_runner, trial, bracket,
|
||||
hard=(trial.status == Trial.PAUSED))
|
||||
if (trial.status == Trial.PAUSED):
|
||||
self._cleanup_trial(
|
||||
trial_runner, trial, bracket,
|
||||
hard=True)
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
"""Cleans up trial info from bracket if trial completed early."""
|
||||
@@ -219,12 +241,15 @@ class Bracket():
|
||||
|
||||
Also keeps track of progress to ensure good scheduling.
|
||||
"""
|
||||
def __init__(self, max_trials, init_iters, eta, s):
|
||||
self._live_trials = {} # stores (result, itrs left before halving)
|
||||
def __init__(self, time_attr, max_trials, init_t_attr, eta, s):
|
||||
self._live_trials = {} # maps trial -> current result
|
||||
self._all_trials = []
|
||||
self._time_attr = time_attr # attribute to
|
||||
|
||||
self._n = self._n0 = max_trials
|
||||
self._r = self._r0 = init_iters
|
||||
self._r = self._r0 = init_t_attr
|
||||
self._cumul_r = self._r0
|
||||
|
||||
self._eta = eta
|
||||
self._halves = s
|
||||
|
||||
@@ -237,15 +262,15 @@ class Bracket():
|
||||
At a later iteration, a newly added trial will be given equal
|
||||
opportunity to catch up."""
|
||||
assert not self.filled(), "Cannot add trial to filled bracket!"
|
||||
self._live_trials[trial] = (None, self._cumul_r)
|
||||
self._live_trials[trial] = None
|
||||
self._all_trials.append(trial)
|
||||
|
||||
def cur_iter_done(self):
|
||||
"""Checks if all iterations have completed.
|
||||
|
||||
TODO(rliaw): also check that `t.iterations == self._r`"""
|
||||
all_done = all(itr == 0 for _, itr in self._live_trials.values())
|
||||
return all_done
|
||||
return all(self._get_result_time(result) >= self._cumul_r
|
||||
for result in self._live_trials.values())
|
||||
|
||||
def finished(self):
|
||||
return self._halves == 0 and self.cur_iter_done()
|
||||
@@ -254,8 +279,8 @@ class Bracket():
|
||||
return list(self._live_trials)
|
||||
|
||||
def continue_trial(self, trial):
|
||||
_, itr = self._live_trials[trial]
|
||||
if itr > 0:
|
||||
result = self._live_trials[trial]
|
||||
if self._get_result_time(result) < self._cumul_r:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@@ -265,24 +290,19 @@ class Bracket():
|
||||
minimizing the need to backtrack and bookkeep previous medians"""
|
||||
return len(self._live_trials) == self._n
|
||||
|
||||
def successive_halving(self):
|
||||
def successive_halving(self, reward_attr):
|
||||
assert self._halves > 0
|
||||
self._halves -= 1
|
||||
self._n /= self._eta
|
||||
self._n = int(np.ceil(self._n))
|
||||
self._r *= self._eta
|
||||
self._r = int(np.ceil(self._r))
|
||||
self._r = int((self._r))
|
||||
self._cumul_r += self._r
|
||||
sorted_trials = sorted(
|
||||
self._live_trials,
|
||||
key=lambda t: self._live_trials[t][0].episode_reward_mean)
|
||||
key=lambda t: getattr(self._live_trials[t], reward_attr))
|
||||
|
||||
good, bad = sorted_trials[-self._n:], sorted_trials[:-self._n]
|
||||
|
||||
# reset good trials to track updated iterations
|
||||
for t in good:
|
||||
res, old_itr = self._live_trials[t]
|
||||
self._live_trials[t] = (res, self._r)
|
||||
return good, bad
|
||||
|
||||
def update_trial_stats(self, trial, result):
|
||||
@@ -293,10 +313,13 @@ class Bracket():
|
||||
in and make sure they're not set as pending later."""
|
||||
|
||||
assert trial in self._live_trials
|
||||
_, itr = self._live_trials[trial]
|
||||
assert itr > 0
|
||||
self._live_trials[trial] = (result, itr - 1)
|
||||
self._completed_progress += 1
|
||||
assert self._get_result_time(result) >= 0
|
||||
|
||||
delta = self._get_result_time(result) - \
|
||||
self._get_result_time(self._live_trials[trial])
|
||||
assert delta >= 0
|
||||
self._completed_progress += delta
|
||||
self._live_trials[trial] = result
|
||||
|
||||
def cleanup_trial(self, trial):
|
||||
"""Clean up statistics tracking for terminated trials (either by force
|
||||
@@ -315,6 +338,11 @@ class Bracket():
|
||||
are dropped."""
|
||||
return self._completed_progress / self._total_work
|
||||
|
||||
def _get_result_time(self, result):
|
||||
if result is None:
|
||||
return 0
|
||||
return getattr(result, self._time_attr)
|
||||
|
||||
def _calculate_total_work(self, n, r, s):
|
||||
work = 0
|
||||
for i in range(s+1):
|
||||
@@ -322,6 +350,7 @@ class Bracket():
|
||||
n /= self._eta
|
||||
n = int(np.ceil(n))
|
||||
r *= self._eta
|
||||
r = int(r)
|
||||
return work
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
Reference in New Issue
Block a user