[tune] Implement BOHB (#5382)

This commit is contained in:
Lisa Dunlap
2019-08-13 14:32:07 -05:00
committed by Richard Liaw
parent 79949fb8a0
commit b7d0733362
12 changed files with 519 additions and 30 deletions
+82
View File
@@ -0,0 +1,82 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import os
import numpy as np
import ray
from ray.tune import Trainable, run
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.suggest.bohb import TuneBOHB
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
parser.add_argument(
"--ray-redis-address",
help="Address of Ray cluster for seamless distributed execution.")
args, _ = parser.parse_known_args()
class MyTrainableClass(Trainable):
"""Example agent whose learning curve is a random sigmoid.
The dummy hyperparameters "width" and "height" determine the slope and
maximum reward value reached.
"""
def _setup(self, config):
self.timestep = 0
def _train(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
v *= self.config.get("height", 1)
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy.
return {"episode_reward_mean": v}
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path
def _restore(self, checkpoint_path):
with open(checkpoint_path) as f:
self.timestep = json.loads(f.read())["timestep"]
if __name__ == "__main__":
import ConfigSpace as CS
ray.init(redis_address=args.ray_redis_address)
# BOHB uses ConfigSpace for their hyperparameter search space
config_space = CS.ConfigurationSpace()
config_space.add_hyperparameter(
CS.UniformFloatHyperparameter("height", lower=10, upper=100))
config_space.add_hyperparameter(
CS.UniformFloatHyperparameter("width", lower=0, upper=100))
experiment_metrics = dict(metric="episode_reward_mean", mode="min")
bohb_hyperband = HyperBandForBOHB(
time_attr="training_iteration",
max_t=100,
reduction_factor=4,
**experiment_metrics)
bohb_search = TuneBOHB(
config_space, max_concurrent=4, **experiment_metrics)
run(MyTrainableClass,
name="bohb_test",
scheduler=bohb_hyperband,
search_alg=bohb_search,
num_samples=10,
stop={"training_iteration": 10 if args.smoke_test else 100})
+2 -1
View File
@@ -4,6 +4,7 @@ from __future__ import print_function
from ray.tune.schedulers.trial_scheduler import TrialScheduler, FIFOScheduler
from ray.tune.schedulers.hyperband import HyperBandScheduler
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.schedulers.async_hyperband import (AsyncHyperBandScheduler,
ASHAScheduler)
from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule
@@ -12,5 +13,5 @@ from ray.tune.schedulers.pbt import PopulationBasedTraining
__all__ = [
"TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler",
"ASHAScheduler", "MedianStoppingRule", "FIFOScheduler",
"PopulationBasedTraining"
"PopulationBasedTraining", "HyperBandForBOHB"
]
+128
View File
@@ -0,0 +1,128 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
from ray.tune.schedulers.trial_scheduler import TrialScheduler
from ray.tune.schedulers.hyperband import HyperBandScheduler, Bracket
from ray.tune.trial import Trial
logger = logging.getLogger(__name__)
class HyperBandForBOHB(HyperBandScheduler):
"""Extends HyperBand early stopping algorithm for BOHB.
This implementation removes the ``HyperBandScheduler`` pipelining. This
class introduces key changes:
1. Trials are now placed so that the bracket with the largest size is
filled first.
2. Trials will be paused even if the bracket is not filled. This allows
BOHB to insert new trials into the training.
See ray.tune.schedulers.HyperBandScheduler for parameter docstring.
"""
def on_trial_add(self, trial_runner, trial):
"""Adds new trial.
On a new trial add, if current bracket is not filled, add to current
bracket. Else, if current band is not filled, create new bracket, add
to current bracket. Else, create new iteration, create new bracket,
add to bracket.
"""
cur_bracket = self._state["bracket"]
cur_band = self._hyperbands[self._state["band_idx"]]
if cur_bracket is None or cur_bracket.filled():
retry = True
while retry:
# if current iteration is filled, create new iteration
if self._cur_band_filled():
cur_band = []
self._hyperbands.append(cur_band)
self._state["band_idx"] += 1
# MAIN CHANGE HERE - largest bracket first!
# cur_band will always be less than s_max_1 or else filled
s = self._s_max_1 - len(cur_band) - 1
assert s >= 0, "Current band is filled!"
if self._get_r0(s) == 0:
logger.debug("BOHB: Bracket too small - Retrying...")
cur_bracket = None
else:
retry = False
cur_bracket = Bracket(self._time_attr, self._get_n0(s),
self._get_r0(s), self._max_t_attr,
self._eta, s)
cur_band.append(cur_bracket)
self._state["bracket"] = cur_bracket
self._state["bracket"].add_trial(trial)
self._trial_info[trial] = cur_bracket, self._state["band_idx"]
def on_trial_result(self, trial_runner, trial, result):
"""If bracket is finished, all trials will be stopped.
If a given trial finishes and bracket iteration is not done,
the trial will be paused and resources will be given up.
This scheduler will not start trials but will stop trials.
The current running trial will not be handled,
as the trialrunner will be given control to handle it."""
result["hyperband_info"] = {}
bracket, _ = self._trial_info[trial]
bracket.update_trial_stats(trial, result)
if bracket.continue_trial(trial):
return TrialScheduler.CONTINUE
result["hyperband_info"]["budget"] = bracket._cumul_r
# MAIN CHANGE HERE!
statuses = [(t, t.status) for t in bracket._live_trials]
if not bracket.filled() or any(status != Trial.PAUSED
for t, status in statuses
if t is not trial):
trial_runner._search_alg.on_pause(trial.trial_id)
return TrialScheduler.PAUSE
action = self._process_bracket(trial_runner, bracket)
return action
def _unpause_trial(self, trial_runner, trial):
trial_runner.trial_executor.unpause_trial(trial)
trial_runner._search_alg.on_unpause(trial.trial_id)
def choose_trial_to_run(self, trial_runner):
"""Fair scheduling within iteration by completion percentage.
List of trials not used since all trials are tracked as state
of scheduler. If iteration is occupied (ie, no trials to run),
then look into next iteration.
"""
for hyperband in self._hyperbands:
# band will have None entries if no resources
# are to be allocated to that bracket.
scrubbed = [b for b in hyperband if b is not None]
for bracket in scrubbed:
for trial in bracket.current_trials():
if (trial.status == Trial.PENDING
and trial_runner.has_resources(trial.resources)):
return trial
# MAIN CHANGE HERE!
if not any(t.status == Trial.RUNNING
for t in trial_runner.get_trials()):
for hyperband in self._hyperbands:
for bracket in hyperband:
if bracket and any(trial.status == Trial.PAUSED
for trial in bracket.current_trials()):
# This will change the trial state and let the
# trial runner retry.
self._process_bracket(trial_runner, bracket)
# MAIN CHANGE HERE!
return None
+24 -10
View File
@@ -8,6 +8,7 @@ import logging
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.tune.trial import Trial
from ray.tune.error import TuneError
logger = logging.getLogger(__name__)
@@ -72,6 +73,8 @@ class HyperBandScheduler(FIFOScheduler):
The scheduler will terminate trials after this time has passed.
Note that this is different from the semantics of `max_t` as
mentioned in the original HyperBand paper.
reduction_factor (float): Same as `eta`. Determines how sharp
the difference is between bracket space-time allocation ratios.
"""
def __init__(self,
@@ -79,7 +82,8 @@ class HyperBandScheduler(FIFOScheduler):
reward_attr=None,
metric="episode_reward_mean",
mode="max",
max_t=81):
max_t=81,
reduction_factor=3):
assert max_t > 0, "Max (time_attr) not valid!"
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
@@ -92,8 +96,9 @@ class HyperBandScheduler(FIFOScheduler):
"Setting `metric={}` and `mode=max`.".format(reward_attr))
FIFOScheduler.__init__(self)
self._eta = 3
self._s_max_1 = 5
self._eta = reduction_factor
self._s_max_1 = int(
np.round(np.log(max_t) / np.log(reduction_factor))) + 1
self._max_t_attr = max_t
# bracket max trials
self._get_n0 = lambda s: int(
@@ -173,10 +178,10 @@ class HyperBandScheduler(FIFOScheduler):
if bracket.continue_trial(trial):
return TrialScheduler.CONTINUE
action = self._process_bracket(trial_runner, bracket, trial)
action = self._process_bracket(trial_runner, bracket)
return action
def _process_bracket(self, trial_runner, bracket, trial):
def _process_bracket(self, trial_runner, bracket):
"""This is called whenever a trial makes progress.
When all live trials in the bracket have no more iterations left,
@@ -202,15 +207,15 @@ class HyperBandScheduler(FIFOScheduler):
bracket.cleanup_trial(t)
action = TrialScheduler.STOP
else:
raise Exception("Trial with unexpected status encountered")
raise TuneError("Trial with unexpected status encountered")
# ready the good trials - if trial is too far ahead, don't continue
for t in good:
if t.status not in [Trial.PAUSED, Trial.RUNNING]:
raise Exception("Trial with unexpected status encountered")
raise TuneError("Trial with unexpected status encountered")
if bracket.continue_trial(t):
if t.status == Trial.PAUSED:
trial_runner.trial_executor.unpause_trial(t)
self._unpause_trial(trial_runner, t)
elif t.status == Trial.RUNNING:
action = TrialScheduler.CONTINUE
return action
@@ -223,7 +228,7 @@ class HyperBandScheduler(FIFOScheduler):
bracket, _ = self._trial_info[trial]
bracket.cleanup_trial(trial)
if not bracket.finished():
self._process_bracket(trial_runner, bracket, trial)
self._process_bracket(trial_runner, bracket)
def on_trial_complete(self, trial_runner, trial, result):
"""Cleans up trial info from bracket if trial completed early."""
@@ -279,6 +284,15 @@ class HyperBandScheduler(FIFOScheduler):
out += "\n {}".format(bracket)
return out
def state(self):
return {
"num_brackets": sum(len(band) for band in self._hyperbands),
"num_stopped": self._num_stopped
}
def _unpause_trial(self, trial_runner, trial):
trial_runner.trial_executor.unpause_trial(trial)
class Bracket():
"""Logical object for tracking Hyperband bracket progress. Keeps track
@@ -349,7 +363,7 @@ class Bracket():
self._r *= self._eta
self._r = int(min(self._r, self._max_t_attr - self._cumul_r))
self._cumul_r += self._r
self._cumul_r = self._r
sorted_trials = sorted(
self._live_trials,
key=lambda t: metric_op * self._live_trials[t][metric])
+2 -1
View File
@@ -2,10 +2,11 @@ from ray.tune.suggest.search import SearchAlgorithm
from ray.tune.suggest.basic_variant import BasicVariantGenerator
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest.variant_generator import grid_search
from ray.tune.suggest.bohb import TuneBOHB
__all__ = [
"SearchAlgorithm", "BasicVariantGenerator", "SuggestionAlgorithm",
"grid_search"
"grid_search", "TuneBOHB"
]
+128
View File
@@ -0,0 +1,128 @@
"""BOHB (Bayesian Optimization with HyperBand)"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import logging
from ray.tune.suggest import SuggestionAlgorithm
logger = logging.getLogger(__name__)
class _BOHBJobWrapper():
"""Mock object for HpBandSter to process."""
def __init__(self, loss, budget, config):
self.result = {"loss": loss}
self.kwargs = {"budget": budget, "config": config.copy()}
self.exception = None
class TuneBOHB(SuggestionAlgorithm):
"""BOHB suggestion component.
Requires HpBandSter and ConfigSpace to be installed. You can install
HpBandSter and ConfigSpace with: `pip install hpbandster ConfigSpace`.
This should be used in conjunction with HyperBandForBOHB.
Args:
space (ConfigurationSpace): Continuous ConfigSpace search space.
Parameters will be sampled from this space which will be used
to run trials.
bohb_config (dict): configuration for HpBandSter BOHB algorithm
max_concurrent (int): Number of maximum concurrent trials. Defaults
to 10.
metric (str): The training result objective value attribute.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
Example:
>>> import ConfigSpace as CS
>>> config_space = CS.ConfigurationSpace()
>>> config_space.add_hyperparameter(
CS.UniformFloatHyperparameter('width', lower=0, upper=20))
>>> config_space.add_hyperparameter(
CS.UniformFloatHyperparameter('height', lower=-100, upper=100))
>>> config_space.add_hyperparameter(
CS.CategoricalHyperparameter(
name='activation', choices=['relu', 'tanh']))
>>> algo = TuneBOHB(
config_space, max_concurrent=4, metric='mean_loss', mode='min')
>>> bohb = HyperBandForBOHB(
time_attr='training_iteration',
metric='mean_loss',
mode='min',
max_t=100)
>>> run(MyTrainableClass, scheduler=bohb, search_alg=algo)
"""
def __init__(self,
space,
bohb_config=None,
max_concurrent=10,
metric="neg_mean_loss",
mode="max"):
from hpbandster.optimizers.config_generators.bohb import BOHB
assert BOHB is not None, "HpBandSter must be installed!"
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
self._max_concurrent = max_concurrent
self.trial_to_params = {}
self.running = set()
self.paused = set()
self.metric = metric
if mode == "max":
self._metric_op = -1.
elif mode == "min":
self._metric_op = 1.
bohb_config = bohb_config or {}
self.bohber = BOHB(space, **bohb_config)
super(TuneBOHB, self).__init__()
def _suggest(self, trial_id):
if len(self.running) < self._max_concurrent:
# This parameter is not used in hpbandster implementation.
config, info = self.bohber.get_config(None)
self.trial_to_params[trial_id] = copy.deepcopy(config)
self.running.add(trial_id)
return config
return None
def on_trial_result(self, trial_id, result):
if trial_id not in self.paused:
self.running.add(trial_id)
if "hyperband_info" not in result:
logger.warning("BOHB Info not detected in result. Are you using "
"HyperBandForBOHB as a scheduler?")
elif "budget" in result.get("hyperband_info", {}):
hbs_wrapper = self.to_wrapper(trial_id, result)
self.bohber.new_result(hbs_wrapper)
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
del self.trial_to_params[trial_id]
if trial_id in self.paused:
self.paused.remove(trial_id)
if trial_id in self.running:
self.running.remove(trial_id)
def to_wrapper(self, trial_id, result):
return _BOHBJobWrapper(self._metric_op * result[self.metric],
result["hyperband_info"]["budget"],
self.trial_to_params[trial_id])
def on_pause(self, trial_id):
self.paused.add(trial_id)
self.running.remove(trial_id)
def on_unpause(self, trial_id):
self.paused.remove(trial_id)
self.running.add(trial_id)
+82 -11
View File
@@ -15,7 +15,7 @@ import ray
from ray.tune.result import TRAINING_ITERATION
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
PopulationBasedTraining, MedianStoppingRule,
TrialScheduler)
TrialScheduler, HyperBandForBOHB)
from ray.tune.schedulers.pbt import explore
from ray.tune.trial import Trial, Checkpoint
@@ -236,7 +236,7 @@ class _MockTrialRunner():
class HyperbandSuite(unittest.TestCase):
def setUp(self):
ray.init()
ray.init(object_store_memory=int(1e8))
def tearDown(self):
ray.shutdown()
@@ -319,17 +319,19 @@ class HyperbandSuite(unittest.TestCase):
self.assertEqual(sched._hyperbands[0][-1]._n, 81)
self.assertEqual(sched._hyperbands[0][-1]._r, 1)
sched = HyperBandScheduler(max_t=810)
reduction_factor = 10
sched = HyperBandScheduler(
max_t=1000, reduction_factor=reduction_factor)
i = 0
while not sched._cur_band_filled():
t = Trial("__fake")
sched.on_trial_add(None, t)
i += 1
self.assertEqual(len(sched._hyperbands[0]), 5)
self.assertEqual(sched._hyperbands[0][0]._n, 5)
self.assertEqual(sched._hyperbands[0][0]._r, 810)
self.assertEqual(sched._hyperbands[0][-1]._n, 81)
self.assertEqual(sched._hyperbands[0][-1]._r, 10)
self.assertEqual(len(sched._hyperbands[0]), 4)
self.assertEqual(sched._hyperbands[0][0]._n, 4)
self.assertEqual(sched._hyperbands[0][0]._r, 1000)
self.assertEqual(sched._hyperbands[0][-1]._n, 1000)
self.assertEqual(sched._hyperbands[0][-1]._r, 1)
def testConfigSameEtaSmall(self):
sched = HyperBandScheduler(max_t=1)
@@ -338,8 +340,7 @@ class HyperbandSuite(unittest.TestCase):
t = Trial("__fake")
sched.on_trial_add(None, t)
i += 1
self.assertEqual(len(sched._hyperbands[0]), 5)
self.assertTrue(all(v is None for v in sched._hyperbands[0][1:]))
self.assertEqual(len(sched._hyperbands[0]), 1)
def testSuccessiveHalving(self):
"""Setup full band, then iterate through last bracket (n=81)
@@ -367,7 +368,7 @@ class HyperbandSuite(unittest.TestCase):
self.assertEqual(action, TrialScheduler.CONTINUE)
new_length = len(big_bracket.current_trials())
self.assertEqual(new_length, self.downscale(current_length, sched))
cur_units += int(cur_units * sched._eta)
cur_units = int(cur_units * sched._eta)
self.assertEqual(len(big_bracket.current_trials()), 1)
def testHalvingStop(self):
@@ -603,6 +604,76 @@ class HyperbandSuite(unittest.TestCase):
self.assertIsNotNone(trial)
class BOHBSuite(unittest.TestCase):
def setUp(self):
ray.init(object_store_memory=int(1e8))
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
def testLargestBracketFirst(self):
sched = HyperBandForBOHB(max_t=3, reduction_factor=3)
runner = _MockTrialRunner(sched)
for i in range(3):
t = Trial("__fake")
sched.on_trial_add(runner, t)
runner._launch_trial(t)
self.assertEqual(sched.state()["num_brackets"], 1)
sched.on_trial_add(runner, Trial("__fake"))
self.assertEqual(sched.state()["num_brackets"], 2)
def testCheckTrialInfoUpdate(self):
def result(score, ts):
return {"episode_reward_mean": score, TRAINING_ITERATION: ts}
sched = HyperBandForBOHB(max_t=3, reduction_factor=3)
runner = _MockTrialRunner(sched)
runner._search_alg = MagicMock()
trials = [Trial("__fake") for i in range(3)]
for t in trials:
runner.add_trial(t)
runner._launch_trial(t)
for trial, trial_result in zip(trials, [result(1, 1), result(2, 1)]):
decision = sched.on_trial_result(runner, trial, trial_result)
self.assertEqual(decision, TrialScheduler.PAUSE)
runner._pause_trial(trial)
spy_result = result(0, 1)
decision = sched.on_trial_result(runner, trials[-1], spy_result)
self.assertEqual(decision, TrialScheduler.STOP)
sched.choose_trial_to_run(runner)
self.assertEqual(runner._search_alg.on_pause.call_count, 2)
self.assertEqual(runner._search_alg.on_unpause.call_count, 1)
self.assertTrue("hyperband_info" in spy_result)
self.assertEquals(spy_result["hyperband_info"]["budget"], 1)
def testCheckTrialInfoUpdateMin(self):
def result(score, ts):
return {"episode_reward_mean": score, TRAINING_ITERATION: ts}
sched = HyperBandForBOHB(max_t=3, reduction_factor=3, mode="min")
runner = _MockTrialRunner(sched)
runner._search_alg = MagicMock()
trials = [Trial("__fake") for i in range(3)]
for t in trials:
runner.add_trial(t)
runner._launch_trial(t)
for trial, trial_result in zip(trials, [result(1, 1), result(2, 1)]):
decision = sched.on_trial_result(runner, trial, trial_result)
self.assertEqual(decision, TrialScheduler.PAUSE)
runner._pause_trial(trial)
spy_result = result(0, 1)
decision = sched.on_trial_result(runner, trials[-1], spy_result)
self.assertEqual(decision, TrialScheduler.CONTINUE)
sched.choose_trial_to_run(runner)
self.assertEqual(runner._search_alg.on_pause.call_count, 2)
self.assertTrue("hyperband_info" in spy_result)
self.assertEquals(spy_result["hyperband_info"]["budget"], 1)
class _MockTrial(Trial):
def __init__(self, i, config):
self.trainable_name = "trial_{}".format(i)