diff --git a/python/ray/tune/suggest/hyperopt.py b/python/ray/tune/suggest/hyperopt.py index f129ee129..d3424dafb 100644 --- a/python/ray/tune/suggest/hyperopt.py +++ b/python/ray/tune/suggest/hyperopt.py @@ -304,7 +304,7 @@ class HyperOptSearch(Searcher): self.set_state(trials_object) @staticmethod - def convert_search_space(spec: Dict) -> Dict: + def convert_search_space(spec: Dict, prefix: str = "") -> Dict: spec = copy.deepcopy(spec) resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec) @@ -361,7 +361,17 @@ class HyperOptSearch(Searcher): return hpo.hp.randint(par, domain.upper) elif isinstance(domain, Categorical): if isinstance(sampler, Uniform): - return hpo.hp.choice(par, domain.categories) + return hpo.hp.choice(par, [ + HyperOptSearch.convert_search_space( + category, prefix=par) + if isinstance(category, dict) else + HyperOptSearch.convert_search_space( + dict(enumerate(category)), prefix=f"{par}/{i}") + if isinstance(category, list) else resolve_value( + f"{par}/{i}", category) + if isinstance(category, Domain) else category + for i, category in enumerate(domain.categories) + ]) raise ValueError("HyperOpt does not support parameters of type " "`{}` with samplers of type `{}`".format( @@ -369,7 +379,8 @@ class HyperOptSearch(Searcher): type(domain.sampler).__name__)) for path, domain in domain_vars: - par = "/".join(path) + par = "/".join( + [str(p) for p in ((prefix, ) + path if prefix else path)]) value = resolve_value(par, domain) assign_value(spec, path, value) diff --git a/python/ray/tune/tests/test_sample.py b/python/ray/tune/tests/test_sample.py index ca61627fd..3cae66ed5 100644 --- a/python/ray/tune/tests/test_sample.py +++ b/python/ray/tune/tests/test_sample.py @@ -408,6 +408,52 @@ class SearchSpaceTest(unittest.TestCase): trial = analysis.trials[0] assert trial.config["a"] in [2, 3, 4] + def testConvertHyperOptNested(self): + from ray.tune.suggest.hyperopt import HyperOptSearch + + config = { + "a": 1, + "dict_nested": tune.sample.Categorical([{ + "a": tune.sample.Categorical(["M", "N"]), + "b": tune.sample.Categorical(["O", "P"]) + }]).uniform(), + "list_nested": tune.sample.Categorical([ + [ + tune.sample.Categorical(["M", "N"]), + tune.sample.Categorical(["O", "P"]) + ], + [ + tune.sample.Categorical(["Q", "R"]), + tune.sample.Categorical(["S", "T"]) + ], + ]).uniform(), + "domain_nested": tune.sample.Categorical([ + tune.sample.Categorical(["M", "N"]), + tune.sample.Categorical(["O", "P"]) + ]).uniform(), + } + + searcher = HyperOptSearch(metric="a", mode="max") + analysis = tune.run( + _mock_objective, + config=config, + search_alg=searcher, + num_samples=10) + + for trial in analysis.trials: + config = trial.config + + self.assertIn(config["dict_nested"]["a"], ["M", "N"]) + self.assertIn(config["dict_nested"]["b"], ["O", "P"]) + + if config["list_nested"][0] in ["M", "N"]: + self.assertIn(config["list_nested"][1], ["O", "P"]) + else: + self.assertIn(config["list_nested"][0], ["Q", "R"]) + self.assertIn(config["list_nested"][1], ["S", "T"]) + + self.assertIn(config["domain_nested"], ["M", "N", "O", "P"]) + def testConvertNevergrad(self): from ray.tune.suggest.nevergrad import NevergradSearch import nevergrad as ng diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index 45b8fca09..e77449a76 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -251,7 +251,7 @@ def flatten_dict(dt, delimiter="/", prevent_delimiter=False): "Found delimiter `{}` in key when trying to " "flatten array. Please avoid using the delimiter " "in your specification.") - add[delimiter.join([key, subkey])] = v + add[delimiter.join([key, str(subkey)])] = v remove.append(key) dt.update(add) for k in remove: