[tune] Fix up Ax Search and Examples (#4851)

* update Ax for cleaner API

* docs update
This commit is contained in:
Richard Liaw
2019-05-27 13:23:17 -07:00
committed by GitHub
parent 7a78e1e320
commit 574e1c7695
3 changed files with 24 additions and 29 deletions
+5 -2
View File
@@ -51,11 +51,13 @@ def easy_objective(config, reporter):
if __name__ == "__main__":
import argparse
from ax.service.ax_client import AxClient
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
config = {
@@ -101,13 +103,14 @@ if __name__ == "__main__":
"bounds": [0.0, 1.0],
},
]
algo = AxSearch(
client = AxClient(enforce_sequential_optimization=False)
client.create_experiment(
parameters=parameters,
objective_name="hartmann6",
max_concurrent=4,
minimize=True, # Optional, defaults to False.
parameter_constraints=["x1 + x2 <= 2.0"], # Optional.
outcome_constraints=["l2norm <= 1.25"], # Optional.
)
algo = AxSearch(client, max_concurrent=4)
scheduler = AsyncHyperBandScheduler(reward_attr="hartmann6")
run(easy_objective, name="ax", search_alg=algo, **config)
+16 -26
View File
@@ -6,16 +6,19 @@ try:
import ax
except ImportError:
ax = None
import logging
from ray.tune.suggest.suggestion import SuggestionAlgorithm
logger = logging.getLogger(__name__)
class AxSearch(SuggestionAlgorithm):
"""A wrapper around Ax to provide trial suggestions.
Requires Ax to be installed.
Ax is an open source tool from Facebook for configuring and
optimizing experiments. More information can be found in https://ax.dev/.
Requires Ax to be installed. Ax is an open source tool from
Facebook for configuring and optimizing experiments. More information
can be found in https://ax.dev/.
Parameters:
parameters (list[dict]): Parameters in the experiment search space.
@@ -48,40 +51,27 @@ class AxSearch(SuggestionAlgorithm):
>>> objective_name="hartmann6", max_concurrent=4)
"""
def __init__(self,
parameters,
objective_name,
max_concurrent=10,
minimize=False,
parameter_constraints=None,
outcome_constraints=None,
**kwargs):
def __init__(self, ax_client, max_concurrent=10, **kwargs):
assert ax is not None, "Ax must be installed!"
from ax.service import ax_client
assert type(max_concurrent) is int and max_concurrent > 0
self._ax = ax_client.AxClient(enforce_sequential_optimization=False)
self._ax.create_experiment(
name="ax",
parameters=parameters,
objective_name=objective_name,
minimize=minimize,
parameter_constraints=parameter_constraints or [],
outcome_constraints=outcome_constraints or [],
)
self._ax = ax_client
exp = self._ax.experiment
self._objective_name = exp.optimization_config.objective.metric.name
if self._ax._enforce_sequential_optimization:
logger.warning("Detected sequential enforcement. Setting max "
"concurrency to 1.")
max_concurrent = 1
self._max_concurrent = max_concurrent
self._parameters = [d["name"] for d in parameters]
self._objective_name = objective_name
self._parameters = list(exp.parameters)
self._live_index_mapping = {}
super(AxSearch, self).__init__(**kwargs)
def _suggest(self, trial_id):
if self._num_live_trials() >= self._max_concurrent:
return None
parameters, trial_index = self._ax.get_next_trial()
suggested_config = list(parameters.values())
self._live_index_mapping[trial_id] = trial_index
return dict(zip(self._parameters, suggested_config))
return parameters
def on_trial_result(self, trial_id, result):
pass