mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 04:59:12 +08:00
[tune] Add Ax to Tune (#4731)
This commit is contained in:
committed by
Richard Liaw
parent
0421cba4e8
commit
28d381373d
@@ -0,0 +1,113 @@
|
||||
"""This test checks that AxSearch is functional.
|
||||
|
||||
It also checks that it is usable with a separate scheduler.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.tune import run
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
from ray.tune.suggest.ax import AxSearch
|
||||
|
||||
|
||||
def hartmann6(x):
|
||||
alpha = np.array([1.0, 1.2, 3.0, 3.2])
|
||||
A = np.array([
|
||||
[10, 3, 17, 3.5, 1.7, 8],
|
||||
[0.05, 10, 17, 0.1, 8, 14],
|
||||
[3, 3.5, 1.7, 10, 17, 8],
|
||||
[17, 8, 0.05, 10, 0.1, 14],
|
||||
])
|
||||
P = 10**(-4) * np.array([
|
||||
[1312, 1696, 5569, 124, 8283, 5886],
|
||||
[2329, 4135, 8307, 3736, 1004, 9991],
|
||||
[2348, 1451, 3522, 2883, 3047, 6650],
|
||||
[4047, 8828, 8732, 5743, 1091, 381],
|
||||
])
|
||||
y = 0.0
|
||||
for j, alpha_j in enumerate(alpha):
|
||||
t = 0
|
||||
for k in range(6):
|
||||
t += A[j, k] * ((x[k] - P[j, k])**2)
|
||||
y -= alpha_j * np.exp(-t)
|
||||
return y
|
||||
|
||||
|
||||
def easy_objective(config, reporter):
|
||||
import time
|
||||
time.sleep(0.2)
|
||||
for i in range(config["iterations"]):
|
||||
x = np.array([config.get(f"x{i+1}") for i in range(6)])
|
||||
reporter(
|
||||
timesteps_total=i,
|
||||
hartmann6=hartmann6(x),
|
||||
l2norm=np.sqrt((x**2).sum()))
|
||||
time.sleep(0.02)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init()
|
||||
|
||||
config = {
|
||||
"num_samples": 10 if args.smoke_test else 50,
|
||||
"config": {
|
||||
"iterations": 100,
|
||||
},
|
||||
"stop": {
|
||||
"timesteps_total": 100
|
||||
}
|
||||
}
|
||||
parameters = [
|
||||
{
|
||||
"name": "x1",
|
||||
"type": "range",
|
||||
"bounds": [0.0, 1.0],
|
||||
"value_type": "float", # Optional, defaults to "bounds".
|
||||
"log_scale": False, # Optional, defaults to False.
|
||||
},
|
||||
{
|
||||
"name": "x2",
|
||||
"type": "range",
|
||||
"bounds": [0.0, 1.0],
|
||||
},
|
||||
{
|
||||
"name": "x3",
|
||||
"type": "range",
|
||||
"bounds": [0.0, 1.0],
|
||||
},
|
||||
{
|
||||
"name": "x4",
|
||||
"type": "range",
|
||||
"bounds": [0.0, 1.0],
|
||||
},
|
||||
{
|
||||
"name": "x5",
|
||||
"type": "range",
|
||||
"bounds": [0.0, 1.0],
|
||||
},
|
||||
{
|
||||
"name": "x6",
|
||||
"type": "range",
|
||||
"bounds": [0.0, 1.0],
|
||||
},
|
||||
]
|
||||
algo = AxSearch(
|
||||
parameters=parameters,
|
||||
objective_name="hartmann6",
|
||||
max_concurrent=4,
|
||||
minimize=True, # Optional, defaults to False.
|
||||
parameter_constraints=["x1 + x2 <= 2.0"], # Optional.
|
||||
outcome_constraints=["l2norm <= 1.25"], # Optional.
|
||||
)
|
||||
scheduler = AsyncHyperBandScheduler(reward_attr="hartmann6")
|
||||
run(easy_objective, name="ax", search_alg=algo, **config)
|
||||
@@ -0,0 +1,112 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
try:
|
||||
import ax
|
||||
except ImportError:
|
||||
ax = None
|
||||
|
||||
from ray.tune.suggest.suggestion import SuggestionAlgorithm
|
||||
|
||||
|
||||
class AxSearch(SuggestionAlgorithm):
|
||||
"""A wrapper around Ax to provide trial suggestions.
|
||||
|
||||
Requires Ax to be installed.
|
||||
Ax is an open source tool from Facebook for configuring and
|
||||
optimizing experiments. More information can be found in https://ax.dev/.
|
||||
|
||||
Parameters:
|
||||
parameters (list[dict]): Parameters in the experiment search space.
|
||||
Required elements in the dictionaries are: "name" (name of
|
||||
this parameter, string), "type" (type of the parameter: "range",
|
||||
"fixed", or "choice", string), "bounds" for range parameters
|
||||
(list of two values, lower bound first), "values" for choice
|
||||
parameters (list of values), and "value" for fixed parameters
|
||||
(single value).
|
||||
objective_name (str): Name of the metric used as objective in this
|
||||
experiment. This metric must be present in `raw_data` argument
|
||||
to `log_data`. This metric must also be present in the dict
|
||||
reported/returned by the Trainable.
|
||||
max_concurrent (int): Number of maximum concurrent trials. Defaults
|
||||
to 10.
|
||||
minimize (bool): Whether this experiment represents a minimization
|
||||
problem. Defaults to False.
|
||||
parameter_constraints (list[str]): Parameter constraints, such as
|
||||
"x3 >= x4" or "x3 + x4 >= 2".
|
||||
outcome_constraints (list[str]): Outcome constraints of form
|
||||
"metric_name >= bound", like "m1 <= 3."
|
||||
|
||||
|
||||
Example:
|
||||
>>> parameters = [
|
||||
>>> {"name": "x1", "type": "range", "bounds": [0.0, 1.0]},
|
||||
>>> {"name": "x2", "type": "range", "bounds": [0.0, 1.0]},
|
||||
>>> ]
|
||||
>>> algo = AxSearch(parameters=parameters,
|
||||
>>> objective_name="hartmann6", max_concurrent=4)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
parameters,
|
||||
objective_name,
|
||||
max_concurrent=10,
|
||||
minimize=False,
|
||||
parameter_constraints=None,
|
||||
outcome_constraints=None,
|
||||
**kwargs):
|
||||
assert ax is not None, "Ax must be installed!"
|
||||
from ax.service import ax_client
|
||||
assert type(max_concurrent) is int and max_concurrent > 0
|
||||
self._ax = ax_client.AxClient(enforce_sequential_optimization=False)
|
||||
self._ax.create_experiment(
|
||||
name="ax",
|
||||
parameters=parameters,
|
||||
objective_name=objective_name,
|
||||
minimize=minimize,
|
||||
parameter_constraints=parameter_constraints or [],
|
||||
outcome_constraints=outcome_constraints or [],
|
||||
)
|
||||
self._max_concurrent = max_concurrent
|
||||
self._parameters = [d["name"] for d in parameters]
|
||||
self._objective_name = objective_name
|
||||
self._live_index_mapping = {}
|
||||
|
||||
super(AxSearch, self).__init__(**kwargs)
|
||||
|
||||
def _suggest(self, trial_id):
|
||||
if self._num_live_trials() >= self._max_concurrent:
|
||||
return None
|
||||
parameters, trial_index = self._ax.get_next_trial()
|
||||
suggested_config = list(parameters.values())
|
||||
self._live_index_mapping[trial_id] = trial_index
|
||||
return dict(zip(self._parameters, suggested_config))
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
pass
|
||||
|
||||
def on_trial_complete(self,
|
||||
trial_id,
|
||||
result=None,
|
||||
error=False,
|
||||
early_terminated=False):
|
||||
"""Pass data back to Ax.
|
||||
|
||||
Data of form key value dictionary of metric names and values.
|
||||
"""
|
||||
ax_trial_index = self._live_index_mapping.pop(trial_id)
|
||||
if result:
|
||||
metric_dict = {
|
||||
self._objective_name: (result[self._objective_name], 0.0)
|
||||
}
|
||||
outcome_names = [
|
||||
oc.metric.name for oc in
|
||||
self._ax.experiment.optimization_config.outcome_constraints
|
||||
]
|
||||
metric_dict.update({on: (result[on], 0.0) for on in outcome_names})
|
||||
self._ax.complete_trial(
|
||||
trial_index=ax_trial_index, raw_data=metric_dict)
|
||||
|
||||
def _num_live_trials(self):
|
||||
return len(self._live_index_mapping)
|
||||
Reference in New Issue
Block a user