[tune] Add Ax to Tune (#4731)

This commit is contained in:
Adi Zimmerman
2019-05-08 15:54:29 -07:00
committed by Richard Liaw
parent 0421cba4e8
commit 28d381373d
3 changed files with 249 additions and 0 deletions
+113
View File
@@ -0,0 +1,113 @@
"""This test checks that AxSearch is functional.
It also checks that it is usable with a separate scheduler.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import ray
from ray.tune import run
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.suggest.ax import AxSearch
def hartmann6(x):
alpha = np.array([1.0, 1.2, 3.0, 3.2])
A = np.array([
[10, 3, 17, 3.5, 1.7, 8],
[0.05, 10, 17, 0.1, 8, 14],
[3, 3.5, 1.7, 10, 17, 8],
[17, 8, 0.05, 10, 0.1, 14],
])
P = 10**(-4) * np.array([
[1312, 1696, 5569, 124, 8283, 5886],
[2329, 4135, 8307, 3736, 1004, 9991],
[2348, 1451, 3522, 2883, 3047, 6650],
[4047, 8828, 8732, 5743, 1091, 381],
])
y = 0.0
for j, alpha_j in enumerate(alpha):
t = 0
for k in range(6):
t += A[j, k] * ((x[k] - P[j, k])**2)
y -= alpha_j * np.exp(-t)
return y
def easy_objective(config, reporter):
import time
time.sleep(0.2)
for i in range(config["iterations"]):
x = np.array([config.get(f"x{i+1}") for i in range(6)])
reporter(
timesteps_total=i,
hartmann6=hartmann6(x),
l2norm=np.sqrt((x**2).sum()))
time.sleep(0.02)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
config = {
"num_samples": 10 if args.smoke_test else 50,
"config": {
"iterations": 100,
},
"stop": {
"timesteps_total": 100
}
}
parameters = [
{
"name": "x1",
"type": "range",
"bounds": [0.0, 1.0],
"value_type": "float", # Optional, defaults to "bounds".
"log_scale": False, # Optional, defaults to False.
},
{
"name": "x2",
"type": "range",
"bounds": [0.0, 1.0],
},
{
"name": "x3",
"type": "range",
"bounds": [0.0, 1.0],
},
{
"name": "x4",
"type": "range",
"bounds": [0.0, 1.0],
},
{
"name": "x5",
"type": "range",
"bounds": [0.0, 1.0],
},
{
"name": "x6",
"type": "range",
"bounds": [0.0, 1.0],
},
]
algo = AxSearch(
parameters=parameters,
objective_name="hartmann6",
max_concurrent=4,
minimize=True, # Optional, defaults to False.
parameter_constraints=["x1 + x2 <= 2.0"], # Optional.
outcome_constraints=["l2norm <= 1.25"], # Optional.
)
scheduler = AsyncHyperBandScheduler(reward_attr="hartmann6")
run(easy_objective, name="ax", search_alg=algo, **config)
+112
View File
@@ -0,0 +1,112 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
import ax
except ImportError:
ax = None
from ray.tune.suggest.suggestion import SuggestionAlgorithm
class AxSearch(SuggestionAlgorithm):
"""A wrapper around Ax to provide trial suggestions.
Requires Ax to be installed.
Ax is an open source tool from Facebook for configuring and
optimizing experiments. More information can be found in https://ax.dev/.
Parameters:
parameters (list[dict]): Parameters in the experiment search space.
Required elements in the dictionaries are: "name" (name of
this parameter, string), "type" (type of the parameter: "range",
"fixed", or "choice", string), "bounds" for range parameters
(list of two values, lower bound first), "values" for choice
parameters (list of values), and "value" for fixed parameters
(single value).
objective_name (str): Name of the metric used as objective in this
experiment. This metric must be present in `raw_data` argument
to `log_data`. This metric must also be present in the dict
reported/returned by the Trainable.
max_concurrent (int): Number of maximum concurrent trials. Defaults
to 10.
minimize (bool): Whether this experiment represents a minimization
problem. Defaults to False.
parameter_constraints (list[str]): Parameter constraints, such as
"x3 >= x4" or "x3 + x4 >= 2".
outcome_constraints (list[str]): Outcome constraints of form
"metric_name >= bound", like "m1 <= 3."
Example:
>>> parameters = [
>>> {"name": "x1", "type": "range", "bounds": [0.0, 1.0]},
>>> {"name": "x2", "type": "range", "bounds": [0.0, 1.0]},
>>> ]
>>> algo = AxSearch(parameters=parameters,
>>> objective_name="hartmann6", max_concurrent=4)
"""
def __init__(self,
parameters,
objective_name,
max_concurrent=10,
minimize=False,
parameter_constraints=None,
outcome_constraints=None,
**kwargs):
assert ax is not None, "Ax must be installed!"
from ax.service import ax_client
assert type(max_concurrent) is int and max_concurrent > 0
self._ax = ax_client.AxClient(enforce_sequential_optimization=False)
self._ax.create_experiment(
name="ax",
parameters=parameters,
objective_name=objective_name,
minimize=minimize,
parameter_constraints=parameter_constraints or [],
outcome_constraints=outcome_constraints or [],
)
self._max_concurrent = max_concurrent
self._parameters = [d["name"] for d in parameters]
self._objective_name = objective_name
self._live_index_mapping = {}
super(AxSearch, self).__init__(**kwargs)
def _suggest(self, trial_id):
if self._num_live_trials() >= self._max_concurrent:
return None
parameters, trial_index = self._ax.get_next_trial()
suggested_config = list(parameters.values())
self._live_index_mapping[trial_id] = trial_index
return dict(zip(self._parameters, suggested_config))
def on_trial_result(self, trial_id, result):
pass
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
"""Pass data back to Ax.
Data of form key value dictionary of metric names and values.
"""
ax_trial_index = self._live_index_mapping.pop(trial_id)
if result:
metric_dict = {
self._objective_name: (result[self._objective_name], 0.0)
}
outcome_names = [
oc.metric.name for oc in
self._ax.experiment.optimization_config.outcome_constraints
]
metric_dict.update({on: (result[on], 0.0) for on in outcome_names})
self._ax.complete_trial(
trial_index=ax_trial_index, raw_data=metric_dict)
def _num_live_trials(self):
return len(self._live_index_mapping)