diff --git a/python/ray/rllib/tuned_examples/hyperband-cartpole.yaml b/python/ray/rllib/tuned_examples/hyperband-cartpole.yaml new file mode 100644 index 000000000..2f7bda5e8 --- /dev/null +++ b/python/ray/rllib/tuned_examples/hyperband-cartpole.yaml @@ -0,0 +1,16 @@ +cartpole-ppo: + env: CartPole-v0 + alg: PPO + num_trials: 20 + stop: + episode_reward_mean: 200 + time_total_s: 180 + resources: + cpu: 2 + driver_cpu_limit: 1 + config: + num_workers: 1 + num_sgd_iter: + grid_search: [1, 4] + sgd_batchsize: + grid_search: [128, 256, 512] diff --git a/python/ray/tune/hyperband.py b/python/ray/tune/hyperband.py new file mode 100644 index 000000000..8b98bac81 --- /dev/null +++ b/python/ray/tune/hyperband.py @@ -0,0 +1,313 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler +from ray.tune.trial import Trial + + +def calculate_bracket_count(max_iter, eta): + return int(np.log(max_iter)/np.log(eta)) + 1 + + +class HyperBandScheduler(FIFOScheduler): + """Implements HyperBand. + + This implementation contains 3 logical levels. + Each HyperBand iteration is a "band". There can be multiple + bands running at once, and there can be 1 band that is incomplete. + + In each band, there are at most `s` + 1 brackets. + `s` is a value determined by given parameters, and assigned on + a cyclic basis. + + In each bracket, there are at most `n(s)` trials, indicating that + `n` is a function of `s`. These trials go through a series of + halving procedures, dropping lowest performers. Multiple + brackets are running at once. + + Trials added will be inserted into the most recent bracket + and band and will spill over to new brackets/bands accordingly. + """ + + def __init__(self, max_iter, eta=3): + """ + args: + max_iter (int): maximum iterations per configuration + eta (int): # defines downsampling rate (default=3) + """ + assert max_iter > 0, "Max Iterations not valid!" + assert eta > 1, "Downsampling rate (eta) not valid!" + + FIFOScheduler.__init__(self) + self._eta = eta + self._s_max_1 = s_max_1 = calculate_bracket_count(max_iter, eta) + # total number of iterations per execution of Succesive Halving (n,r) + B = s_max_1 * max_iter + # bracket trial count total + self._get_n0 = lambda s: int(np.ceil(B/max_iter/(s+1)*eta**s)) + # bracket initial iterations + self._get_r0 = lambda s: int(max_iter*eta**(-s)) + self._hyperbands = [[]] # list of hyperband iterations + self._trial_info = {} # Stores Trial -> Bracket, Band Iteration + + # Tracks state for new trial add + self._state = {"bracket": None, + "band_idx": 0} + self._num_stopped = 0 + + def on_trial_add(self, trial_runner, trial): + """On a new trial add, if current bracket is not filled, + add to current bracket. Else, if current hp iteration is not filled, + create new bracket, add to current bracket. + Else, create new iteration, create new bracket, add to bracket. + + TODO(rliaw): This is messy.""" + + cur_bracket = self._state["bracket"] + cur_band = self._hyperbands[self._state["band_idx"]] + if cur_bracket is None or cur_bracket.filled(): + + # if current iteration is filled, create new iteration + if self._cur_band_filled(): + cur_band = [] + self._hyperbands.append(cur_band) + self._state["band_idx"] += 1 + + # cur_band will always be less than s_max or else filled + s = self._s_max_1 - len(cur_band) - 1 + assert s >= 0, "Current band is filled but adding bracket!" + + # create new bracket + cur_bracket = Bracket(self._get_n0(s), + self._get_r0(s), self._eta, s) + cur_band.append(cur_bracket) + self._state["bracket"] = cur_bracket + + self._state["bracket"].add_trial(trial) + self._trial_info[trial] = cur_bracket, self._state["band_idx"] + + def _cur_band_filled(self): + """Checks if the current band is filled. + + The size of the current band should be equal to s_max_1""" + + cur_band = self._hyperbands[self._state["band_idx"]] + return len(cur_band) == self._s_max_1 + + def on_trial_result(self, trial_runner, trial, result): + """If bracket is finished, all trials will be stopped. + + If a given trial finishes and bracket iteration is not done, + the trial will be paused and resources will be given up. + When bracket iteration is done, Trials will be successively halved, + and during each halving phase, bad trials will be stopped while good + trials will return to "PENDING". This scheduler will not start trials + but will stop trials. The current running trial will not be handled, + as the trialrunner will be given control to handle it. + + # TODO(rliaw) should be only called if trial has not errored""" + bracket, _ = self._trial_info[trial] + bracket.update_trial_stats(trial, result) + if bracket.continue_trial(trial): + return TrialScheduler.CONTINUE + + signal = TrialScheduler.PAUSE + + if bracket.cur_iter_done(): + if bracket.finished(): + self._cleanup_bracket(trial_runner, bracket) + return TrialScheduler.STOP + # what if bracket is done and trial not completed? + good, bad = bracket.successive_halving() + # kill bad trials + for t in bad: + self._num_stopped += 1 + if t.status == Trial.PAUSED: + trial_runner._stop_trial(t) + bracket.cleanup_trial_early(t) + elif t is trial: + signal = TrialScheduler.STOP + else: + raise Exception("Trial with unexpected status encountered") + + # ready the good trials + for t in good: + if t.status == Trial.PAUSED: + t.unpause() + elif t is trial: + signal = TrialScheduler.CONTINUE + else: + raise Exception("Trial with unexpected status encountered") + + return signal + + def _cleanup_bracket(self, trial_runner, bracket): + """Cleans up bracket after bracket is completely finished. + + Bracket information will only be cleaned up after the trialrunner has + finished its bookkeeping.""" + for t in bracket.current_trials(): + if t.status == Trial.PAUSED: + trial_runner._stop_trial(t) + bracket.cleanup_trial_early(t) + + def on_trial_complete(self, trial_runner, trial, result): + """Cleans up trial info from bracket if trial completed early. + + Bracket information will only be cleaned up after the trialrunner has + finished its bookkeeping.""" + bracket, _ = self._trial_info[trial] + bracket.cleanup_trial_early(trial) + + def on_trial_error(self, trial_runner, trial): + """Cleans up trial info from bracket if trial errored early. + + Bracket information will only be cleaned up after the trialrunner has + finished its bookkeeping.""" + bracket, _ = self._trial_info[trial] + bracket.cleanup_trial_early(trial) + + def choose_trial_to_run(self, trial_runner, *args): + """Fair scheduling within iteration by completion percentage. + List of trials not used since all trials are tracked as state + of scheduler. + + If iteration is occupied (ie, no trials to run), then look into + next iteration.""" + for hyperband in self._hyperbands: + for bracket in sorted(hyperband, + key=lambda b: b.completion_percentage()): + for trial in bracket.current_trials(): + if (trial.status == Trial.PENDING and + trial_runner.has_resources(trial.resources)): + return trial + return None + + def debug_string(self): + return " ".join([ + "Using HyperBand:", + "num_stopped={}".format(self._num_stopped), + "brackets={}".format(sum(len(band) for band in self._hyperbands))]) + + +class Bracket(): + """Logical object for tracking Hyperband bracket progress. Keeps track + of proper parameters as designated by HyperBand. + + Also keeps track of progress to ensure good scheduling. + """ + def __init__(self, max_trials, init_iters, eta, s): + self._live_trials = {} # stores (result, itrs left before halving) + self._all_trials = [] + self._n = self._n0 = max_trials + self._r = self._r0 = init_iters + self._cumul_r = self._r0 + self._eta = eta + self._halves = s + + self._total_work = self._calculate_total_work(self._n0, self._r0, s) + self._completed_progress = 0 + + def add_trial(self, trial): + """Add trial to bracket assuming bracket is not filled. + + At a later iteration, a newly added trial will be given equal + opportunity to catch up.""" + assert not self.filled(), "Cannot add trial to filled bracket!" + self._live_trials[trial] = (None, self._cumul_r) + self._all_trials.append(trial) + + def cur_iter_done(self): + """Checks if all iterations have completed. + + TODO(rliaw): also check that `t.iterations == self._r`""" + all_done = all(itr == 0 for _, itr in self._live_trials.values()) + return all_done + + def finished(self): + return self._halves == 0 and self.cur_iter_done() + + def current_trials(self): + return list(self._live_trials) + + def continue_trial(self, trial): + _, itr = self._live_trials[trial] + if itr > 0: + return True + else: + return False + + def filled(self): + """We will only let new trials be added at current level, + minimizing the need to backtrack and bookkeep previous medians""" + return len(self._live_trials) == self._n + + def successive_halving(self): + assert self._halves > 0 + self._halves -= 1 + self._n /= self._eta + self._n = int(np.ceil(self._n)) + self._r *= self._eta + self._r = int(np.ceil(self._r)) + self._cumul_r += self._r + sorted_trials = sorted( + self._live_trials, + key=lambda t: self._live_trials[t][0].episode_reward_mean) + + good, bad = sorted_trials[-self._n:], sorted_trials[:-self._n] + + # reset good trials to track updated iterations + for t in good: + res, old_itr = self._live_trials[t] + self._live_trials[t] = (res, self._r) + return good, bad + + def update_trial_stats(self, trial, result): + """Update result for trial. Called after trial has finished + an iteration - will decrement iteration count. + + TODO(rliaw): The other alternative is to keep the trials + in and make sure they're not set as pending later.""" + + assert trial in self._live_trials + _, itr = self._live_trials[trial] + assert itr > 0 + self._live_trials[trial] = (result, itr - 1) + self._completed_progress += 1 + + def cleanup_trial_early(self, trial): + """Clean up statistics tracking for trial that terminated early. + + This may cause bad trials to continue for a long time, in the case + where all the good trials finish early and there are only bad trials + left in a bracket with a large max-iteration.""" + assert trial in self._live_trials + del self._live_trials[trial] + + def completion_percentage(self): + """Returns a progress metric. + + This will not be always finish with 100 since dead trials + are dropped.""" + return self._completed_progress / self._total_work + + def _calculate_total_work(self, n, r, s): + work = 0 + for i in range(s+1): + work += int(n) * int(r) + n /= self._eta + n = int(np.ceil(n)) + r *= self._eta + return work + + def __repr__(self): + status = ", ".join([ + "n={}".format(self._n), + "r={}".format(self._r), + "progress={}".format(self.completion_percentage()) + ]) + trials = ", ".join([t.status for t in self._live_trials]) + return "Bracket({})[{}]".format(status, trials) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 0bcb04016..c77d7a182 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -142,12 +142,21 @@ class Trial(object): experiment. This results in a state similar to TERMINATED.""" assert self.status == Trial.RUNNING, self.status - self.checkpoint() - self.stop() - self.status = Trial.PAUSED + try: + self.checkpoint() + self.stop() + self.status = Trial.PAUSED + except Exception: + print("Error pausing agent:", traceback.format_exc()) + self.status = Trial.ERROR + + def unpause(self): + """Sets PAUSED trial to pending to allow scheduler to start.""" + assert self.status == Trial.PAUSED, self.status + self.status = Trial.PENDING def resume(self): - """Resume PAUSED tasks. This is a blocking call.""" + """Resume PAUSED trials. This is a blocking call.""" assert self.status == Trial.PAUSED, self.status self.start() diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 0e4f14aff..d8cc83442 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -67,8 +67,8 @@ class TrialRunner(object): ("Insufficient cluster resources to launch trial", (trial.resources, self._avail_resources)) elif trial.status == Trial.PAUSED: - assert False, "There are paused trials, but no more \ - pending trials with sufficient resources." + assert False, "There are paused trials, but no more "\ + "pending trials with sufficient resources." assert False, "Called step when all trials finished?" def get_trials(self): @@ -84,7 +84,7 @@ class TrialRunner(object): Trials may be added at any time. """ - + self._scheduler_alg.on_trial_add(self, trial) self._trials.append(trial) def debug_string(self): @@ -168,6 +168,7 @@ class TrialRunner(object): except Exception: print("Error processing event:", traceback.format_exc()) if trial.status == Trial.RUNNING: + self._scheduler_alg.on_trial_error(self, trial) self._stop_trial(trial, error=True) def _get_runnable(self): @@ -186,12 +187,18 @@ class TrialRunner(object): assert self._committed_resources.gpu >= 0 def _stop_trial(self, trial, error=False): + """Only returns resources if resources allocated.""" + prior_status = trial.status trial.stop(error=error) - self._return_resources(trial.resources) + if prior_status == Trial.RUNNING: + self._return_resources(trial.resources) def _pause_trial(self, trial): + """Only returns resources if resources allocated.""" + prior_status = trial.status trial.pause() - self._return_resources(trial.resources) + if prior_status == Trial.RUNNING: + self._return_resources(trial.resources) def _update_avail_resources(self): clients = ray.global_state.client_table() diff --git a/python/ray/tune/trial_scheduler.py b/python/ray/tune/trial_scheduler.py index e8401440c..d8d932ea1 100644 --- a/python/ray/tune/trial_scheduler.py +++ b/python/ray/tune/trial_scheduler.py @@ -12,6 +12,18 @@ class TrialScheduler(object): PAUSE = "PAUSE" STOP = "STOP" + def on_trial_add(self, trial_runner, trial): + """Called when a new trial is added to the trial runner.""" + + raise NotImplementedError + + def on_trial_error(self, trial_runner, trial): + """Notification for the error of trial. + + This will only be called when the trial is in the RUNNING state.""" + + raise NotImplementedError + def on_trial_result(self, trial_runner, trial, result): """Called on each intermediate result returned by a trial. @@ -44,6 +56,12 @@ class TrialScheduler(object): class FIFOScheduler(TrialScheduler): """Simple scheduler that just runs trials in submission order.""" + def on_trial_add(self, trial_runner, trial): + pass + + def on_trial_error(self, trial_runner, trial): + pass + def on_trial_result(self, trial_runner, trial, result): return TrialScheduler.CONTINUE diff --git a/test/trial_scheduler_test.py b/test/trial_scheduler_test.py index c1767d3b7..1839e967e 100644 --- a/test/trial_scheduler_test.py +++ b/test/trial_scheduler_test.py @@ -7,6 +7,7 @@ import unittest from ray.tune.result import TrainingResult from ray.tune.trial import Trial from ray.tune.trial_scheduler import MedianStoppingRule, TrialScheduler +from ray.tune.hyperband import HyperBandScheduler def result(t, rew): @@ -120,5 +121,210 @@ class EarlyStoppingSuite(unittest.TestCase): TrialScheduler.CONTINUE) +class _MockTrialRunner(): + def _stop_trial(self, trial): + trial.stop() + + def has_resources(self, resources): + return True + + def _pause_trial(self, trial): + trial.status = Trial.PAUSED + + def _launch_trial(self, trial): + trial.status = Trial.RUNNING + + +class HyperbandSuite(unittest.TestCase): + def basicSetup(self): + """s_max_1 = 3; + brackets: iter (n, r) | iter (n, r) | iter (n, r) + (9, 1) -> (3, 3) -> (1, 9) + (9, 1) -> (3, 3) -> (1, 9) + """ + + sched = HyperBandScheduler(9, eta=3) + for i in range(17): + t = Trial("t%d" % i, "__fake") + sched.on_trial_add(None, t) + + self.assertEqual(len(sched._hyperbands), 1) + self.assertEqual(sched._cur_band_filled(), True) + + filled_band = sched._hyperbands[0] + for bracket in filled_band: + self.assertEqual(bracket.filled(), True) + return sched + + def advancedSetup(self): + sched = self.basicSetup() + for i in range(3): + t = Trial("t%d" % (i + 20), "__fake") + sched.on_trial_add(None, t) + + self.assertEqual(sched._cur_band_filled(), False) + + unfilled_band = sched._hyperbands[1] + self.assertEqual(len(unfilled_band), 1) + self.assertEqual(len(sched._hyperbands[1]), 1) + bracket = unfilled_band[0] + self.assertEqual(bracket.filled(), False) + self.assertEqual(len(bracket.current_trials()), 3) + + return sched + + def testBasicHalving(self): + sched = self.advancedSetup() + mock_runner = _MockTrialRunner() + filled_band = sched._hyperbands[0] + big_bracket = filled_band[0] + bracket_trials = big_bracket.current_trials() + + for t in bracket_trials: + mock_runner._launch_trial(t) + + for i, t in enumerate(bracket_trials): + if i == len(bracket_trials) - 1: + break + self.assertEqual( + TrialScheduler.PAUSE, + sched.on_trial_result(mock_runner, t, result(i, 10))) + mock_runner._pause_trial(t) + self.assertEqual( + TrialScheduler.CONTINUE, + sched.on_trial_result( + mock_runner, bracket_trials[-1], result(7, 12))) + + def testSuccessiveHalving(self): + sched = HyperBandScheduler(9, eta=3) + for i in range(9): + t = Trial("t%d" % i, "__fake") + sched.on_trial_add(None, t) + filled_band = sched._hyperbands[0] + big_bracket = filled_band[0] + mock_runner = _MockTrialRunner() + + current_length = len(big_bracket.current_trials()) + for i in range(current_length): + trl = sched.choose_trial_to_run(mock_runner) + mock_runner._launch_trial(trl) + while True: + status = sched.on_trial_result(mock_runner, trl, result(1, 10)) + if status == TrialScheduler.CONTINUE: + continue + elif status == TrialScheduler.PAUSE: + mock_runner._pause_trial(trl) + break + + def testBasicRun(self): + sched = self.advancedSetup() + mock_runner = _MockTrialRunner() + trl = sched.choose_trial_to_run(mock_runner) + while trl: + if sched._trial_info[trl][1] > 0: + first_band = sched._hyperbands[0] + trials = [t for b in first_band for t in b._live_trials] + self.assertEqual( + all(t.status == Trial.RUNNING for t in trials), + True) + mock_runner._launch_trial(trl) + res = sched.on_trial_result(mock_runner, trl, result(1, 10)) + if res is TrialScheduler.PAUSE: + mock_runner._pause_trial(trl) + trl = sched.choose_trial_to_run(mock_runner) + + self.assertEqual( + all(t.status == Trial.RUNNING for t in trials), True) + + def testTrialErrored(self): + sched = HyperBandScheduler(9, eta=3) + t1 = Trial("t1", "__fake") + t2 = Trial("t2", "__fake") + sched.on_trial_add(None, t1) + sched.on_trial_add(None, t2) + mock_runner = _MockTrialRunner() + filled_band = sched._hyperbands[0] + big_bracket = filled_band[0] + bracket_trials = big_bracket.current_trials() + + for t in bracket_trials: + mock_runner._launch_trial(t) + + sched.on_trial_error(mock_runner, t2) + self.assertEqual( + TrialScheduler.CONTINUE, + sched.on_trial_result(mock_runner, t1, result(3, 10))) + + def testTrialEndedEarly(self): + sched = HyperBandScheduler(9, eta=3) + t1 = Trial("t1", "__fake") + t2 = Trial("t2", "__fake") + sched.on_trial_add(None, t1) + sched.on_trial_add(None, t2) + mock_runner = _MockTrialRunner() + filled_band = sched._hyperbands[0] + big_bracket = filled_band[0] + bracket_trials = big_bracket.current_trials() + + for t in bracket_trials: + mock_runner._launch_trial(t) + + sched.on_trial_complete(mock_runner, t2, result(5, 10)) + self.assertEqual( + TrialScheduler.CONTINUE, + sched.on_trial_result(mock_runner, t1, result(3, 12))) + + def testAddAfterHalf(self): + sched = HyperBandScheduler(9, eta=3) + for i in range(2): + t = Trial("t%d" % i, "__fake") + sched.on_trial_add(None, t) + mock_runner = _MockTrialRunner() + filled_band = sched._hyperbands[0] + big_bracket = filled_band[0] + bracket_trials = big_bracket.current_trials() + + for t in bracket_trials: + mock_runner._launch_trial(t) + + for i, t in enumerate(bracket_trials): + if i == len(bracket_trials) - 1: + break + self.assertEqual( + TrialScheduler.PAUSE, + sched.on_trial_result(mock_runner, t, result(i, 10))) + mock_runner._pause_trial(t) + self.assertEqual( + TrialScheduler.CONTINUE, + sched.on_trial_result( + mock_runner, bracket_trials[-1], result(7, 12))) + t = Trial("t%d" % 5, "__fake") + sched.on_trial_add(None, t) + self.assertEqual(4, big_bracket._live_trials[t][1]) + + def testDone(self): + sched = HyperBandScheduler(3, eta=3) + mock_runner = _MockTrialRunner() + trials = [Trial("t%d" % i, "__fake") for i in range(5)] + for t in trials: + sched.on_trial_add(None, t) + + filled_band = sched._hyperbands[0] + brack = filled_band[1] + bracket_trials = brack.current_trials() + for t in bracket_trials: + mock_runner._launch_trial(t) + for i in range(3): + res = sched.on_trial_result( + mock_runner, bracket_trials[-1], result(i, 10)) + self.assertEqual(res, TrialScheduler.PAUSE) + mock_runner._pause_trial(bracket_trials[-1]) + for i in range(3): + res = sched.on_trial_result( + mock_runner, bracket_trials[-2], result(i, 10)) + self.assertEqual(res, TrialScheduler.STOP) + self.assertEqual(len(brack.current_trials()), 1) + + if __name__ == "__main__": unittest.main(verbosity=2)