mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:00:10 +08:00
[tune] Add hyperopt warm start feature (#5372)
This commit is contained in:
committed by
Richard Liaw
parent
18f1e904de
commit
7e8a4a62ea
@@ -5,6 +5,8 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import copy
|
||||
import logging
|
||||
from functools import partial
|
||||
import pickle
|
||||
try:
|
||||
hyperopt_logger = logging.getLogger("hyperopt")
|
||||
hyperopt_logger.setLevel(logging.WARNING)
|
||||
@@ -24,7 +26,9 @@ class HyperOptSearch(SuggestionAlgorithm):
|
||||
Requires HyperOpt to be installed from source.
|
||||
Uses the Tree-structured Parzen Estimators algorithm, although can be
|
||||
trivially extended to support any algorithm HyperOpt uses. Externally
|
||||
added trials will not be tracked by HyperOpt.
|
||||
added trials will not be tracked by HyperOpt. Trials of the current run
|
||||
can be saved using save method, trials of a previous run can be loaded
|
||||
using restore method, thus enabling a warm start feature.
|
||||
|
||||
Parameters:
|
||||
space (dict): HyperOpt configuration. Parameters will be sampled
|
||||
@@ -42,6 +46,13 @@ class HyperOptSearch(SuggestionAlgorithm):
|
||||
a list of dict of hyperopt-named variables.
|
||||
Choice variables should be indicated by their index in the
|
||||
list (see example)
|
||||
n_initial_points (int): number of random evaluations of the
|
||||
objective function before starting to aproximate it with
|
||||
tree parzen estimators. Defaults to 20.
|
||||
random_state_seed (int, array_like, None): seed for reproducible
|
||||
results. Defaults to None.
|
||||
gamma (float in range (0,1)): parameter governing the tree parzen
|
||||
estimators suggestion algorithm. Defaults to 0.25.
|
||||
|
||||
Example:
|
||||
>>> space = {
|
||||
@@ -66,6 +77,9 @@ class HyperOptSearch(SuggestionAlgorithm):
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
points_to_evaluate=None,
|
||||
n_initial_points=20,
|
||||
random_state_seed=None,
|
||||
gamma=0.25,
|
||||
**kwargs):
|
||||
assert hpo is not None, "HyperOpt must be installed!"
|
||||
from hyperopt.fmin import generate_trials_to_calculate
|
||||
@@ -87,7 +101,13 @@ class HyperOptSearch(SuggestionAlgorithm):
|
||||
self._metric_op = -1.
|
||||
elif mode == "min":
|
||||
self._metric_op = 1.
|
||||
self.algo = hpo.tpe.suggest
|
||||
if n_initial_points is None:
|
||||
self.algo = hpo.tpe.suggest
|
||||
else:
|
||||
self.algo = partial(
|
||||
hpo.tpe.suggest, n_startup_jobs=n_initial_points)
|
||||
if gamma is not None:
|
||||
self.algo = partial(self.algo, gamma=gamma)
|
||||
self.domain = hpo.Domain(lambda spc: spc, space)
|
||||
if points_to_evaluate is None:
|
||||
self._hpopt_trials = hpo.Trials()
|
||||
@@ -99,7 +119,10 @@ class HyperOptSearch(SuggestionAlgorithm):
|
||||
self._hpopt_trials.refresh()
|
||||
self._points_to_evaluate = len(points_to_evaluate)
|
||||
self._live_trial_mapping = {}
|
||||
self.rstate = np.random.RandomState()
|
||||
if random_state_seed is None:
|
||||
self.rstate = np.random.RandomState()
|
||||
else:
|
||||
self.rstate = np.random.RandomState(random_state_seed)
|
||||
|
||||
super(HyperOptSearch, self).__init__(**kwargs)
|
||||
|
||||
@@ -183,3 +206,14 @@ class HyperOptSearch(SuggestionAlgorithm):
|
||||
|
||||
def _num_live_trials(self):
|
||||
return len(self._live_trial_mapping)
|
||||
|
||||
def save(self, checkpoint_dir):
|
||||
trials_object = (self._hpopt_trials, self.rstate.get_state())
|
||||
with open(checkpoint_dir, "wb") as output:
|
||||
pickle.dump(trials_object, output)
|
||||
|
||||
def restore(self, checkpoint_dir):
|
||||
with open(checkpoint_dir, "rb") as input:
|
||||
trials_object = pickle.load(input)
|
||||
self._hpopt_trials = trials_object[0]
|
||||
self.rstate.set_state(trials_object[1])
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from hyperopt import hp
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
@@ -13,6 +14,7 @@ from ray import tune
|
||||
from ray.tests.utils import recursive_fnmatch
|
||||
from ray.tune.util import validate_save_restore
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
|
||||
|
||||
class TuneRestoreTest(unittest.TestCase):
|
||||
@@ -112,5 +114,60 @@ class AutoInitTest(unittest.TestCase):
|
||||
_register_all()
|
||||
|
||||
|
||||
class HyperoptWarmStartTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(local_mode=True)
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdir)
|
||||
ray.shutdown()
|
||||
_register_all()
|
||||
|
||||
def set_basic_conf(self):
|
||||
space = {
|
||||
"x": hp.uniform("x", 0, 10),
|
||||
"y": hp.uniform("y", -10, 10),
|
||||
"z": hp.uniform("z", -10, 0)
|
||||
}
|
||||
|
||||
def cost(space, reporter):
|
||||
loss = space["x"]**2 + space["y"]**2 + space["z"]**2
|
||||
reporter(loss=loss)
|
||||
|
||||
search_alg = HyperOptSearch(
|
||||
space,
|
||||
max_concurrent=1,
|
||||
metric="loss",
|
||||
mode="min",
|
||||
random_state_seed=5)
|
||||
return search_alg, cost
|
||||
|
||||
def run_exp_1(self):
|
||||
search_alg, cost = self.set_basic_conf()
|
||||
results_exp_1 = tune.run(cost, num_samples=15, search_alg=search_alg)
|
||||
self.log_dir = os.path.join(self.tmpdir, "trials_algo1.pkl")
|
||||
search_alg.save(self.log_dir)
|
||||
return results_exp_1
|
||||
|
||||
def run_exp_2(self):
|
||||
search_alg2, cost = self.set_basic_conf()
|
||||
search_alg2.restore(self.log_dir)
|
||||
return tune.run(cost, num_samples=15, search_alg=search_alg2)
|
||||
|
||||
def run_exp_3(self):
|
||||
search_alg3, cost = self.set_basic_conf()
|
||||
return tune.run(cost, num_samples=30, search_alg=search_alg3)
|
||||
|
||||
def testHyperoptWarmStart(self):
|
||||
results_exp_1 = self.run_exp_1()
|
||||
results_exp_2 = self.run_exp_2()
|
||||
results_exp_3 = self.run_exp_3()
|
||||
trials_1_config = [trial.config for trial in results_exp_1.trials]
|
||||
trials_2_config = [trial.config for trial in results_exp_2.trials]
|
||||
trials_3_config = [trial.config for trial in results_exp_3.trials]
|
||||
self.assertEqual(trials_1_config + trials_2_config, trials_3_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user