mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:01:24 +08:00
[tune] Add BayesOpt (#3864)
Adds BayesOpt as a Tune suggestion algorithm.
This commit is contained in:
committed by
Richard Liaw
parent
d3551dd8df
commit
62a0a7bdc7
@@ -0,0 +1,61 @@
|
||||
"""This test checks that BayesOpt is functional.
|
||||
|
||||
It also checks that it is usable with a separate scheduler.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments, register_trainable
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
from ray.tune.suggest import BayesOptSearch
|
||||
|
||||
|
||||
def easy_objective(config, reporter):
|
||||
import time
|
||||
time.sleep(0.2)
|
||||
for i in range(config["iterations"]):
|
||||
reporter(
|
||||
timesteps_total=i,
|
||||
neg_mean_loss=-(config["height"] - 14)**2 +
|
||||
abs(config["width"] - 3))
|
||||
time.sleep(0.02)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(redirect_output=True)
|
||||
|
||||
register_trainable("exp", easy_objective)
|
||||
|
||||
space = {'width': (0, 20), 'height': (-100, 100)}
|
||||
|
||||
config = {
|
||||
"my_exp": {
|
||||
"run": "exp",
|
||||
"num_samples": 10 if args.smoke_test else 1000,
|
||||
"config": {
|
||||
"iterations": 100,
|
||||
},
|
||||
"stop": {
|
||||
"timesteps_total": 100
|
||||
},
|
||||
}
|
||||
}
|
||||
algo = BayesOptSearch(
|
||||
space,
|
||||
max_concurrent=4,
|
||||
reward_attr="neg_mean_loss",
|
||||
utility_kwargs={
|
||||
"kind": "ucb",
|
||||
"kappa": 2.5,
|
||||
"xi": 0.0
|
||||
})
|
||||
scheduler = AsyncHyperBandScheduler(reward_attr="neg_mean_loss")
|
||||
run_experiments(config, search_alg=algo, scheduler=scheduler)
|
||||
@@ -1,6 +1,7 @@
|
||||
from ray.tune.suggest.search import SearchAlgorithm
|
||||
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import SuggestionAlgorithm
|
||||
from ray.tune.suggest.bayesopt import BayesOptSearch
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from ray.tune.suggest.variant_generator import grid_search, function, \
|
||||
sample_from
|
||||
@@ -8,6 +9,7 @@ from ray.tune.suggest.variant_generator import grid_search, function, \
|
||||
__all__ = [
|
||||
"SearchAlgorithm",
|
||||
"BasicVariantGenerator",
|
||||
"BayesOptSearch",
|
||||
"HyperOptSearch",
|
||||
"SuggestionAlgorithm",
|
||||
"grid_search",
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
try:
|
||||
import bayes_opt as byo
|
||||
except Exception:
|
||||
byo = None
|
||||
|
||||
from ray.tune.suggest.suggestion import SuggestionAlgorithm
|
||||
|
||||
|
||||
class BayesOptSearch(SuggestionAlgorithm):
|
||||
"""A wrapper around BayesOpt to provide trial suggestions.
|
||||
|
||||
Requires BayesOpt to be installed. You can install BayesOpt with the
|
||||
command: `pip install bayesian-optimization`.
|
||||
|
||||
Parameters:
|
||||
space (dict): Continuous search space. Parameters will be sampled from
|
||||
this space which will be used to run trials.
|
||||
max_concurrent (int): Number of maximum concurrent trials. Defaults
|
||||
to 10.
|
||||
reward_attr (str): The training result objective value attribute.
|
||||
This refers to an increasing value.
|
||||
utility_kwargs (dict): Parameters to define the utility function. Must
|
||||
provide values for the keys `kind`, `kappa`, and `xi`.
|
||||
random_state (int): Used to initialize BayesOpt.
|
||||
verbose (int): Sets verbosity level for BayesOpt packages.
|
||||
|
||||
Example:
|
||||
>>> space = {
|
||||
>>> 'width': (0, 20),
|
||||
>>> 'height': (-100, 100),
|
||||
>>> }
|
||||
>>> config = {
|
||||
>>> "my_exp": {
|
||||
>>> "run": "exp",
|
||||
>>> "num_samples": 10 if args.smoke_test else 1000,
|
||||
>>> "stop": {
|
||||
>>> "training_iteration": 100
|
||||
>>> },
|
||||
>>> }
|
||||
>>> }
|
||||
>>> algo = BayesOptSearch(
|
||||
>>> space, max_concurrent=4, reward_attr="neg_mean_loss")
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
space,
|
||||
max_concurrent=10,
|
||||
reward_attr="episode_reward_mean",
|
||||
utility_kwargs=None,
|
||||
random_state=1,
|
||||
verbose=0,
|
||||
**kwargs):
|
||||
assert byo is not None, (
|
||||
"BayesOpt must be installed!. You can install BayesOpt with"
|
||||
" the command: `pip install bayesian-optimization`.")
|
||||
assert type(max_concurrent) is int and max_concurrent > 0
|
||||
assert utility_kwargs is not None, (
|
||||
"Must define arguments for the utiliy function!")
|
||||
self._max_concurrent = max_concurrent
|
||||
self._reward_attr = reward_attr
|
||||
self._live_trial_mapping = {}
|
||||
|
||||
self.optimizer = byo.BayesianOptimization(
|
||||
f=None, pbounds=space, verbose=verbose, random_state=random_state)
|
||||
|
||||
self.utility = byo.UtilityFunction(**utility_kwargs)
|
||||
|
||||
super(BayesOptSearch, self).__init__(**kwargs)
|
||||
|
||||
def _suggest(self, trial_id):
|
||||
if self._num_live_trials() >= self._max_concurrent:
|
||||
return None
|
||||
|
||||
new_trial = self.optimizer.suggest(self.utility)
|
||||
|
||||
self._live_trial_mapping[trial_id] = new_trial
|
||||
|
||||
return copy.deepcopy(new_trial)
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
pass
|
||||
|
||||
def on_trial_complete(self,
|
||||
trial_id,
|
||||
result=None,
|
||||
error=False,
|
||||
early_terminated=False):
|
||||
"""Passes the result to BayesOpt unless early terminated or errored"""
|
||||
if result:
|
||||
self.optimizer.register(
|
||||
params=self._live_trial_mapping[trial_id],
|
||||
target=result[self._reward_attr])
|
||||
del self._live_trial_mapping[trial_id]
|
||||
|
||||
def _num_live_trials(self):
|
||||
return len(self._live_trial_mapping)
|
||||
Reference in New Issue
Block a user