mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:53:14 +08:00
[tune] support hierarchical search spaces for hyperopt (#11431)
* support hierarchical search spaces for hyperopt * Reduce num samples * Fix prefix
This commit is contained in:
@@ -304,7 +304,7 @@ class HyperOptSearch(Searcher):
|
||||
self.set_state(trials_object)
|
||||
|
||||
@staticmethod
|
||||
def convert_search_space(spec: Dict) -> Dict:
|
||||
def convert_search_space(spec: Dict, prefix: str = "") -> Dict:
|
||||
spec = copy.deepcopy(spec)
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
@@ -361,7 +361,17 @@ class HyperOptSearch(Searcher):
|
||||
return hpo.hp.randint(par, domain.upper)
|
||||
elif isinstance(domain, Categorical):
|
||||
if isinstance(sampler, Uniform):
|
||||
return hpo.hp.choice(par, domain.categories)
|
||||
return hpo.hp.choice(par, [
|
||||
HyperOptSearch.convert_search_space(
|
||||
category, prefix=par)
|
||||
if isinstance(category, dict) else
|
||||
HyperOptSearch.convert_search_space(
|
||||
dict(enumerate(category)), prefix=f"{par}/{i}")
|
||||
if isinstance(category, list) else resolve_value(
|
||||
f"{par}/{i}", category)
|
||||
if isinstance(category, Domain) else category
|
||||
for i, category in enumerate(domain.categories)
|
||||
])
|
||||
|
||||
raise ValueError("HyperOpt does not support parameters of type "
|
||||
"`{}` with samplers of type `{}`".format(
|
||||
@@ -369,7 +379,8 @@ class HyperOptSearch(Searcher):
|
||||
type(domain.sampler).__name__))
|
||||
|
||||
for path, domain in domain_vars:
|
||||
par = "/".join(path)
|
||||
par = "/".join(
|
||||
[str(p) for p in ((prefix, ) + path if prefix else path)])
|
||||
value = resolve_value(par, domain)
|
||||
assign_value(spec, path, value)
|
||||
|
||||
|
||||
@@ -408,6 +408,52 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
trial = analysis.trials[0]
|
||||
assert trial.config["a"] in [2, 3, 4]
|
||||
|
||||
def testConvertHyperOptNested(self):
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
|
||||
config = {
|
||||
"a": 1,
|
||||
"dict_nested": tune.sample.Categorical([{
|
||||
"a": tune.sample.Categorical(["M", "N"]),
|
||||
"b": tune.sample.Categorical(["O", "P"])
|
||||
}]).uniform(),
|
||||
"list_nested": tune.sample.Categorical([
|
||||
[
|
||||
tune.sample.Categorical(["M", "N"]),
|
||||
tune.sample.Categorical(["O", "P"])
|
||||
],
|
||||
[
|
||||
tune.sample.Categorical(["Q", "R"]),
|
||||
tune.sample.Categorical(["S", "T"])
|
||||
],
|
||||
]).uniform(),
|
||||
"domain_nested": tune.sample.Categorical([
|
||||
tune.sample.Categorical(["M", "N"]),
|
||||
tune.sample.Categorical(["O", "P"])
|
||||
]).uniform(),
|
||||
}
|
||||
|
||||
searcher = HyperOptSearch(metric="a", mode="max")
|
||||
analysis = tune.run(
|
||||
_mock_objective,
|
||||
config=config,
|
||||
search_alg=searcher,
|
||||
num_samples=10)
|
||||
|
||||
for trial in analysis.trials:
|
||||
config = trial.config
|
||||
|
||||
self.assertIn(config["dict_nested"]["a"], ["M", "N"])
|
||||
self.assertIn(config["dict_nested"]["b"], ["O", "P"])
|
||||
|
||||
if config["list_nested"][0] in ["M", "N"]:
|
||||
self.assertIn(config["list_nested"][1], ["O", "P"])
|
||||
else:
|
||||
self.assertIn(config["list_nested"][0], ["Q", "R"])
|
||||
self.assertIn(config["list_nested"][1], ["S", "T"])
|
||||
|
||||
self.assertIn(config["domain_nested"], ["M", "N", "O", "P"])
|
||||
|
||||
def testConvertNevergrad(self):
|
||||
from ray.tune.suggest.nevergrad import NevergradSearch
|
||||
import nevergrad as ng
|
||||
|
||||
@@ -251,7 +251,7 @@ def flatten_dict(dt, delimiter="/", prevent_delimiter=False):
|
||||
"Found delimiter `{}` in key when trying to "
|
||||
"flatten array. Please avoid using the delimiter "
|
||||
"in your specification.")
|
||||
add[delimiter.join([key, subkey])] = v
|
||||
add[delimiter.join([key, str(subkey)])] = v
|
||||
remove.append(key)
|
||||
dt.update(add)
|
||||
for k in remove:
|
||||
|
||||
Reference in New Issue
Block a user