From eadb9986435d588dcdfb95624e0761b673961bec Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Thu, 16 Nov 2017 10:31:42 -0800 Subject: [PATCH] [tune] Make HyperBand Usable (#1215) --- .../tuned_examples/hyperband-cartpole.yaml | 2 +- python/ray/tune/hyperband.py | 163 ++++--- test/trial_scheduler_test.py | 407 ++++++++++++------ 3 files changed, 379 insertions(+), 193 deletions(-) diff --git a/python/ray/rllib/tuned_examples/hyperband-cartpole.yaml b/python/ray/rllib/tuned_examples/hyperband-cartpole.yaml index 2f7bda5e8..fa2d168c4 100644 --- a/python/ray/rllib/tuned_examples/hyperband-cartpole.yaml +++ b/python/ray/rllib/tuned_examples/hyperband-cartpole.yaml @@ -1,7 +1,7 @@ cartpole-ppo: env: CartPole-v0 alg: PPO - num_trials: 20 + repeat: 3 stop: episode_reward_mean: 200 time_total_s: 180 diff --git a/python/ray/tune/hyperband.py b/python/ray/tune/hyperband.py index b127bd4b2..9dda79877 100644 --- a/python/ray/tune/hyperband.py +++ b/python/ray/tune/hyperband.py @@ -8,13 +8,11 @@ from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler from ray.tune.trial import Trial -def calculate_bracket_count(max_iter, eta): - return int(np.log(max_iter)/np.log(eta)) + 1 - - class HyperBandScheduler(FIFOScheduler): """Implements HyperBand. + Blog post: https://people.eecs.berkeley.edu/~kjamieson/hyperband.html + This implementation contains 3 logical levels. Each HyperBand iteration is a "band". There can be multiple bands running at once, and there can be 1 band that is incomplete. @@ -30,26 +28,39 @@ class HyperBandScheduler(FIFOScheduler): Trials added will be inserted into the most recent bracket and band and will spill over to new brackets/bands accordingly. + + This maintains the bracket size and max trial count per band + to 5 and 117 respectively, which correspond to that of + `max_attr=81, eta=3` from the blog post. Trials will fill up + from smallest bracket to largest, with largest + having the most rounds of successive halving. + + Args: + time_attr (str): The TrainingResult attr to use for comparing time. + Note that you can pass in something non-temporal such as + `training_iteration` as a measure of progress, the only requirement + is that the attribute should increase monotonically. + reward_attr (str): The TrainingResult objective value attribute. As + with `time_attr`, this may refer to any objective value. Stopping + procedures will use this attribute. + max_t (int): max time units per trial. Trials will be stopped after + max_t time units (determined by time_attr) have passed. + The HyperBand scheduler automatically tries to determine a + reasonable number of brackets based on this and eta. """ - def __init__(self, max_iter=200, eta=3): - """ - args: - max_iter (int): maximum iterations per configuration - eta (int): # defines downsampling rate (default=3) - """ - assert max_iter > 0, "Max Iterations not valid!" - assert eta > 1, "Downsampling rate (eta) not valid!" - + def __init__( + self, time_attr='training_iteration', + reward_attr='episode_reward_mean', max_t=81): + assert max_t > 0, "Max (time_attr) not valid!" FIFOScheduler.__init__(self) - self._eta = eta - self._s_max_1 = s_max_1 = calculate_bracket_count(max_iter, eta) - # total number of iterations per execution of Succesive Halving (n,r) - B = s_max_1 * max_iter - # bracket trial count total - self._get_n0 = lambda s: int(np.ceil(B/max_iter/(s+1)*eta**s)) + self._eta = 3 + self._s_max_1 = 5 + # bracket max trials + self._get_n0 = lambda s: int( + np.ceil(self._s_max_1/(s+1) * self._eta**s)) # bracket initial iterations - self._get_r0 = lambda s: int(max_iter*eta**(-s)) + self._get_r0 = lambda s: int((max_t*self._eta**(-s))) self._hyperbands = [[]] # list of hyperband iterations self._trial_info = {} # Stores Trial -> Bracket, Band Iteration @@ -57,6 +68,8 @@ class HyperBandScheduler(FIFOScheduler): self._state = {"bracket": None, "band_idx": 0} self._num_stopped = 0 + self._reward_attr = reward_attr + self._time_attr = time_attr def on_trial_add(self, trial_runner, trial): """On a new trial add, if current bracket is not filled, @@ -67,22 +80,27 @@ class HyperBandScheduler(FIFOScheduler): cur_bracket = self._state["bracket"] cur_band = self._hyperbands[self._state["band_idx"]] if cur_bracket is None or cur_bracket.filled(): + retry = True + while retry: + # if current iteration is filled, create new iteration + if self._cur_band_filled(): + cur_band = [] + self._hyperbands.append(cur_band) + self._state["band_idx"] += 1 - # if current iteration is filled, create new iteration - if self._cur_band_filled(): - cur_band = [] - self._hyperbands.append(cur_band) - self._state["band_idx"] += 1 - - # cur_band will always be less than s_max_1 or else filled - s = len(cur_band) - assert s < self._s_max_1, "Current band is filled!" - - # create new bracket - cur_bracket = Bracket(self._get_n0(s), - self._get_r0(s), self._eta, s) - cur_band.append(cur_bracket) - self._state["bracket"] = cur_bracket + # cur_band will always be less than s_max_1 or else filled + s = len(cur_band) + assert s < self._s_max_1, "Current band is filled!" + if self._get_r0(s) == 0: + print("Bracket too small - Retrying...") + cur_bracket = None + else: + retry = False + cur_bracket = Bracket( + self._time_attr, self._get_n0(s), self._get_r0(s), + self._eta, s) + cur_band.append(cur_bracket) + self._state["bracket"] = cur_bracket self._state["bracket"].add_trial(trial) self._trial_info[trial] = cur_bracket, self._state["band_idx"] @@ -128,9 +146,9 @@ class HyperBandScheduler(FIFOScheduler): if bracket.cur_iter_done(): if bracket.finished(): self._cleanup_bracket(trial_runner, bracket) - return TrialScheduler.STOP + return TrialScheduler.CONTINUE - good, bad = bracket.successive_halving() + good, bad = bracket.successive_halving(self._reward_attr) # kill bad trials for t in bad: if t.status == Trial.PAUSED: @@ -141,14 +159,15 @@ class HyperBandScheduler(FIFOScheduler): else: raise Exception("Trial with unexpected status encountered") - # ready the good trials + # ready the good trials - if trial is too far ahead, don't continue for t in good: - if t.status == Trial.PAUSED: - t.unpause() - elif t.status == Trial.RUNNING: - action = TrialScheduler.CONTINUE - else: + if t.status not in [Trial.PAUSED, Trial.RUNNING]: raise Exception("Trial with unexpected status encountered") + if bracket.continue_trial(t): + if t.status == Trial.PAUSED: + t.unpause() + elif t.status == Trial.RUNNING: + action = TrialScheduler.CONTINUE return action def _cleanup_trial(self, trial_runner, t, bracket, hard=False): @@ -162,11 +181,14 @@ class HyperBandScheduler(FIFOScheduler): bracket.cleanup_trial(t) def _cleanup_bracket(self, trial_runner, bracket): - """Cleans up bracket after bracket is completely finished.""" + """Cleans up bracket after bracket is completely finished. + Lets the last trial continue to run until termination condition + kicks in.""" for trial in bracket.current_trials(): - self._cleanup_trial( - trial_runner, trial, bracket, - hard=(trial.status == Trial.PAUSED)) + if (trial.status == Trial.PAUSED): + self._cleanup_trial( + trial_runner, trial, bracket, + hard=True) def on_trial_complete(self, trial_runner, trial, result): """Cleans up trial info from bracket if trial completed early.""" @@ -219,12 +241,15 @@ class Bracket(): Also keeps track of progress to ensure good scheduling. """ - def __init__(self, max_trials, init_iters, eta, s): - self._live_trials = {} # stores (result, itrs left before halving) + def __init__(self, time_attr, max_trials, init_t_attr, eta, s): + self._live_trials = {} # maps trial -> current result self._all_trials = [] + self._time_attr = time_attr # attribute to + self._n = self._n0 = max_trials - self._r = self._r0 = init_iters + self._r = self._r0 = init_t_attr self._cumul_r = self._r0 + self._eta = eta self._halves = s @@ -237,15 +262,15 @@ class Bracket(): At a later iteration, a newly added trial will be given equal opportunity to catch up.""" assert not self.filled(), "Cannot add trial to filled bracket!" - self._live_trials[trial] = (None, self._cumul_r) + self._live_trials[trial] = None self._all_trials.append(trial) def cur_iter_done(self): """Checks if all iterations have completed. TODO(rliaw): also check that `t.iterations == self._r`""" - all_done = all(itr == 0 for _, itr in self._live_trials.values()) - return all_done + return all(self._get_result_time(result) >= self._cumul_r + for result in self._live_trials.values()) def finished(self): return self._halves == 0 and self.cur_iter_done() @@ -254,8 +279,8 @@ class Bracket(): return list(self._live_trials) def continue_trial(self, trial): - _, itr = self._live_trials[trial] - if itr > 0: + result = self._live_trials[trial] + if self._get_result_time(result) < self._cumul_r: return True else: return False @@ -265,24 +290,19 @@ class Bracket(): minimizing the need to backtrack and bookkeep previous medians""" return len(self._live_trials) == self._n - def successive_halving(self): + def successive_halving(self, reward_attr): assert self._halves > 0 self._halves -= 1 self._n /= self._eta self._n = int(np.ceil(self._n)) self._r *= self._eta - self._r = int(np.ceil(self._r)) + self._r = int((self._r)) self._cumul_r += self._r sorted_trials = sorted( self._live_trials, - key=lambda t: self._live_trials[t][0].episode_reward_mean) + key=lambda t: getattr(self._live_trials[t], reward_attr)) good, bad = sorted_trials[-self._n:], sorted_trials[:-self._n] - - # reset good trials to track updated iterations - for t in good: - res, old_itr = self._live_trials[t] - self._live_trials[t] = (res, self._r) return good, bad def update_trial_stats(self, trial, result): @@ -293,10 +313,13 @@ class Bracket(): in and make sure they're not set as pending later.""" assert trial in self._live_trials - _, itr = self._live_trials[trial] - assert itr > 0 - self._live_trials[trial] = (result, itr - 1) - self._completed_progress += 1 + assert self._get_result_time(result) >= 0 + + delta = self._get_result_time(result) - \ + self._get_result_time(self._live_trials[trial]) + assert delta >= 0 + self._completed_progress += delta + self._live_trials[trial] = result def cleanup_trial(self, trial): """Clean up statistics tracking for terminated trials (either by force @@ -315,6 +338,11 @@ class Bracket(): are dropped.""" return self._completed_progress / self._total_work + def _get_result_time(self, result): + if result is None: + return 0 + return getattr(result, self._time_attr) + def _calculate_total_work(self, n, r, s): work = 0 for i in range(s+1): @@ -322,6 +350,7 @@ class Bracket(): n /= self._eta n = int(np.ceil(n)) r *= self._eta + r = int(r) return work def __repr__(self): diff --git a/test/trial_scheduler_test.py b/test/trial_scheduler_test.py index af2f43d49..568d93f20 100644 --- a/test/trial_scheduler_test.py +++ b/test/trial_scheduler_test.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function import unittest +import numpy as np from ray.tune.hyperband import HyperBandScheduler from ray.tune.median_stopping_rule import MedianStoppingRule @@ -12,7 +13,9 @@ from ray.tune.trial_scheduler import TrialScheduler def result(t, rew): - return TrainingResult(time_total_s=t, episode_reward_mean=rew) + return TrainingResult(time_total_s=t, + episode_reward_mean=rew, + training_iteration=int(t)) class EarlyStoppingSuite(unittest.TestCase): @@ -156,21 +159,46 @@ class HyperbandSuite(unittest.TestCase): """Setup a scheduler and Runner with max Iter = 9 Bracketing is placed as follows: - (3, 9); - (5, 3) -> (2, 9); - (9, 1) -> (3, 3) -> (1, 9); """ - sched = HyperBandScheduler(9, eta=3) + (5, 81); + (8, 27) -> (3, 81); + (15, 9) -> (5, 27) -> (2, 81); + (34, 3) -> (12, 9) -> (4, 27) -> (2, 81); + (81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 81);""" + sched = HyperBandScheduler() for i in range(num_trials): t = Trial("t%d" % i, "__fake") sched.on_trial_add(None, t) runner = _MockTrialRunner() return sched, runner + def default_statistics(self): + """Default statistics for HyperBand""" + sched = HyperBandScheduler() + res = { + str(s): {"n": sched._get_n0(s), "r": sched._get_r0(s)} + for s in range(sched._s_max_1) + } + res["max_trials"] = sum(v["n"] for v in res.values()) + res["brack_count"] = sched._s_max_1 + res["s_max"] = sched._s_max_1 - 1 + return res + + def downscale(self, n, sched): + return int(np.ceil(n / sched._eta)) + + def process(self, trl, mock_runner, action): + if action == TrialScheduler.CONTINUE: + pass + elif action == TrialScheduler.PAUSE: + mock_runner._pause_trial(trl) + elif action == TrialScheduler.STOP: + self.stopTrial(trl, mock_runner) + def basicSetup(self): """Setup and verify full band. """ - - sched, _ = self.schedulerSetup(17) + stats = self.default_statistics() + sched, _ = self.schedulerSetup(stats["max_trials"]) self.assertEqual(len(sched._hyperbands), 1) self.assertEqual(sched._cur_band_filled(), True) @@ -192,7 +220,7 @@ class HyperbandSuite(unittest.TestCase): self.assertEqual(len(unfilled_band), 2) bracket = unfilled_band[-1] self.assertEqual(bracket.filled(), False) - self.assertEqual(len(bracket.current_trials()), 1) + self.assertEqual(len(bracket.current_trials()), 7) return sched @@ -200,19 +228,254 @@ class HyperbandSuite(unittest.TestCase): self.assertNotEqual(trial.status, Trial.TERMINATED) mock_runner._stop_trial(trial) - def testSuccessiveHalving(self): - """Setup full band, then iterate through last bracket (n=9) - to make sure successive halving is correct.""" + def testConfigSameEta(self): + sched = HyperBandScheduler() + i = 0 + while not sched._cur_band_filled(): + t = Trial("t%d" % (i), "__fake") + sched.on_trial_add(None, t) + i += 1 + self.assertEqual(len(sched._hyperbands[0]), 5) + self.assertEqual(sched._hyperbands[0][0]._n, 5) + self.assertEqual(sched._hyperbands[0][0]._r, 81) + self.assertEqual(sched._hyperbands[0][-1]._n, 81) + self.assertEqual(sched._hyperbands[0][-1]._r, 1) - sched, mock_runner = self.schedulerSetup(17) - filled_band = sched._hyperbands[0][-1] - big_bracket = filled_band + sched = HyperBandScheduler(max_t=810) + i = 0 + while not sched._cur_band_filled(): + t = Trial("t%d" % (i), "__fake") + sched.on_trial_add(None, t) + i += 1 + self.assertEqual(len(sched._hyperbands[0]), 5) + self.assertEqual(sched._hyperbands[0][0]._n, 5) + self.assertEqual(sched._hyperbands[0][0]._r, 810) + self.assertEqual(sched._hyperbands[0][-1]._n, 81) + self.assertEqual(sched._hyperbands[0][-1]._r, 10) + + def testConfigSameEtaSmall(self): + sched = HyperBandScheduler(max_t=1) + i = 0 + while len(sched._hyperbands) < 2: + t = Trial("t%d" % (i), "__fake") + sched.on_trial_add(None, t) + i += 1 + self.assertEqual(len(sched._hyperbands[0]), 5) + self.assertTrue(all(v is None for v in sched._hyperbands[0][1:])) + + def testSuccessiveHalving(self): + """Setup full band, then iterate through last bracket (n=81) + to make sure successive halving is correct.""" + stats = self.default_statistics() + sched, mock_runner = self.schedulerSetup(stats["max_trials"]) + big_bracket = sched._state["bracket"] + cur_units = stats[str(stats["s_max"])]["r"] + # The last bracket will downscale 4 times + for x in range(stats["brack_count"] - 1): + trials = big_bracket.current_trials() + current_length = len(trials) + for trl in trials: + mock_runner._launch_trial(trl) + + # Provides results from 0 to 8 in order, keeping last one running + for i, trl in enumerate(trials): + action = sched.on_trial_result( + mock_runner, trl, result(cur_units, i)) + if i < current_length - 1: + self.assertEqual(action, TrialScheduler.PAUSE) + self.process(trl, mock_runner, action) + + self.assertEqual(action, TrialScheduler.CONTINUE) + new_length = len(big_bracket.current_trials()) + self.assertEqual(new_length, self.downscale(current_length, sched)) + cur_units += int(cur_units * sched._eta) + self.assertEqual(len(big_bracket.current_trials()), 1) + + def testHalvingStop(self): + stats = self.default_statistics() + num_trials = stats[str(0)]["n"] + stats[str(1)]["n"] + sched, mock_runner = self.schedulerSetup(num_trials) + big_bracket = sched._state["bracket"] + for trl in big_bracket.current_trials(): + mock_runner._launch_trial(trl) + + # # Provides result in reverse order, killing the last one + cur_units = stats[str(1)]["r"] + for i, trl in reversed(list(enumerate(big_bracket.current_trials()))): + action = sched.on_trial_result( + mock_runner, trl, result(cur_units, i)) + self.process(trl, mock_runner, action) + + self.assertEqual(action, TrialScheduler.STOP) + + def testContinueLastOne(self): + stats = self.default_statistics() + num_trials = stats[str(0)]["n"] + sched, mock_runner = self.schedulerSetup(num_trials) + big_bracket = sched._state["bracket"] + for trl in big_bracket.current_trials(): + mock_runner._launch_trial(trl) + + # # Provides result in reverse order, killing the last one + cur_units = stats[str(0)]["r"] + for i, trl in enumerate(big_bracket.current_trials()): + action = sched.on_trial_result( + mock_runner, trl, result(cur_units, i)) + self.process(trl, mock_runner, action) + + self.assertEqual(action, TrialScheduler.CONTINUE) + + for x in range(100): + action = sched.on_trial_result( + mock_runner, trl, result(cur_units + x, 10)) + self.assertEqual(action, TrialScheduler.CONTINUE) + + def testTrialErrored(self): + """If a trial errored, make sure successive halving still happens""" + stats = self.default_statistics() + trial_count = stats[str(0)]["n"] + 3 + sched, mock_runner = self.schedulerSetup(trial_count) + t1, t2, t3 = sched._state["bracket"].current_trials() + for t in [t1, t2, t3]: + mock_runner._launch_trial(t) + + sched.on_trial_error(mock_runner, t3) + self.assertEqual( + TrialScheduler.PAUSE, + sched.on_trial_result( + mock_runner, t1, result(stats[str(1)]["r"], 10))) + self.assertEqual( + TrialScheduler.CONTINUE, + sched.on_trial_result( + mock_runner, t2, result(stats[str(1)]["r"], 10))) + + def testTrialErrored2(self): + """Check successive halving happened even when last trial failed""" + stats = self.default_statistics() + trial_count = stats[str(0)]["n"] + stats[str(1)]["n"] + sched, mock_runner = self.schedulerSetup(trial_count) + trials = sched._state["bracket"].current_trials() + for t in trials[:-1]: + mock_runner._launch_trial(t) + sched.on_trial_result( + mock_runner, t, result(stats[str(1)]["r"], 10)) + + mock_runner._launch_trial(trials[-1]) + sched.on_trial_error(mock_runner, trials[-1]) + self.assertEqual(len(sched._state["bracket"].current_trials()), + self.downscale(stats[str(1)]["n"], sched)) + + def testTrialEndedEarly(self): + """Check successive halving happened even when one trial failed""" + stats = self.default_statistics() + trial_count = stats[str(0)]["n"] + 3 + sched, mock_runner = self.schedulerSetup(trial_count) + + t1, t2, t3 = sched._state["bracket"].current_trials() + for t in [t1, t2, t3]: + mock_runner._launch_trial(t) + + sched.on_trial_complete(mock_runner, t3, result(1, 12)) + self.assertEqual( + TrialScheduler.PAUSE, + sched.on_trial_result( + mock_runner, t1, result(stats[str(1)]["r"], 10))) + self.assertEqual( + TrialScheduler.CONTINUE, + sched.on_trial_result( + mock_runner, t2, result(stats[str(1)]["r"], 10))) + + def testTrialEndedEarly2(self): + """Check successive halving happened even when last trial failed""" + stats = self.default_statistics() + trial_count = stats[str(0)]["n"] + stats[str(1)]["n"] + sched, mock_runner = self.schedulerSetup(trial_count) + trials = sched._state["bracket"].current_trials() + for t in trials[:-1]: + mock_runner._launch_trial(t) + sched.on_trial_result( + mock_runner, t, result(stats[str(1)]["r"], 10)) + + mock_runner._launch_trial(trials[-1]) + sched.on_trial_complete(mock_runner, trials[-1], result(100, 12)) + self.assertEqual(len(sched._state["bracket"].current_trials()), + self.downscale(stats[str(1)]["n"], sched)) + + def testAddAfterHalving(self): + stats = self.default_statistics() + trial_count = stats[str(0)]["n"] + 1 + sched, mock_runner = self.schedulerSetup(trial_count) + bracket_trials = sched._state["bracket"].current_trials() + init_units = stats[str(1)]["r"] + + for t in bracket_trials: + mock_runner._launch_trial(t) + + for i, t in enumerate(bracket_trials): + status = sched.on_trial_result( + mock_runner, t, result(init_units, i)) + self.assertEqual(status, TrialScheduler.CONTINUE) + t = Trial("t%d" % 100, "__fake") + sched.on_trial_add(None, t) + mock_runner._launch_trial(t) + self.assertEqual(len(sched._state["bracket"].current_trials()), 2) + + # Make sure that newly added trial gets fair computation (not just 1) + self.assertEqual( + TrialScheduler.CONTINUE, + sched.on_trial_result(mock_runner, t, result(init_units, 12))) + new_units = init_units + int(init_units * sched._eta) + self.assertEqual( + TrialScheduler.PAUSE, + sched.on_trial_result(mock_runner, t, result(new_units, 12))) + + def testAlternateMetrics(self): + """Checking that alternate metrics will pass.""" + + def result2(t, rew): + return TrainingResult(time_total_s=t, neg_mean_loss=rew) + + sched = HyperBandScheduler( + time_attr='time_total_s', reward_attr='neg_mean_loss') + stats = self.default_statistics() + + for i in range(stats["max_trials"]): + t = Trial("t%d" % i, "__fake") + sched.on_trial_add(None, t) + runner = _MockTrialRunner() + + big_bracket = sched._hyperbands[0][-1] + + for trl in big_bracket.current_trials(): + runner._launch_trial(trl) + current_length = len(big_bracket.current_trials()) + + # Provides results from 0 to 8 in order, keeping the last one running + for i, trl in enumerate(big_bracket.current_trials()): + status = sched.on_trial_result(runner, trl, result2(1, i)) + if status == TrialScheduler.CONTINUE: + continue + elif status == TrialScheduler.PAUSE: + runner._pause_trial(trl) + elif status == TrialScheduler.STOP: + self.assertNotEqual(trl.status, Trial.TERMINATED) + self.stopTrial(trl, runner) + + new_length = len(big_bracket.current_trials()) + self.assertEqual(status, TrialScheduler.CONTINUE) + self.assertEqual(new_length, self.downscale(current_length, sched)) + + def testJumpingTime(self): + sched, mock_runner = self.schedulerSetup(81) + big_bracket = sched._hyperbands[0][-1] for trl in big_bracket.current_trials(): mock_runner._launch_trial(trl) # Provides results from 0 to 8 in order, keeping the last one running - for i, trl in enumerate(big_bracket.current_trials()): + main_trials = big_bracket.current_trials()[:-1] + jump = big_bracket.current_trials()[-1] + for i, trl in enumerate(main_trials): status = sched.on_trial_result(mock_runner, trl, result(1, i)) if status == TrialScheduler.CONTINUE: continue @@ -222,117 +485,11 @@ class HyperbandSuite(unittest.TestCase): self.assertNotEqual(trl.status, Trial.TERMINATED) self.stopTrial(trl, mock_runner) + status = sched.on_trial_result(mock_runner, jump, result(4, i)) + self.assertEqual(status, TrialScheduler.PAUSE) + current_length = len(big_bracket.current_trials()) - self.assertEqual(status, TrialScheduler.CONTINUE) - self.assertEqual(current_length, 3) - - # Techincally only need to launch 2/3, as one is already running - for trl in big_bracket.current_trials(): - mock_runner._launch_trial(trl) - - # Provides results from 2 to 0 in order, killing the last one - for i, trl in reversed(list(enumerate(big_bracket.current_trials()))): - for j in range(3): - status = sched.on_trial_result(mock_runner, trl, result(1, i)) - if status == TrialScheduler.CONTINUE: - continue - elif status == TrialScheduler.PAUSE: - mock_runner._pause_trial(trl) - elif status == TrialScheduler.STOP: - self.stopTrial(trl, mock_runner) - - self.assertEqual(status, TrialScheduler.STOP) - trl = big_bracket.current_trials()[0] - for i in range(9): - status = sched.on_trial_result(mock_runner, trl, result(1, i)) - self.assertEqual(status, TrialScheduler.STOP) - self.assertEqual(len(big_bracket.current_trials()), 0) - self.assertEqual(sched._num_stopped, 9) - - def testScheduling(self): - """Setup two bands, then make sure all trials are running""" - sched = self.advancedSetup() - mock_runner = _MockTrialRunner() - trl = sched.choose_trial_to_run(mock_runner) - while trl: - # If band iteration > 0, make sure first band is all running - if sched._trial_info[trl][1] > 0: - first_band = sched._hyperbands[0] - trials = [t for b in first_band for t in b._live_trials] - self.assertEqual( - all(t.status == Trial.RUNNING for t in trials), - True) - mock_runner._launch_trial(trl) - res = sched.on_trial_result(mock_runner, trl, result(1, 10)) - if res is TrialScheduler.PAUSE: - mock_runner._pause_trial(trl) - trl = sched.choose_trial_to_run(mock_runner) - - self.assertEqual( - all(t.status == Trial.RUNNING for t in trials), True) - - def testTrialErrored(self): - sched, mock_runner = self.schedulerSetup(10) - t1, t2 = sched._state["bracket"].current_trials() - mock_runner._launch_trial(t1) - mock_runner._launch_trial(t2) - - sched.on_trial_error(mock_runner, t2) - self.assertEqual( - TrialScheduler.CONTINUE, - sched.on_trial_result(mock_runner, t1, result(1, 10))) - - def testTrialErrored2(self): - """Check successive halving happened even when last trial failed""" - sched, mock_runner = self.schedulerSetup(17) - trials = sched._state["bracket"].current_trials() - self.assertEqual(len(trials), 9) - for t in trials[:-1]: - mock_runner._launch_trial(t) - sched.on_trial_result(mock_runner, t, result(1, 10)) - - mock_runner._launch_trial(trials[-1]) - sched.on_trial_error(mock_runner, trials[-1]) - self.assertEqual(len(sched._state["bracket"].current_trials()), 3) - - def testTrialEndedEarly(self): - sched, mock_runner = self.schedulerSetup(10) - trials = sched._state["bracket"].current_trials() - for t in trials: - mock_runner._launch_trial(t) - - sched.on_trial_complete(mock_runner, trials[-1], result(1, 12)) - self.assertEqual( - TrialScheduler.CONTINUE, - sched.on_trial_result(mock_runner, trials[0], result(1, 12))) - - def testTrialEndedEarly2(self): - """Check successive halving happened even when last trial finished""" - sched, mock_runner = self.schedulerSetup(17) - trials = sched._state["bracket"].current_trials() - self.assertEqual(len(trials), 9) - for t in trials[:-1]: - mock_runner._launch_trial(t) - sched.on_trial_result(mock_runner, t, result(1, 10)) - - mock_runner._launch_trial(trials[-1]) - sched.on_trial_complete(mock_runner, trials[-1], result(1, 12)) - self.assertEqual(len(sched._state["bracket"].current_trials()), 3) - - def testAddAfterHalving(self): - sched, mock_runner = self.schedulerSetup(10) - bracket_trials = sched._state["bracket"].current_trials() - - for t in bracket_trials: - mock_runner._launch_trial(t) - - for i, t in enumerate(bracket_trials): - res = sched.on_trial_result( - mock_runner, t, result(1, i)) - self.assertEqual(res, TrialScheduler.CONTINUE) - t = Trial("t%d" % 5, "__fake") - sched.on_trial_add(None, t) - self.assertEqual(3 + 1, sched._state["bracket"]._live_trials[t][1]) + self.assertLess(current_length, 27) if __name__ == "__main__":