mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 19:32:38 +08:00
[tune] Fix up Ax Search and Examples (#4851)
* update Ax for cleaner API * docs update
This commit is contained in:
@@ -51,11 +51,13 @@ def easy_objective(config, reporter):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from ax.service.ax_client import AxClient
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
ray.init()
|
||||
|
||||
config = {
|
||||
@@ -101,13 +103,14 @@ if __name__ == "__main__":
|
||||
"bounds": [0.0, 1.0],
|
||||
},
|
||||
]
|
||||
algo = AxSearch(
|
||||
client = AxClient(enforce_sequential_optimization=False)
|
||||
client.create_experiment(
|
||||
parameters=parameters,
|
||||
objective_name="hartmann6",
|
||||
max_concurrent=4,
|
||||
minimize=True, # Optional, defaults to False.
|
||||
parameter_constraints=["x1 + x2 <= 2.0"], # Optional.
|
||||
outcome_constraints=["l2norm <= 1.25"], # Optional.
|
||||
)
|
||||
algo = AxSearch(client, max_concurrent=4)
|
||||
scheduler = AsyncHyperBandScheduler(reward_attr="hartmann6")
|
||||
run(easy_objective, name="ax", search_alg=algo, **config)
|
||||
|
||||
@@ -6,16 +6,19 @@ try:
|
||||
import ax
|
||||
except ImportError:
|
||||
ax = None
|
||||
import logging
|
||||
|
||||
from ray.tune.suggest.suggestion import SuggestionAlgorithm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AxSearch(SuggestionAlgorithm):
|
||||
"""A wrapper around Ax to provide trial suggestions.
|
||||
|
||||
Requires Ax to be installed.
|
||||
Ax is an open source tool from Facebook for configuring and
|
||||
optimizing experiments. More information can be found in https://ax.dev/.
|
||||
Requires Ax to be installed. Ax is an open source tool from
|
||||
Facebook for configuring and optimizing experiments. More information
|
||||
can be found in https://ax.dev/.
|
||||
|
||||
Parameters:
|
||||
parameters (list[dict]): Parameters in the experiment search space.
|
||||
@@ -48,40 +51,27 @@ class AxSearch(SuggestionAlgorithm):
|
||||
>>> objective_name="hartmann6", max_concurrent=4)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
parameters,
|
||||
objective_name,
|
||||
max_concurrent=10,
|
||||
minimize=False,
|
||||
parameter_constraints=None,
|
||||
outcome_constraints=None,
|
||||
**kwargs):
|
||||
def __init__(self, ax_client, max_concurrent=10, **kwargs):
|
||||
assert ax is not None, "Ax must be installed!"
|
||||
from ax.service import ax_client
|
||||
assert type(max_concurrent) is int and max_concurrent > 0
|
||||
self._ax = ax_client.AxClient(enforce_sequential_optimization=False)
|
||||
self._ax.create_experiment(
|
||||
name="ax",
|
||||
parameters=parameters,
|
||||
objective_name=objective_name,
|
||||
minimize=minimize,
|
||||
parameter_constraints=parameter_constraints or [],
|
||||
outcome_constraints=outcome_constraints or [],
|
||||
)
|
||||
self._ax = ax_client
|
||||
exp = self._ax.experiment
|
||||
self._objective_name = exp.optimization_config.objective.metric.name
|
||||
if self._ax._enforce_sequential_optimization:
|
||||
logger.warning("Detected sequential enforcement. Setting max "
|
||||
"concurrency to 1.")
|
||||
max_concurrent = 1
|
||||
self._max_concurrent = max_concurrent
|
||||
self._parameters = [d["name"] for d in parameters]
|
||||
self._objective_name = objective_name
|
||||
self._parameters = list(exp.parameters)
|
||||
self._live_index_mapping = {}
|
||||
|
||||
super(AxSearch, self).__init__(**kwargs)
|
||||
|
||||
def _suggest(self, trial_id):
|
||||
if self._num_live_trials() >= self._max_concurrent:
|
||||
return None
|
||||
parameters, trial_index = self._ax.get_next_trial()
|
||||
suggested_config = list(parameters.values())
|
||||
self._live_index_mapping[trial_id] = trial_index
|
||||
return dict(zip(self._parameters, suggested_config))
|
||||
return parameters
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user