mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:00:10 +08:00
[tune] Implement BOHB (#5382)
This commit is contained in:
committed by
Richard Liaw
parent
79949fb8a0
commit
b7d0733362
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.tune import Trainable, run
|
||||
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
|
||||
from ray.tune.suggest.bohb import TuneBOHB
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
parser.add_argument(
|
||||
"--ray-redis-address",
|
||||
help="Address of Ray cluster for seamless distributed execution.")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
class MyTrainableClass(Trainable):
|
||||
"""Example agent whose learning curve is a random sigmoid.
|
||||
|
||||
The dummy hyperparameters "width" and "height" determine the slope and
|
||||
maximum reward value reached.
|
||||
"""
|
||||
|
||||
def _setup(self, config):
|
||||
self.timestep = 0
|
||||
|
||||
def _train(self):
|
||||
self.timestep += 1
|
||||
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
|
||||
v *= self.config.get("height", 1)
|
||||
|
||||
# Here we use `episode_reward_mean`, but you can also report other
|
||||
# objectives such as loss or accuracy.
|
||||
return {"episode_reward_mean": v}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"timestep": self.timestep}))
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
with open(checkpoint_path) as f:
|
||||
self.timestep = json.loads(f.read())["timestep"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import ConfigSpace as CS
|
||||
ray.init(redis_address=args.ray_redis_address)
|
||||
|
||||
# BOHB uses ConfigSpace for their hyperparameter search space
|
||||
config_space = CS.ConfigurationSpace()
|
||||
config_space.add_hyperparameter(
|
||||
CS.UniformFloatHyperparameter("height", lower=10, upper=100))
|
||||
config_space.add_hyperparameter(
|
||||
CS.UniformFloatHyperparameter("width", lower=0, upper=100))
|
||||
|
||||
experiment_metrics = dict(metric="episode_reward_mean", mode="min")
|
||||
bohb_hyperband = HyperBandForBOHB(
|
||||
time_attr="training_iteration",
|
||||
max_t=100,
|
||||
reduction_factor=4,
|
||||
**experiment_metrics)
|
||||
bohb_search = TuneBOHB(
|
||||
config_space, max_concurrent=4, **experiment_metrics)
|
||||
|
||||
run(MyTrainableClass,
|
||||
name="bohb_test",
|
||||
scheduler=bohb_hyperband,
|
||||
search_alg=bohb_search,
|
||||
num_samples=10,
|
||||
stop={"training_iteration": 10 if args.smoke_test else 100})
|
||||
@@ -4,6 +4,7 @@ from __future__ import print_function
|
||||
|
||||
from ray.tune.schedulers.trial_scheduler import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.schedulers.hyperband import HyperBandScheduler
|
||||
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
|
||||
from ray.tune.schedulers.async_hyperband import (AsyncHyperBandScheduler,
|
||||
ASHAScheduler)
|
||||
from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule
|
||||
@@ -12,5 +13,5 @@ from ray.tune.schedulers.pbt import PopulationBasedTraining
|
||||
__all__ = [
|
||||
"TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler",
|
||||
"ASHAScheduler", "MedianStoppingRule", "FIFOScheduler",
|
||||
"PopulationBasedTraining"
|
||||
"PopulationBasedTraining", "HyperBandForBOHB"
|
||||
]
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
|
||||
from ray.tune.schedulers.trial_scheduler import TrialScheduler
|
||||
from ray.tune.schedulers.hyperband import HyperBandScheduler, Bracket
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HyperBandForBOHB(HyperBandScheduler):
|
||||
"""Extends HyperBand early stopping algorithm for BOHB.
|
||||
|
||||
This implementation removes the ``HyperBandScheduler`` pipelining. This
|
||||
class introduces key changes:
|
||||
|
||||
1. Trials are now placed so that the bracket with the largest size is
|
||||
filled first.
|
||||
|
||||
2. Trials will be paused even if the bracket is not filled. This allows
|
||||
BOHB to insert new trials into the training.
|
||||
|
||||
See ray.tune.schedulers.HyperBandScheduler for parameter docstring.
|
||||
"""
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
"""Adds new trial.
|
||||
|
||||
On a new trial add, if current bracket is not filled, add to current
|
||||
bracket. Else, if current band is not filled, create new bracket, add
|
||||
to current bracket. Else, create new iteration, create new bracket,
|
||||
add to bracket.
|
||||
"""
|
||||
|
||||
cur_bracket = self._state["bracket"]
|
||||
cur_band = self._hyperbands[self._state["band_idx"]]
|
||||
if cur_bracket is None or cur_bracket.filled():
|
||||
retry = True
|
||||
while retry:
|
||||
# if current iteration is filled, create new iteration
|
||||
if self._cur_band_filled():
|
||||
cur_band = []
|
||||
self._hyperbands.append(cur_band)
|
||||
self._state["band_idx"] += 1
|
||||
|
||||
# MAIN CHANGE HERE - largest bracket first!
|
||||
# cur_band will always be less than s_max_1 or else filled
|
||||
s = self._s_max_1 - len(cur_band) - 1
|
||||
assert s >= 0, "Current band is filled!"
|
||||
if self._get_r0(s) == 0:
|
||||
logger.debug("BOHB: Bracket too small - Retrying...")
|
||||
cur_bracket = None
|
||||
else:
|
||||
retry = False
|
||||
cur_bracket = Bracket(self._time_attr, self._get_n0(s),
|
||||
self._get_r0(s), self._max_t_attr,
|
||||
self._eta, s)
|
||||
cur_band.append(cur_bracket)
|
||||
self._state["bracket"] = cur_bracket
|
||||
|
||||
self._state["bracket"].add_trial(trial)
|
||||
self._trial_info[trial] = cur_bracket, self._state["band_idx"]
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
"""If bracket is finished, all trials will be stopped.
|
||||
|
||||
If a given trial finishes and bracket iteration is not done,
|
||||
the trial will be paused and resources will be given up.
|
||||
|
||||
This scheduler will not start trials but will stop trials.
|
||||
The current running trial will not be handled,
|
||||
as the trialrunner will be given control to handle it."""
|
||||
|
||||
result["hyperband_info"] = {}
|
||||
bracket, _ = self._trial_info[trial]
|
||||
bracket.update_trial_stats(trial, result)
|
||||
|
||||
if bracket.continue_trial(trial):
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
result["hyperband_info"]["budget"] = bracket._cumul_r
|
||||
|
||||
# MAIN CHANGE HERE!
|
||||
statuses = [(t, t.status) for t in bracket._live_trials]
|
||||
if not bracket.filled() or any(status != Trial.PAUSED
|
||||
for t, status in statuses
|
||||
if t is not trial):
|
||||
trial_runner._search_alg.on_pause(trial.trial_id)
|
||||
return TrialScheduler.PAUSE
|
||||
action = self._process_bracket(trial_runner, bracket)
|
||||
return action
|
||||
|
||||
def _unpause_trial(self, trial_runner, trial):
|
||||
trial_runner.trial_executor.unpause_trial(trial)
|
||||
trial_runner._search_alg.on_unpause(trial.trial_id)
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
"""Fair scheduling within iteration by completion percentage.
|
||||
|
||||
List of trials not used since all trials are tracked as state
|
||||
of scheduler. If iteration is occupied (ie, no trials to run),
|
||||
then look into next iteration.
|
||||
"""
|
||||
|
||||
for hyperband in self._hyperbands:
|
||||
# band will have None entries if no resources
|
||||
# are to be allocated to that bracket.
|
||||
scrubbed = [b for b in hyperband if b is not None]
|
||||
for bracket in scrubbed:
|
||||
for trial in bracket.current_trials():
|
||||
if (trial.status == Trial.PENDING
|
||||
and trial_runner.has_resources(trial.resources)):
|
||||
return trial
|
||||
# MAIN CHANGE HERE!
|
||||
if not any(t.status == Trial.RUNNING
|
||||
for t in trial_runner.get_trials()):
|
||||
for hyperband in self._hyperbands:
|
||||
for bracket in hyperband:
|
||||
if bracket and any(trial.status == Trial.PAUSED
|
||||
for trial in bracket.current_trials()):
|
||||
# This will change the trial state and let the
|
||||
# trial runner retry.
|
||||
self._process_bracket(trial_runner, bracket)
|
||||
# MAIN CHANGE HERE!
|
||||
return None
|
||||
@@ -8,6 +8,7 @@ import logging
|
||||
|
||||
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.error import TuneError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -72,6 +73,8 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
The scheduler will terminate trials after this time has passed.
|
||||
Note that this is different from the semantics of `max_t` as
|
||||
mentioned in the original HyperBand paper.
|
||||
reduction_factor (float): Same as `eta`. Determines how sharp
|
||||
the difference is between bracket space-time allocation ratios.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -79,7 +82,8 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
reward_attr=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
max_t=81):
|
||||
max_t=81,
|
||||
reduction_factor=3):
|
||||
assert max_t > 0, "Max (time_attr) not valid!"
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
|
||||
@@ -92,8 +96,9 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
"Setting `metric={}` and `mode=max`.".format(reward_attr))
|
||||
|
||||
FIFOScheduler.__init__(self)
|
||||
self._eta = 3
|
||||
self._s_max_1 = 5
|
||||
self._eta = reduction_factor
|
||||
self._s_max_1 = int(
|
||||
np.round(np.log(max_t) / np.log(reduction_factor))) + 1
|
||||
self._max_t_attr = max_t
|
||||
# bracket max trials
|
||||
self._get_n0 = lambda s: int(
|
||||
@@ -173,10 +178,10 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
if bracket.continue_trial(trial):
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
action = self._process_bracket(trial_runner, bracket, trial)
|
||||
action = self._process_bracket(trial_runner, bracket)
|
||||
return action
|
||||
|
||||
def _process_bracket(self, trial_runner, bracket, trial):
|
||||
def _process_bracket(self, trial_runner, bracket):
|
||||
"""This is called whenever a trial makes progress.
|
||||
|
||||
When all live trials in the bracket have no more iterations left,
|
||||
@@ -202,15 +207,15 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
bracket.cleanup_trial(t)
|
||||
action = TrialScheduler.STOP
|
||||
else:
|
||||
raise Exception("Trial with unexpected status encountered")
|
||||
raise TuneError("Trial with unexpected status encountered")
|
||||
|
||||
# ready the good trials - if trial is too far ahead, don't continue
|
||||
for t in good:
|
||||
if t.status not in [Trial.PAUSED, Trial.RUNNING]:
|
||||
raise Exception("Trial with unexpected status encountered")
|
||||
raise TuneError("Trial with unexpected status encountered")
|
||||
if bracket.continue_trial(t):
|
||||
if t.status == Trial.PAUSED:
|
||||
trial_runner.trial_executor.unpause_trial(t)
|
||||
self._unpause_trial(trial_runner, t)
|
||||
elif t.status == Trial.RUNNING:
|
||||
action = TrialScheduler.CONTINUE
|
||||
return action
|
||||
@@ -223,7 +228,7 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
bracket, _ = self._trial_info[trial]
|
||||
bracket.cleanup_trial(trial)
|
||||
if not bracket.finished():
|
||||
self._process_bracket(trial_runner, bracket, trial)
|
||||
self._process_bracket(trial_runner, bracket)
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
"""Cleans up trial info from bracket if trial completed early."""
|
||||
@@ -279,6 +284,15 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
out += "\n {}".format(bracket)
|
||||
return out
|
||||
|
||||
def state(self):
|
||||
return {
|
||||
"num_brackets": sum(len(band) for band in self._hyperbands),
|
||||
"num_stopped": self._num_stopped
|
||||
}
|
||||
|
||||
def _unpause_trial(self, trial_runner, trial):
|
||||
trial_runner.trial_executor.unpause_trial(trial)
|
||||
|
||||
|
||||
class Bracket():
|
||||
"""Logical object for tracking Hyperband bracket progress. Keeps track
|
||||
@@ -349,7 +363,7 @@ class Bracket():
|
||||
|
||||
self._r *= self._eta
|
||||
self._r = int(min(self._r, self._max_t_attr - self._cumul_r))
|
||||
self._cumul_r += self._r
|
||||
self._cumul_r = self._r
|
||||
sorted_trials = sorted(
|
||||
self._live_trials,
|
||||
key=lambda t: metric_op * self._live_trials[t][metric])
|
||||
|
||||
@@ -2,10 +2,11 @@ from ray.tune.suggest.search import SearchAlgorithm
|
||||
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import SuggestionAlgorithm
|
||||
from ray.tune.suggest.variant_generator import grid_search
|
||||
from ray.tune.suggest.bohb import TuneBOHB
|
||||
|
||||
__all__ = [
|
||||
"SearchAlgorithm", "BasicVariantGenerator", "SuggestionAlgorithm",
|
||||
"grid_search"
|
||||
"grid_search", "TuneBOHB"
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
"""BOHB (Bayesian Optimization with HyperBand)"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import logging
|
||||
|
||||
from ray.tune.suggest import SuggestionAlgorithm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _BOHBJobWrapper():
|
||||
"""Mock object for HpBandSter to process."""
|
||||
|
||||
def __init__(self, loss, budget, config):
|
||||
self.result = {"loss": loss}
|
||||
self.kwargs = {"budget": budget, "config": config.copy()}
|
||||
self.exception = None
|
||||
|
||||
|
||||
class TuneBOHB(SuggestionAlgorithm):
|
||||
"""BOHB suggestion component.
|
||||
|
||||
|
||||
Requires HpBandSter and ConfigSpace to be installed. You can install
|
||||
HpBandSter and ConfigSpace with: `pip install hpbandster ConfigSpace`.
|
||||
|
||||
This should be used in conjunction with HyperBandForBOHB.
|
||||
|
||||
Args:
|
||||
space (ConfigurationSpace): Continuous ConfigSpace search space.
|
||||
Parameters will be sampled from this space which will be used
|
||||
to run trials.
|
||||
bohb_config (dict): configuration for HpBandSter BOHB algorithm
|
||||
max_concurrent (int): Number of maximum concurrent trials. Defaults
|
||||
to 10.
|
||||
metric (str): The training result objective value attribute.
|
||||
mode (str): One of {min, max}. Determines whether objective is
|
||||
minimizing or maximizing the metric attribute.
|
||||
|
||||
Example:
|
||||
>>> import ConfigSpace as CS
|
||||
>>> config_space = CS.ConfigurationSpace()
|
||||
>>> config_space.add_hyperparameter(
|
||||
CS.UniformFloatHyperparameter('width', lower=0, upper=20))
|
||||
>>> config_space.add_hyperparameter(
|
||||
CS.UniformFloatHyperparameter('height', lower=-100, upper=100))
|
||||
>>> config_space.add_hyperparameter(
|
||||
CS.CategoricalHyperparameter(
|
||||
name='activation', choices=['relu', 'tanh']))
|
||||
>>> algo = TuneBOHB(
|
||||
config_space, max_concurrent=4, metric='mean_loss', mode='min')
|
||||
>>> bohb = HyperBandForBOHB(
|
||||
time_attr='training_iteration',
|
||||
metric='mean_loss',
|
||||
mode='min',
|
||||
max_t=100)
|
||||
>>> run(MyTrainableClass, scheduler=bohb, search_alg=algo)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
space,
|
||||
bohb_config=None,
|
||||
max_concurrent=10,
|
||||
metric="neg_mean_loss",
|
||||
mode="max"):
|
||||
from hpbandster.optimizers.config_generators.bohb import BOHB
|
||||
assert BOHB is not None, "HpBandSter must be installed!"
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
self._max_concurrent = max_concurrent
|
||||
self.trial_to_params = {}
|
||||
self.running = set()
|
||||
self.paused = set()
|
||||
self.metric = metric
|
||||
if mode == "max":
|
||||
self._metric_op = -1.
|
||||
elif mode == "min":
|
||||
self._metric_op = 1.
|
||||
bohb_config = bohb_config or {}
|
||||
self.bohber = BOHB(space, **bohb_config)
|
||||
super(TuneBOHB, self).__init__()
|
||||
|
||||
def _suggest(self, trial_id):
|
||||
if len(self.running) < self._max_concurrent:
|
||||
# This parameter is not used in hpbandster implementation.
|
||||
config, info = self.bohber.get_config(None)
|
||||
self.trial_to_params[trial_id] = copy.deepcopy(config)
|
||||
self.running.add(trial_id)
|
||||
return config
|
||||
return None
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
if trial_id not in self.paused:
|
||||
self.running.add(trial_id)
|
||||
if "hyperband_info" not in result:
|
||||
logger.warning("BOHB Info not detected in result. Are you using "
|
||||
"HyperBandForBOHB as a scheduler?")
|
||||
elif "budget" in result.get("hyperband_info", {}):
|
||||
hbs_wrapper = self.to_wrapper(trial_id, result)
|
||||
self.bohber.new_result(hbs_wrapper)
|
||||
|
||||
def on_trial_complete(self,
|
||||
trial_id,
|
||||
result=None,
|
||||
error=False,
|
||||
early_terminated=False):
|
||||
del self.trial_to_params[trial_id]
|
||||
if trial_id in self.paused:
|
||||
self.paused.remove(trial_id)
|
||||
if trial_id in self.running:
|
||||
self.running.remove(trial_id)
|
||||
|
||||
def to_wrapper(self, trial_id, result):
|
||||
return _BOHBJobWrapper(self._metric_op * result[self.metric],
|
||||
result["hyperband_info"]["budget"],
|
||||
self.trial_to_params[trial_id])
|
||||
|
||||
def on_pause(self, trial_id):
|
||||
self.paused.add(trial_id)
|
||||
self.running.remove(trial_id)
|
||||
|
||||
def on_unpause(self, trial_id):
|
||||
self.paused.remove(trial_id)
|
||||
self.running.add(trial_id)
|
||||
@@ -15,7 +15,7 @@ import ray
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
|
||||
PopulationBasedTraining, MedianStoppingRule,
|
||||
TrialScheduler)
|
||||
TrialScheduler, HyperBandForBOHB)
|
||||
|
||||
from ray.tune.schedulers.pbt import explore
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
@@ -236,7 +236,7 @@ class _MockTrialRunner():
|
||||
|
||||
class HyperbandSuite(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
ray.init(object_store_memory=int(1e8))
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
@@ -319,17 +319,19 @@ class HyperbandSuite(unittest.TestCase):
|
||||
self.assertEqual(sched._hyperbands[0][-1]._n, 81)
|
||||
self.assertEqual(sched._hyperbands[0][-1]._r, 1)
|
||||
|
||||
sched = HyperBandScheduler(max_t=810)
|
||||
reduction_factor = 10
|
||||
sched = HyperBandScheduler(
|
||||
max_t=1000, reduction_factor=reduction_factor)
|
||||
i = 0
|
||||
while not sched._cur_band_filled():
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
i += 1
|
||||
self.assertEqual(len(sched._hyperbands[0]), 5)
|
||||
self.assertEqual(sched._hyperbands[0][0]._n, 5)
|
||||
self.assertEqual(sched._hyperbands[0][0]._r, 810)
|
||||
self.assertEqual(sched._hyperbands[0][-1]._n, 81)
|
||||
self.assertEqual(sched._hyperbands[0][-1]._r, 10)
|
||||
self.assertEqual(len(sched._hyperbands[0]), 4)
|
||||
self.assertEqual(sched._hyperbands[0][0]._n, 4)
|
||||
self.assertEqual(sched._hyperbands[0][0]._r, 1000)
|
||||
self.assertEqual(sched._hyperbands[0][-1]._n, 1000)
|
||||
self.assertEqual(sched._hyperbands[0][-1]._r, 1)
|
||||
|
||||
def testConfigSameEtaSmall(self):
|
||||
sched = HyperBandScheduler(max_t=1)
|
||||
@@ -338,8 +340,7 @@ class HyperbandSuite(unittest.TestCase):
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
i += 1
|
||||
self.assertEqual(len(sched._hyperbands[0]), 5)
|
||||
self.assertTrue(all(v is None for v in sched._hyperbands[0][1:]))
|
||||
self.assertEqual(len(sched._hyperbands[0]), 1)
|
||||
|
||||
def testSuccessiveHalving(self):
|
||||
"""Setup full band, then iterate through last bracket (n=81)
|
||||
@@ -367,7 +368,7 @@ class HyperbandSuite(unittest.TestCase):
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
new_length = len(big_bracket.current_trials())
|
||||
self.assertEqual(new_length, self.downscale(current_length, sched))
|
||||
cur_units += int(cur_units * sched._eta)
|
||||
cur_units = int(cur_units * sched._eta)
|
||||
self.assertEqual(len(big_bracket.current_trials()), 1)
|
||||
|
||||
def testHalvingStop(self):
|
||||
@@ -603,6 +604,76 @@ class HyperbandSuite(unittest.TestCase):
|
||||
self.assertIsNotNone(trial)
|
||||
|
||||
|
||||
class BOHBSuite(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(object_store_memory=int(1e8))
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def testLargestBracketFirst(self):
|
||||
sched = HyperBandForBOHB(max_t=3, reduction_factor=3)
|
||||
runner = _MockTrialRunner(sched)
|
||||
for i in range(3):
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(runner, t)
|
||||
runner._launch_trial(t)
|
||||
|
||||
self.assertEqual(sched.state()["num_brackets"], 1)
|
||||
sched.on_trial_add(runner, Trial("__fake"))
|
||||
self.assertEqual(sched.state()["num_brackets"], 2)
|
||||
|
||||
def testCheckTrialInfoUpdate(self):
|
||||
def result(score, ts):
|
||||
return {"episode_reward_mean": score, TRAINING_ITERATION: ts}
|
||||
|
||||
sched = HyperBandForBOHB(max_t=3, reduction_factor=3)
|
||||
runner = _MockTrialRunner(sched)
|
||||
runner._search_alg = MagicMock()
|
||||
trials = [Trial("__fake") for i in range(3)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
runner._launch_trial(t)
|
||||
|
||||
for trial, trial_result in zip(trials, [result(1, 1), result(2, 1)]):
|
||||
decision = sched.on_trial_result(runner, trial, trial_result)
|
||||
self.assertEqual(decision, TrialScheduler.PAUSE)
|
||||
runner._pause_trial(trial)
|
||||
spy_result = result(0, 1)
|
||||
decision = sched.on_trial_result(runner, trials[-1], spy_result)
|
||||
self.assertEqual(decision, TrialScheduler.STOP)
|
||||
sched.choose_trial_to_run(runner)
|
||||
self.assertEqual(runner._search_alg.on_pause.call_count, 2)
|
||||
self.assertEqual(runner._search_alg.on_unpause.call_count, 1)
|
||||
self.assertTrue("hyperband_info" in spy_result)
|
||||
self.assertEquals(spy_result["hyperband_info"]["budget"], 1)
|
||||
|
||||
def testCheckTrialInfoUpdateMin(self):
|
||||
def result(score, ts):
|
||||
return {"episode_reward_mean": score, TRAINING_ITERATION: ts}
|
||||
|
||||
sched = HyperBandForBOHB(max_t=3, reduction_factor=3, mode="min")
|
||||
runner = _MockTrialRunner(sched)
|
||||
runner._search_alg = MagicMock()
|
||||
trials = [Trial("__fake") for i in range(3)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
runner._launch_trial(t)
|
||||
|
||||
for trial, trial_result in zip(trials, [result(1, 1), result(2, 1)]):
|
||||
decision = sched.on_trial_result(runner, trial, trial_result)
|
||||
self.assertEqual(decision, TrialScheduler.PAUSE)
|
||||
runner._pause_trial(trial)
|
||||
spy_result = result(0, 1)
|
||||
decision = sched.on_trial_result(runner, trials[-1], spy_result)
|
||||
self.assertEqual(decision, TrialScheduler.CONTINUE)
|
||||
sched.choose_trial_to_run(runner)
|
||||
self.assertEqual(runner._search_alg.on_pause.call_count, 2)
|
||||
self.assertTrue("hyperband_info" in spy_result)
|
||||
self.assertEquals(spy_result["hyperband_info"]["budget"], 1)
|
||||
|
||||
|
||||
class _MockTrial(Trial):
|
||||
def __init__(self, i, config):
|
||||
self.trainable_name = "trial_{}".format(i)
|
||||
|
||||
Reference in New Issue
Block a user