mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:54:34 +08:00
[tune] allow tune search spaces to be passed to search algorithms (#11503)
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ax.service.ax_client import AxClient
|
||||
from ray.tune.sample import Categorical, Float, Integer, LogUniform, \
|
||||
Quantized, Uniform
|
||||
from ray.tune.suggest.suggestion import UNRESOLVED_SEARCH_SPACE
|
||||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.utils.util import unflatten_dict
|
||||
@@ -103,7 +104,7 @@ class AxSearch(Searcher):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
space: Optional[List[Dict]] = None,
|
||||
space: Optional[Union[Dict, List[Dict]]] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
parameter_constraints: Optional[List] = None,
|
||||
@@ -122,6 +123,15 @@ class AxSearch(Searcher):
|
||||
use_early_stopped_trials=use_early_stopped_trials)
|
||||
|
||||
self._ax = ax_client
|
||||
|
||||
if isinstance(space, dict) and space:
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
||||
if domain_vars or grid_vars:
|
||||
logger.warning(
|
||||
UNRESOLVED_SEARCH_SPACE.format(
|
||||
par="space", cls=type(self)))
|
||||
space = self.convert_search_space(space)
|
||||
|
||||
self._space = space
|
||||
self._parameter_constraints = parameter_constraints
|
||||
self._outcome_constraints = outcome_constraints
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Dict, Optional, Tuple
|
||||
|
||||
from ray.tune import ExperimentAnalysis
|
||||
from ray.tune.sample import Domain, Float, Quantized
|
||||
from ray.tune.suggest.suggestion import UNRESOLVED_SEARCH_SPACE
|
||||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils.util import unflatten_dict
|
||||
|
||||
@@ -186,6 +187,14 @@ class BayesOptSearch(Searcher):
|
||||
if analysis is not None:
|
||||
self.register_analysis(analysis)
|
||||
|
||||
if isinstance(space, dict) and space:
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
||||
if domain_vars or grid_vars:
|
||||
logger.warning(
|
||||
UNRESOLVED_SEARCH_SPACE.format(
|
||||
par="space", cls=type(self)))
|
||||
space = self.convert_search_space(space, join=True)
|
||||
|
||||
self._space = space
|
||||
self._verbose = verbose
|
||||
self._random_state = random_state
|
||||
@@ -354,7 +363,7 @@ class BayesOptSearch(Searcher):
|
||||
self._config_counter) = pickle.load(f)
|
||||
|
||||
@staticmethod
|
||||
def convert_search_space(spec: Dict) -> Dict:
|
||||
def convert_search_space(spec: Dict, join: bool = False) -> Dict:
|
||||
spec = flatten_dict(spec, prevent_delimiter=True)
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
@@ -387,4 +396,8 @@ class BayesOptSearch(Searcher):
|
||||
for path, domain in domain_vars
|
||||
}
|
||||
|
||||
if join:
|
||||
spec.update(bounds)
|
||||
bounds = spec
|
||||
|
||||
return bounds
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import ConfigSpace
|
||||
from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
|
||||
@@ -11,6 +11,7 @@ from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
|
||||
Quantized, \
|
||||
Uniform
|
||||
from ray.tune.suggest import Searcher
|
||||
from ray.tune.suggest.suggestion import UNRESOLVED_SEARCH_SPACE
|
||||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.utils.util import unflatten_dict
|
||||
@@ -93,7 +94,8 @@ class TuneBOHB(Searcher):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
space: Optional[ConfigSpace.ConfigurationSpace] = None,
|
||||
space: Optional[Union[Dict,
|
||||
ConfigSpace.ConfigurationSpace]] = None,
|
||||
bohb_config: Optional[Dict] = None,
|
||||
max_concurrent: int = 10,
|
||||
metric: Optional[str] = None,
|
||||
@@ -109,6 +111,15 @@ class TuneBOHB(Searcher):
|
||||
self._metric = metric
|
||||
|
||||
self._bohb_config = bohb_config
|
||||
|
||||
if isinstance(space, dict) and space:
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
||||
if domain_vars or grid_vars:
|
||||
logger.warning(
|
||||
UNRESOLVED_SEARCH_SPACE.format(
|
||||
par="space", cls=type(self)))
|
||||
space = self.convert_search_space(space)
|
||||
|
||||
self._space = space
|
||||
|
||||
super(TuneBOHB, self).__init__(metric=self._metric, mode=mode)
|
||||
|
||||
@@ -5,9 +5,10 @@ from __future__ import print_function
|
||||
import inspect
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ray.tune.sample import Domain, Float, Quantized
|
||||
from ray.tune.suggest.suggestion import UNRESOLVED_SEARCH_SPACE
|
||||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils.util import flatten_dict
|
||||
|
||||
@@ -53,7 +54,7 @@ class DragonflySearch(Searcher):
|
||||
domain (str): Optional domain. Should only be set if you don't pass
|
||||
an optimizer as the `optimizer` argument.
|
||||
Has to be one of [cartesian, euclidean].
|
||||
space (list): Search space. Should only be set if you don't pass
|
||||
space (list|dict): Search space. Should only be set if you don't pass
|
||||
an optimizer as the `optimizer` argument. Defines the search space
|
||||
and requires a `domain` to be set. Can be automatically converted
|
||||
from the `config` dict passed to `tune.run()`.
|
||||
@@ -131,7 +132,7 @@ class DragonflySearch(Searcher):
|
||||
def __init__(self,
|
||||
optimizer: Optional[BlackboxOptimiser] = None,
|
||||
domain: Optional[str] = None,
|
||||
space: Optional[List[Dict]] = None,
|
||||
space: Optional[Union[Dict, List[Dict]]] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
points_to_evaluate: Optional[List[List]] = None,
|
||||
@@ -148,6 +149,15 @@ class DragonflySearch(Searcher):
|
||||
|
||||
self._opt_arg = optimizer
|
||||
self._domain = domain
|
||||
|
||||
if isinstance(space, dict) and space:
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
||||
if domain_vars or grid_vars:
|
||||
logger.warning(
|
||||
UNRESOLVED_SEARCH_SPACE.format(
|
||||
par="space", cls=type(self)))
|
||||
space = self.convert_search_space(space)
|
||||
|
||||
self._space = space
|
||||
self._points_to_evaluate = points_to_evaluate
|
||||
self._evaluated_rewards = evaluated_rewards
|
||||
|
||||
@@ -10,6 +10,7 @@ from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
|
||||
Normal, \
|
||||
Quantized, \
|
||||
Uniform
|
||||
from ray.tune.suggest.suggestion import UNRESOLVED_SEARCH_SPACE
|
||||
from ray.tune.suggest.variant_generator import assign_value, parse_spec_vars
|
||||
|
||||
try:
|
||||
@@ -168,7 +169,13 @@ class HyperOptSearch(Searcher):
|
||||
self.rstate = np.random.RandomState(random_state_seed)
|
||||
|
||||
self.domain = None
|
||||
if space:
|
||||
if isinstance(space, dict) and space:
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
||||
if domain_vars or grid_vars:
|
||||
logger.warning(
|
||||
UNRESOLVED_SEARCH_SPACE.format(
|
||||
par="space", cls=type(self)))
|
||||
space = self.convert_search_space(space)
|
||||
self.domain = hpo.Domain(lambda spc: spc, space)
|
||||
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Dict, Optional, Union
|
||||
|
||||
from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
|
||||
Quantized
|
||||
from ray.tune.suggest.suggestion import UNRESOLVED_SEARCH_SPACE
|
||||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.utils.util import unflatten_dict
|
||||
@@ -93,7 +94,7 @@ class NevergradSearch(Searcher):
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Union[None, Optimizer, ConfiguredOptimizer] = None,
|
||||
space: Optional[Parameter] = None,
|
||||
space: Optional[Union[Dict, Parameter]] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
max_concurrent: Optional[int] = None,
|
||||
@@ -109,6 +110,14 @@ class NevergradSearch(Searcher):
|
||||
self._opt_factory = None
|
||||
self._nevergrad_opt = None
|
||||
|
||||
if isinstance(space, dict) and space:
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
||||
if domain_vars or grid_vars:
|
||||
logger.warning(
|
||||
UNRESOLVED_SEARCH_SPACE.format(
|
||||
par="space", cls=type(self)))
|
||||
space = self.convert_search_space(space)
|
||||
|
||||
if isinstance(optimizer, Optimizer):
|
||||
if space is not None or isinstance(space, list):
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
|
||||
Quantized, Uniform
|
||||
from ray.tune.suggest.suggestion import UNRESOLVED_SEARCH_SPACE
|
||||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.utils.util import unflatten_dict
|
||||
@@ -103,7 +104,7 @@ class OptunaSearch(Searcher):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
space: Optional[List[Tuple]] = None,
|
||||
space: Optional[Union[Dict, List[Tuple]]] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
sampler: Optional[BaseSampler] = None):
|
||||
@@ -115,6 +116,14 @@ class OptunaSearch(Searcher):
|
||||
max_concurrent=None,
|
||||
use_early_stopped_trials=None)
|
||||
|
||||
if isinstance(space, dict) and space:
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
||||
if domain_vars or grid_vars:
|
||||
logger.warning(
|
||||
UNRESOLVED_SEARCH_SPACE.format(
|
||||
par="space", cls=type(self)))
|
||||
space = self.convert_search_space(space)
|
||||
|
||||
self._space = space
|
||||
|
||||
self._study_name = "optuna" # Fixed study name for in-memory storage
|
||||
|
||||
@@ -3,6 +3,7 @@ import pickle
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ray.tune.sample import Categorical, Domain, Float, Integer, Quantized
|
||||
from ray.tune.suggest.suggestion import UNRESOLVED_SEARCH_SPACE
|
||||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.utils.util import unflatten_dict
|
||||
@@ -152,6 +153,14 @@ class SkOptSearch(Searcher):
|
||||
self._parameter_names = None
|
||||
self._parameter_ranges = None
|
||||
|
||||
if isinstance(space, dict) and space:
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
||||
if domain_vars or grid_vars:
|
||||
logger.warning(
|
||||
UNRESOLVED_SEARCH_SPACE.format(
|
||||
par="space", cls=type(self)))
|
||||
space = self.convert_search_space(space, join=True)
|
||||
|
||||
self._space = space
|
||||
|
||||
if self._space:
|
||||
@@ -269,7 +278,7 @@ class SkOptSearch(Searcher):
|
||||
self._skopt_opt = trials_object[1]
|
||||
|
||||
@staticmethod
|
||||
def convert_search_space(spec: Dict) -> Dict:
|
||||
def convert_search_space(spec: Dict, join: bool = False) -> Dict:
|
||||
spec = flatten_dict(spec, prevent_delimiter=True)
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
@@ -311,4 +320,8 @@ class SkOptSearch(Searcher):
|
||||
for path, domain in domain_vars
|
||||
}
|
||||
|
||||
if join:
|
||||
spec.update(space)
|
||||
space = spec
|
||||
|
||||
return space
|
||||
|
||||
@@ -8,6 +8,13 @@ from ray.util.debug import log_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
UNRESOLVED_SEARCH_SPACE = str(
|
||||
"You passed a `{par}` parameter to {cls} that contained unresolved search "
|
||||
"space definitions. {cls} should however be instantiated with fully "
|
||||
"configured search spaces only. To use Ray Tune's automatic search space "
|
||||
"conversion, pass the space definition as part of the `config` argument "
|
||||
"to `tune.run()` instead.")
|
||||
|
||||
|
||||
class Searcher:
|
||||
"""Abstract class for wrapping suggesting algorithms.
|
||||
|
||||
@@ -6,6 +6,7 @@ import ray
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.tune.sample import Categorical, Domain, Float, Integer, Quantized, \
|
||||
Uniform
|
||||
from ray.tune.suggest.suggestion import UNRESOLVED_SEARCH_SPACE
|
||||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils.util import unflatten_dict
|
||||
from zoopt import ValueType
|
||||
@@ -140,6 +141,15 @@ class ZOOptSearch(Searcher):
|
||||
], "`algo` must be in ['asracos', 'sracos'] currently"
|
||||
|
||||
self._algo = _algo
|
||||
|
||||
if isinstance(dim_dict, dict) and dim_dict:
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(dim_dict)
|
||||
if domain_vars or grid_vars:
|
||||
logger.warning(
|
||||
UNRESOLVED_SEARCH_SPACE.format(
|
||||
par="dim_dict", cls=type(self)))
|
||||
dim_dict = self.convert_search_space(dim_dict, join=True)
|
||||
|
||||
self._dim_dict = dim_dict
|
||||
self._budget = budget
|
||||
|
||||
@@ -243,12 +253,13 @@ class ZOOptSearch(Searcher):
|
||||
self.optimizer = trials_object
|
||||
|
||||
@staticmethod
|
||||
def convert_search_space(spec: Dict) -> Dict[str, Tuple]:
|
||||
def convert_search_space(spec: Dict,
|
||||
join: bool = False) -> Dict[str, Tuple]:
|
||||
spec = copy.deepcopy(spec)
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
if not domain_vars and not grid_vars:
|
||||
return []
|
||||
return {}
|
||||
|
||||
if grid_vars:
|
||||
raise ValueError(
|
||||
@@ -287,9 +298,13 @@ class ZOOptSearch(Searcher):
|
||||
type(domain).__name__,
|
||||
type(domain.sampler).__name__))
|
||||
|
||||
spec = {
|
||||
conv_spec = {
|
||||
"/".join(path): resolve_value(domain)
|
||||
for path, domain in domain_vars
|
||||
}
|
||||
|
||||
return spec
|
||||
if join:
|
||||
spec.update(conv_spec)
|
||||
conv_spec = spec
|
||||
|
||||
return conv_spec
|
||||
|
||||
@@ -216,6 +216,12 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
trial = analysis.trials[0]
|
||||
assert trial.config["a"] in [2, 3, 4]
|
||||
|
||||
mixed_config = {"a": tune.uniform(5, 6), "b": tune.uniform(8, 9)}
|
||||
searcher = AxSearch(space=mixed_config, metric="a", mode="max")
|
||||
config = searcher.suggest("0")
|
||||
self.assertTrue(5 <= config["a"] <= 6)
|
||||
self.assertTrue(8 <= config["b"] <= 9)
|
||||
|
||||
def testConvertBayesOpt(self):
|
||||
from ray.tune.suggest.bayesopt import BayesOptSearch
|
||||
|
||||
@@ -258,6 +264,12 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
trial = analysis.trials[0]
|
||||
self.assertLess(trial.config["b"]["z"], 1e-2)
|
||||
|
||||
mixed_config = {"a": tune.uniform(5, 6), "b": (8., 9.)}
|
||||
searcher = BayesOptSearch(space=mixed_config, metric="a", mode="max")
|
||||
config = searcher.suggest("0")
|
||||
self.assertTrue(5 <= config["a"] <= 6)
|
||||
self.assertTrue(8 <= config["b"] <= 9)
|
||||
|
||||
def testConvertBOHB(self):
|
||||
from ray.tune.suggest.bohb import TuneBOHB
|
||||
import ConfigSpace
|
||||
@@ -302,6 +314,15 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
self.assertIn(trial.config["a"], [2, 3, 4])
|
||||
self.assertEqual(trial.config["b"]["y"], 4)
|
||||
|
||||
mixed_config = {
|
||||
"a": tune.uniform(5, 6),
|
||||
"b": tune.uniform(8, 9) # Cannot mix ConfigSpace and Dict
|
||||
}
|
||||
searcher = TuneBOHB(space=mixed_config, metric="a", mode="max")
|
||||
config = searcher.suggest("0")
|
||||
self.assertTrue(5 <= config["a"] <= 6)
|
||||
self.assertTrue(8 <= config["b"] <= 9)
|
||||
|
||||
def testConvertDragonfly(self):
|
||||
from ray.tune.suggest.dragonfly import DragonflySearch
|
||||
|
||||
@@ -365,6 +386,21 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
trial = analysis.trials[0]
|
||||
self.assertLess(trial.config["point"], 1e-2)
|
||||
|
||||
mixed_config = {
|
||||
"a": tune.uniform(5, 6),
|
||||
"b": tune.uniform(8, 9) # Cannot mix List and Dict
|
||||
}
|
||||
searcher = DragonflySearch(
|
||||
space=mixed_config,
|
||||
optimizer="bandit",
|
||||
domain="euclidean",
|
||||
metric="a",
|
||||
mode="max")
|
||||
config = searcher.suggest("0")
|
||||
|
||||
self.assertTrue(5 <= config["point"][0] <= 6)
|
||||
self.assertTrue(8 <= config["point"][1] <= 9)
|
||||
|
||||
def testConvertHyperOpt(self):
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from hyperopt import hp
|
||||
@@ -408,6 +444,12 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
trial = analysis.trials[0]
|
||||
assert trial.config["a"] in [2, 3, 4]
|
||||
|
||||
mixed_config = {"a": tune.uniform(5, 6), "b": hp.uniform("b", 8, 9)}
|
||||
searcher = HyperOptSearch(space=mixed_config, metric="a", mode="max")
|
||||
config = searcher.suggest("0")
|
||||
self.assertTrue(5 <= config["a"] <= 6)
|
||||
self.assertTrue(8 <= config["b"] <= 9)
|
||||
|
||||
def testConvertHyperOptNested(self):
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
|
||||
@@ -496,6 +538,19 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
trial = analysis.trials[0]
|
||||
assert trial.config["a"] in [2, 3, 4]
|
||||
|
||||
mixed_config = {
|
||||
"a": tune.uniform(5, 6),
|
||||
"b": tune.uniform(8, 9) # Cannot mix Nevergrad cfg and tune
|
||||
}
|
||||
searcher = NevergradSearch(
|
||||
space=mixed_config,
|
||||
optimizer=ng.optimizers.OnePlusOne,
|
||||
metric="a",
|
||||
mode="max")
|
||||
config = searcher.suggest("0")
|
||||
self.assertTrue(5 <= config["a"] <= 6)
|
||||
self.assertTrue(8 <= config["b"] <= 9)
|
||||
|
||||
def testConvertOptuna(self):
|
||||
from ray.tune.suggest.optuna import OptunaSearch, param
|
||||
from optuna.samplers import RandomSampler
|
||||
@@ -536,6 +591,15 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
trial = analysis.trials[0]
|
||||
assert trial.config["a"] in [2, 3, 4]
|
||||
|
||||
mixed_config = {
|
||||
"a": tune.uniform(5, 6),
|
||||
"b": tune.uniform(8, 9) # Cannot mix List and Dict
|
||||
}
|
||||
searcher = OptunaSearch(space=mixed_config, metric="a", mode="max")
|
||||
config = searcher.suggest("0")
|
||||
self.assertTrue(5 <= config["a"] <= 6)
|
||||
self.assertTrue(8 <= config["b"] <= 9)
|
||||
|
||||
def testConvertSkOpt(self):
|
||||
from ray.tune.suggest.skopt import SkOptSearch
|
||||
|
||||
@@ -571,6 +635,12 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
self.assertIn(trial.config["a"], [2, 3, 4])
|
||||
self.assertEqual(trial.config["b"]["y"], 4)
|
||||
|
||||
mixed_config = {"a": tune.uniform(5, 6), "b": (8, 9)}
|
||||
searcher = SkOptSearch(space=mixed_config, metric="a", mode="max")
|
||||
config = searcher.suggest("0")
|
||||
self.assertTrue(5 <= config["a"] <= 6)
|
||||
self.assertTrue(8 <= config["b"] <= 9)
|
||||
|
||||
def testConvertZOOpt(self):
|
||||
from ray.tune.suggest.zoopt import ZOOptSearch
|
||||
from zoopt import ValueType
|
||||
@@ -627,6 +697,20 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
trial = analysis.trials[0]
|
||||
self.assertIn(trial.config["b"]["y"], [2, 4, 6, 8])
|
||||
|
||||
mixed_config = {
|
||||
"a": tune.uniform(5, 6),
|
||||
"b": (ValueType.CONTINUOUS, [8, 9], 1e-4)
|
||||
}
|
||||
searcher = ZOOptSearch(
|
||||
dim_dict=mixed_config,
|
||||
budget=5,
|
||||
metric="a",
|
||||
mode="max",
|
||||
**zoopt_search_config)
|
||||
config = searcher.suggest("0")
|
||||
self.assertTrue(5 <= config["a"] <= 6)
|
||||
self.assertTrue(8 <= config["b"] <= 9)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
Reference in New Issue
Block a user