diff --git a/doc/source/tune/api_docs/schedulers.rst b/doc/source/tune/api_docs/schedulers.rst index 811da9824..f904a0b82 100644 --- a/doc/source/tune/api_docs/schedulers.rst +++ b/doc/source/tune/api_docs/schedulers.rst @@ -196,3 +196,9 @@ TrialScheduler .. autoclass:: ray.tune.schedulers.TrialScheduler :members: + +Shim Instantiation (tune.create_scheduler) +------------------------------------------ +There is also a shim function that constructs the scheduler based on the provided string. This can be useful if the scheduler you want to use changes often (e.g., specifying the scheduler via a CLI option or config file). + +.. automethod:: ray.tune.create_scheduler diff --git a/doc/source/tune/api_docs/suggestion.rst b/doc/source/tune/api_docs/suggestion.rst index 9df6d4223..ba05429c3 100644 --- a/doc/source/tune/api_docs/suggestion.rst +++ b/doc/source/tune/api_docs/suggestion.rst @@ -79,6 +79,7 @@ Tune also provides helpful utilities to use with Search Algorithms: * :ref:`repeater`: Support for running each *sampled hyperparameter* with multiple random seeds. * :ref:`limiter`: Limits the amount of concurrent trials when running optimization. + * :ref:`shim`: Allows creation of the search algorithm object given a string. Saving and Restoring -------------------- @@ -268,3 +269,11 @@ If you are interested in implementing or contributing a new Search Algorithm, pr :members: :private-members: :show-inheritance: + +.. _shim: + +Shim Instantiation (tune.create_searcher) +----------------------------------------- +There is also a shim function that constructs the search algorithm based on the provided string. This can be useful if the search algorithm you want to use changes often (e.g., specifying the search algorithm via a CLI option or config file). + +.. automethod:: ray.tune.create_searcher diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index b50d96c36..960be824d 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -15,6 +15,8 @@ from ray.tune.progress_reporter import (ProgressReporter, CLIReporter, from ray.tune.sample import (function, sample_from, uniform, quniform, choice, randint, qrandint, randn, qrandn, loguniform, qloguniform) +from ray.tune.suggest import create_searcher +from ray.tune.schedulers import create_scheduler __all__ = [ "Trainable", "DurableTrainable", "TuneError", "grid_search", @@ -24,5 +26,5 @@ __all__ = [ "loguniform", "qloguniform", "ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report", "get_trial_dir", "get_trial_name", "get_trial_id", "make_checkpoint_dir", - "save_checkpoint", "checkpoint_dir" + "save_checkpoint", "checkpoint_dir", "create_searcher", "create_scheduler" ] diff --git a/python/ray/tune/schedulers/__init__.py b/python/ray/tune/schedulers/__init__.py index 70554789d..54b88ca9e 100644 --- a/python/ray/tune/schedulers/__init__.py +++ b/python/ray/tune/schedulers/__init__.py @@ -7,6 +7,69 @@ from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule from ray.tune.schedulers.pbt import (PopulationBasedTraining, PopulationBasedTrainingReplay) + +def create_scheduler( + scheduler, + metric="episode_reward_mean", + mode="max", + **kwargs, +): + """Instantiate a scheduler based on the given string. + + This is useful for swapping between different schedulers. + + Args: + scheduler (str): The scheduler to use. + metric (str): The training result objective value attribute. Stopping + procedures will use this attribute. + mode (str): One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. + **kwargs: Additional parameters. + These keyword arguments will be passed to the initialization + function of the chosen class. + Returns: + ray.tune.schedulers.trial_scheduler.TrialScheduler: The scheduler. + Example: + >>> scheduler = tune.create_scheduler('pbt') + """ + + def _import_async_hyperband_scheduler(): + from ray.tune.schedulers import AsyncHyperBandScheduler + return AsyncHyperBandScheduler + + def _import_median_stopping_rule_scheduler(): + from ray.tune.schedulers import MedianStoppingRule + return MedianStoppingRule + + def _import_hyperband_scheduler(): + from ray.tune.schedulers import HyperBandScheduler + return HyperBandScheduler + + def _import_hb_bohb_scheduler(): + from ray.tune.schedulers import HyperBandForBOHB + return HyperBandForBOHB + + def _import_pbt_search(): + from ray.tune.schedulers import PopulationBasedTraining + return PopulationBasedTraining + + SCHEDULER_IMPORT = { + "async_hyperband": _import_async_hyperband_scheduler, + "median_stopping_rule": _import_median_stopping_rule_scheduler, + "hyperband": _import_hyperband_scheduler, + "hb_bohb": _import_hb_bohb_scheduler, + "pbt": _import_pbt_search, + } + scheduler = scheduler.lower() + if scheduler not in SCHEDULER_IMPORT: + raise ValueError( + f"Search alg must be one of {list(SCHEDULER_IMPORT)}. " + f"Got: {scheduler}") + + SchedulerClass = SCHEDULER_IMPORT[scheduler]() + return SchedulerClass(metric=metric, mode=mode, **kwargs) + + __all__ = [ "TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler", "ASHAScheduler", "MedianStoppingRule", "FIFOScheduler", diff --git a/python/ray/tune/suggest/__init__.py b/python/ray/tune/suggest/__init__.py index 060edc4b3..a9b5582a9 100644 --- a/python/ray/tune/suggest/__init__.py +++ b/python/ray/tune/suggest/__init__.py @@ -5,6 +5,94 @@ from ray.tune.suggest.search_generator import SearchGenerator from ray.tune.suggest.variant_generator import grid_search from ray.tune.suggest.repeater import Repeater + +def create_searcher( + search_alg, + metric="episode_reward_mean", + mode="max", + **kwargs, +): + """Instantiate a search algorithm based on the given string. + + This is useful for swapping between different search algorithms. + + Args: + search_alg (str): The search algorithm to use. + metric (str): The training result objective value attribute. Stopping + procedures will use this attribute. + mode (str): One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. + **kwargs: Additional parameters. + These keyword arguments will be passed to the initialization + function of the chosen class. + Returns: + ray.tune.suggest.Searcher: The search algorithm. + Example: + >>> search_alg = tune.create_searcher('ax') + """ + + def _import_ax_search(): + from ray.tune.suggest.ax import AxSearch + return AxSearch + + def _import_dragonfly_search(): + from ray.tune.suggest.dragonfly import DragonflySearch + return DragonflySearch + + def _import_skopt_search(): + from ray.tune.suggest.skopt import SkOptSearch + return SkOptSearch + + def _import_hyperopt_search(): + from ray.tune.suggest.hyperopt import HyperOptSearch + return HyperOptSearch + + def _import_bayesopt_search(): + from ray.tune.suggest.bayesopt import BayesOptSearch + return BayesOptSearch + + def _import_bohb_search(): + from ray.tune.suggest.bohb import TuneBOHB + return TuneBOHB + + def _import_nevergrad_search(): + from ray.tune.suggest.nevergrad import NevergradSearch + return NevergradSearch + + def _import_optuna_search(): + from ray.tune.suggest.optuna import OptunaSearch + return OptunaSearch + + def _import_zoopt_search(): + from ray.tune.suggest.zoopt import ZOOptSearch + return ZOOptSearch + + def _import_sigopt_search(): + from ray.tune.suggest.sigopt import SigOptSearch + return SigOptSearch + + SEARCH_ALG_IMPORT = { + "ax": _import_ax_search, + "dragonfly": _import_dragonfly_search, + "skopt": _import_skopt_search, + "hyperopt": _import_hyperopt_search, + "bayesopt": _import_bayesopt_search, + "bohb": _import_bohb_search, + "nevergrad": _import_nevergrad_search, + "optuna": _import_optuna_search, + "zoopt": _import_zoopt_search, + "sigopt": _import_sigopt_search, + } + search_alg = search_alg.lower() + if search_alg not in SEARCH_ALG_IMPORT: + raise ValueError( + f"Search alg must be one of {list(SEARCH_ALG_IMPORT)}. " + f"Got: {search_alg}") + + SearcherClass = SEARCH_ALG_IMPORT[search_alg]() + return SearcherClass(metric=metric, mode=mode, **kwargs) + + __all__ = [ "SearchAlgorithm", "Searcher", "BasicVariantGenerator", "SearchGenerator", "grid_search", "Repeater", "ConcurrencyLimiter" diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index 5c4322658..d25536e34 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -14,7 +14,8 @@ from ray import tune from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper, EarlyStopping) from ray.tune import register_env, register_trainable, run_experiments -from ray.tune.schedulers import TrialScheduler, FIFOScheduler +from ray.tune.schedulers import (TrialScheduler, FIFOScheduler, + AsyncHyperBandScheduler) from ray.tune.trial import Trial from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID, EPISODES_TOTAL, TRAINING_ITERATION, @@ -24,6 +25,8 @@ from ray.tune.logger import Logger from ray.tune.experiment import Experiment from ray.tune.resources import Resources from ray.tune.suggest import grid_search +from ray.tune.suggest.hyperopt import HyperOptSearch +from ray.tune.suggest.ax import AxSearch from ray.tune.suggest._mock import _MockSuggestionAlgorithm from ray.tune.utils import (flatten_dict, get_pinned_object, pin_in_object_store) @@ -1105,6 +1108,30 @@ class TrainableFunctionApiTest(unittest.TestCase): self.assertIn("LOG_STDERR", content) +class ShimCreationTest(unittest.TestCase): + def testCreateScheduler(self): + kwargs = {"metric": "metric_foo", "mode": "min"} + + scheduler = "async_hyperband" + shim_scheduler = tune.create_scheduler(scheduler, **kwargs) + real_scheduler = AsyncHyperBandScheduler(**kwargs) + assert type(shim_scheduler) is type(real_scheduler) + + def testCreateSearcher(self): + kwargs = {"metric": "metric_foo", "mode": "min"} + + searcher_ax = "ax" + shim_searcher_ax = tune.create_searcher(searcher_ax, **kwargs) + real_searcher_ax = AxSearch(space=[], **kwargs) + assert type(shim_searcher_ax) is type(real_searcher_ax) + + searcher_hyperopt = "hyperopt" + shim_searcher_hyperopt = tune.create_searcher(searcher_hyperopt, + **kwargs) + real_searcher_hyperopt = HyperOptSearch({}, **kwargs) + assert type(shim_searcher_hyperopt) is type(real_searcher_hyperopt) + + if __name__ == "__main__": import pytest import sys