diff --git a/doc/source/tune-searchalg.rst b/doc/source/tune-searchalg.rst index 07bee59d1..0a2bf491a 100644 --- a/doc/source/tune-searchalg.rst +++ b/doc/source/tune-searchalg.rst @@ -17,6 +17,7 @@ Currently, Tune offers the following search algorithms (and library integrations - `SigOpt `__ - `Nevergrad `__ - `Scikit-Optimize `__ +- `Ax `__ Variant Generation (Grid Search/Random Search) @@ -155,6 +156,29 @@ An example of this can be found in `skopt_example.py `__ to perform sequential model-based hyperparameter optimization. Ax is a platform for understanding, managing, deploying, and automating adaptive experiments. Ax provides an easy to use interface with BoTorch, a flexible, modern library for Bayesian optimization in PyTorch. Note that this class does not extend ``ray.tune.suggest.BasicVariantGenerator``, so you will not be able to use Tune's default variant generation/search space declaration when using AxSearch. + +In order to use this search algorithm, you will need to install PyTorch, Ax, and sqlalchemy. Instructions to install PyTorch locally can be found `here `__. You can install Ax and sqlalchemy via the following command: + +.. code-block:: bash + + $ pip install ax-platform sqlalchemy + +This algorithm requires specifying a search space and objective. You can use `AxSearch` like follows: + +.. code-block:: python + + tune.run(... , search_alg=AxSearch(parameter_dicts, ... )) + +An example of this can be found in `ax_example.py `__. + +.. autoclass:: ray.tune.suggest.ax.AxSearch + :show-inheritance: + :noindex: + Contributing a New Algorithm ---------------------------- diff --git a/python/ray/tune/examples/ax_example.py b/python/ray/tune/examples/ax_example.py new file mode 100644 index 000000000..07bb7f79a --- /dev/null +++ b/python/ray/tune/examples/ax_example.py @@ -0,0 +1,113 @@ +"""This test checks that AxSearch is functional. + +It also checks that it is usable with a separate scheduler. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import ray +from ray.tune import run +from ray.tune.schedulers import AsyncHyperBandScheduler +from ray.tune.suggest.ax import AxSearch + + +def hartmann6(x): + alpha = np.array([1.0, 1.2, 3.0, 3.2]) + A = np.array([ + [10, 3, 17, 3.5, 1.7, 8], + [0.05, 10, 17, 0.1, 8, 14], + [3, 3.5, 1.7, 10, 17, 8], + [17, 8, 0.05, 10, 0.1, 14], + ]) + P = 10**(-4) * np.array([ + [1312, 1696, 5569, 124, 8283, 5886], + [2329, 4135, 8307, 3736, 1004, 9991], + [2348, 1451, 3522, 2883, 3047, 6650], + [4047, 8828, 8732, 5743, 1091, 381], + ]) + y = 0.0 + for j, alpha_j in enumerate(alpha): + t = 0 + for k in range(6): + t += A[j, k] * ((x[k] - P[j, k])**2) + y -= alpha_j * np.exp(-t) + return y + + +def easy_objective(config, reporter): + import time + time.sleep(0.2) + for i in range(config["iterations"]): + x = np.array([config.get(f"x{i+1}") for i in range(6)]) + reporter( + timesteps_total=i, + hartmann6=hartmann6(x), + l2norm=np.sqrt((x**2).sum())) + time.sleep(0.02) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + ray.init() + + config = { + "num_samples": 10 if args.smoke_test else 50, + "config": { + "iterations": 100, + }, + "stop": { + "timesteps_total": 100 + } + } + parameters = [ + { + "name": "x1", + "type": "range", + "bounds": [0.0, 1.0], + "value_type": "float", # Optional, defaults to "bounds". + "log_scale": False, # Optional, defaults to False. + }, + { + "name": "x2", + "type": "range", + "bounds": [0.0, 1.0], + }, + { + "name": "x3", + "type": "range", + "bounds": [0.0, 1.0], + }, + { + "name": "x4", + "type": "range", + "bounds": [0.0, 1.0], + }, + { + "name": "x5", + "type": "range", + "bounds": [0.0, 1.0], + }, + { + "name": "x6", + "type": "range", + "bounds": [0.0, 1.0], + }, + ] + algo = AxSearch( + parameters=parameters, + objective_name="hartmann6", + max_concurrent=4, + minimize=True, # Optional, defaults to False. + parameter_constraints=["x1 + x2 <= 2.0"], # Optional. + outcome_constraints=["l2norm <= 1.25"], # Optional. + ) + scheduler = AsyncHyperBandScheduler(reward_attr="hartmann6") + run(easy_objective, name="ax", search_alg=algo, **config) diff --git a/python/ray/tune/suggest/ax.py b/python/ray/tune/suggest/ax.py new file mode 100644 index 000000000..a48852e84 --- /dev/null +++ b/python/ray/tune/suggest/ax.py @@ -0,0 +1,112 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +try: + import ax +except ImportError: + ax = None + +from ray.tune.suggest.suggestion import SuggestionAlgorithm + + +class AxSearch(SuggestionAlgorithm): + """A wrapper around Ax to provide trial suggestions. + + Requires Ax to be installed. + Ax is an open source tool from Facebook for configuring and + optimizing experiments. More information can be found in https://ax.dev/. + + Parameters: + parameters (list[dict]): Parameters in the experiment search space. + Required elements in the dictionaries are: "name" (name of + this parameter, string), "type" (type of the parameter: "range", + "fixed", or "choice", string), "bounds" for range parameters + (list of two values, lower bound first), "values" for choice + parameters (list of values), and "value" for fixed parameters + (single value). + objective_name (str): Name of the metric used as objective in this + experiment. This metric must be present in `raw_data` argument + to `log_data`. This metric must also be present in the dict + reported/returned by the Trainable. + max_concurrent (int): Number of maximum concurrent trials. Defaults + to 10. + minimize (bool): Whether this experiment represents a minimization + problem. Defaults to False. + parameter_constraints (list[str]): Parameter constraints, such as + "x3 >= x4" or "x3 + x4 >= 2". + outcome_constraints (list[str]): Outcome constraints of form + "metric_name >= bound", like "m1 <= 3." + + + Example: + >>> parameters = [ + >>> {"name": "x1", "type": "range", "bounds": [0.0, 1.0]}, + >>> {"name": "x2", "type": "range", "bounds": [0.0, 1.0]}, + >>> ] + >>> algo = AxSearch(parameters=parameters, + >>> objective_name="hartmann6", max_concurrent=4) + """ + + def __init__(self, + parameters, + objective_name, + max_concurrent=10, + minimize=False, + parameter_constraints=None, + outcome_constraints=None, + **kwargs): + assert ax is not None, "Ax must be installed!" + from ax.service import ax_client + assert type(max_concurrent) is int and max_concurrent > 0 + self._ax = ax_client.AxClient(enforce_sequential_optimization=False) + self._ax.create_experiment( + name="ax", + parameters=parameters, + objective_name=objective_name, + minimize=minimize, + parameter_constraints=parameter_constraints or [], + outcome_constraints=outcome_constraints or [], + ) + self._max_concurrent = max_concurrent + self._parameters = [d["name"] for d in parameters] + self._objective_name = objective_name + self._live_index_mapping = {} + + super(AxSearch, self).__init__(**kwargs) + + def _suggest(self, trial_id): + if self._num_live_trials() >= self._max_concurrent: + return None + parameters, trial_index = self._ax.get_next_trial() + suggested_config = list(parameters.values()) + self._live_index_mapping[trial_id] = trial_index + return dict(zip(self._parameters, suggested_config)) + + def on_trial_result(self, trial_id, result): + pass + + def on_trial_complete(self, + trial_id, + result=None, + error=False, + early_terminated=False): + """Pass data back to Ax. + + Data of form key value dictionary of metric names and values. + """ + ax_trial_index = self._live_index_mapping.pop(trial_id) + if result: + metric_dict = { + self._objective_name: (result[self._objective_name], 0.0) + } + outcome_names = [ + oc.metric.name for oc in + self._ax.experiment.optimization_config.outcome_constraints + ] + metric_dict.update({on: (result[on], 0.0) for on in outcome_names}) + self._ax.complete_trial( + trial_index=ax_trial_index, raw_data=metric_dict) + + def _num_live_trials(self): + return len(self._live_index_mapping)