mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 06:47:13 +08:00
[tune] implement shim instantiation (#10456)
* Create ray.tune.suggest.create.create_scheduler * Update __init__.py * Resolve conflict in __init__.py * Create ray.tune.schedulers.create.create_scheduler * Update __init__.py * Move create_scheduler to tune.schedulers.__init__ * Move create_searcher to tune.suggest.__init__ * Delete tune.suggest.create * Delete tune.schedulers.create * Update imports for shim functions in tune.__init__ * Remove shim from tune.suggest.__init__.__all__ * Remove shim from tune.schedulers.__init__.__all__ * Add ShimCreationTest * Move ShimCreationTest to test_api * Delete test_shim.py * Add docstring for ray.tune.create_scheduler * Add docstring to ray.tune.create_searcher * Fix typo in ray.tune.create_scheduler docstring * Fix lint errors in tune.schedulers.__init__ * Fix lint errors in tune.suggest.__init__ * Fix lint errors in tune.suggest.__init__ * Fix lint errors in tune.schedulers.__init__ * Fix imports in test_api * Fix lint errors in test_api * Fix kwargs in create_searcher * Fix kwargs in create_scheduler * Merge branch 'master' into shim-instantiation * Update use-case in docs in tune.create_scheduler * Update use-case in docs in tune.create_searcher * Remove duplicate pytest run from test_api * Add check to create_searcher Co-authored-by: Richard Liaw <rliaw@berkeley.edu> * Add check to create_scheduler * lint * Compare types of instances in test_api Co-authored-by: Richard Liaw <rliaw@berkeley.edu> * Add tune.create_searcher to docs * Fix doc build * Fix tests * Add tune.create_scheduler to docs * Fix tests * Fix lint errors * Update Ax search for master * Fix metric kwarg for Ax in test_api * Fix doc build * Fix HyperOptSearch import in test_api * Fix HyperOptSearch import in create_searcher Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -196,3 +196,9 @@ TrialScheduler
|
||||
|
||||
.. autoclass:: ray.tune.schedulers.TrialScheduler
|
||||
:members:
|
||||
|
||||
Shim Instantiation (tune.create_scheduler)
|
||||
------------------------------------------
|
||||
There is also a shim function that constructs the scheduler based on the provided string. This can be useful if the scheduler you want to use changes often (e.g., specifying the scheduler via a CLI option or config file).
|
||||
|
||||
.. automethod:: ray.tune.create_scheduler
|
||||
|
||||
@@ -79,6 +79,7 @@ Tune also provides helpful utilities to use with Search Algorithms:
|
||||
|
||||
* :ref:`repeater`: Support for running each *sampled hyperparameter* with multiple random seeds.
|
||||
* :ref:`limiter`: Limits the amount of concurrent trials when running optimization.
|
||||
* :ref:`shim`: Allows creation of the search algorithm object given a string.
|
||||
|
||||
Saving and Restoring
|
||||
--------------------
|
||||
@@ -268,3 +269,11 @@ If you are interested in implementing or contributing a new Search Algorithm, pr
|
||||
:members:
|
||||
:private-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. _shim:
|
||||
|
||||
Shim Instantiation (tune.create_searcher)
|
||||
-----------------------------------------
|
||||
There is also a shim function that constructs the search algorithm based on the provided string. This can be useful if the search algorithm you want to use changes often (e.g., specifying the search algorithm via a CLI option or config file).
|
||||
|
||||
.. automethod:: ray.tune.create_searcher
|
||||
|
||||
@@ -15,6 +15,8 @@ from ray.tune.progress_reporter import (ProgressReporter, CLIReporter,
|
||||
from ray.tune.sample import (function, sample_from, uniform, quniform, choice,
|
||||
randint, qrandint, randn, qrandn, loguniform,
|
||||
qloguniform)
|
||||
from ray.tune.suggest import create_searcher
|
||||
from ray.tune.schedulers import create_scheduler
|
||||
|
||||
__all__ = [
|
||||
"Trainable", "DurableTrainable", "TuneError", "grid_search",
|
||||
@@ -24,5 +26,5 @@ __all__ = [
|
||||
"loguniform", "qloguniform", "ExperimentAnalysis", "Analysis",
|
||||
"CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report",
|
||||
"get_trial_dir", "get_trial_name", "get_trial_id", "make_checkpoint_dir",
|
||||
"save_checkpoint", "checkpoint_dir"
|
||||
"save_checkpoint", "checkpoint_dir", "create_searcher", "create_scheduler"
|
||||
]
|
||||
|
||||
@@ -7,6 +7,69 @@ from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.schedulers.pbt import (PopulationBasedTraining,
|
||||
PopulationBasedTrainingReplay)
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
scheduler,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
**kwargs,
|
||||
):
|
||||
"""Instantiate a scheduler based on the given string.
|
||||
|
||||
This is useful for swapping between different schedulers.
|
||||
|
||||
Args:
|
||||
scheduler (str): The scheduler to use.
|
||||
metric (str): The training result objective value attribute. Stopping
|
||||
procedures will use this attribute.
|
||||
mode (str): One of {min, max}. Determines whether objective is
|
||||
minimizing or maximizing the metric attribute.
|
||||
**kwargs: Additional parameters.
|
||||
These keyword arguments will be passed to the initialization
|
||||
function of the chosen class.
|
||||
Returns:
|
||||
ray.tune.schedulers.trial_scheduler.TrialScheduler: The scheduler.
|
||||
Example:
|
||||
>>> scheduler = tune.create_scheduler('pbt')
|
||||
"""
|
||||
|
||||
def _import_async_hyperband_scheduler():
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
return AsyncHyperBandScheduler
|
||||
|
||||
def _import_median_stopping_rule_scheduler():
|
||||
from ray.tune.schedulers import MedianStoppingRule
|
||||
return MedianStoppingRule
|
||||
|
||||
def _import_hyperband_scheduler():
|
||||
from ray.tune.schedulers import HyperBandScheduler
|
||||
return HyperBandScheduler
|
||||
|
||||
def _import_hb_bohb_scheduler():
|
||||
from ray.tune.schedulers import HyperBandForBOHB
|
||||
return HyperBandForBOHB
|
||||
|
||||
def _import_pbt_search():
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
return PopulationBasedTraining
|
||||
|
||||
SCHEDULER_IMPORT = {
|
||||
"async_hyperband": _import_async_hyperband_scheduler,
|
||||
"median_stopping_rule": _import_median_stopping_rule_scheduler,
|
||||
"hyperband": _import_hyperband_scheduler,
|
||||
"hb_bohb": _import_hb_bohb_scheduler,
|
||||
"pbt": _import_pbt_search,
|
||||
}
|
||||
scheduler = scheduler.lower()
|
||||
if scheduler not in SCHEDULER_IMPORT:
|
||||
raise ValueError(
|
||||
f"Search alg must be one of {list(SCHEDULER_IMPORT)}. "
|
||||
f"Got: {scheduler}")
|
||||
|
||||
SchedulerClass = SCHEDULER_IMPORT[scheduler]()
|
||||
return SchedulerClass(metric=metric, mode=mode, **kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler",
|
||||
"ASHAScheduler", "MedianStoppingRule", "FIFOScheduler",
|
||||
|
||||
@@ -5,6 +5,94 @@ from ray.tune.suggest.search_generator import SearchGenerator
|
||||
from ray.tune.suggest.variant_generator import grid_search
|
||||
from ray.tune.suggest.repeater import Repeater
|
||||
|
||||
|
||||
def create_searcher(
|
||||
search_alg,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
**kwargs,
|
||||
):
|
||||
"""Instantiate a search algorithm based on the given string.
|
||||
|
||||
This is useful for swapping between different search algorithms.
|
||||
|
||||
Args:
|
||||
search_alg (str): The search algorithm to use.
|
||||
metric (str): The training result objective value attribute. Stopping
|
||||
procedures will use this attribute.
|
||||
mode (str): One of {min, max}. Determines whether objective is
|
||||
minimizing or maximizing the metric attribute.
|
||||
**kwargs: Additional parameters.
|
||||
These keyword arguments will be passed to the initialization
|
||||
function of the chosen class.
|
||||
Returns:
|
||||
ray.tune.suggest.Searcher: The search algorithm.
|
||||
Example:
|
||||
>>> search_alg = tune.create_searcher('ax')
|
||||
"""
|
||||
|
||||
def _import_ax_search():
|
||||
from ray.tune.suggest.ax import AxSearch
|
||||
return AxSearch
|
||||
|
||||
def _import_dragonfly_search():
|
||||
from ray.tune.suggest.dragonfly import DragonflySearch
|
||||
return DragonflySearch
|
||||
|
||||
def _import_skopt_search():
|
||||
from ray.tune.suggest.skopt import SkOptSearch
|
||||
return SkOptSearch
|
||||
|
||||
def _import_hyperopt_search():
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
return HyperOptSearch
|
||||
|
||||
def _import_bayesopt_search():
|
||||
from ray.tune.suggest.bayesopt import BayesOptSearch
|
||||
return BayesOptSearch
|
||||
|
||||
def _import_bohb_search():
|
||||
from ray.tune.suggest.bohb import TuneBOHB
|
||||
return TuneBOHB
|
||||
|
||||
def _import_nevergrad_search():
|
||||
from ray.tune.suggest.nevergrad import NevergradSearch
|
||||
return NevergradSearch
|
||||
|
||||
def _import_optuna_search():
|
||||
from ray.tune.suggest.optuna import OptunaSearch
|
||||
return OptunaSearch
|
||||
|
||||
def _import_zoopt_search():
|
||||
from ray.tune.suggest.zoopt import ZOOptSearch
|
||||
return ZOOptSearch
|
||||
|
||||
def _import_sigopt_search():
|
||||
from ray.tune.suggest.sigopt import SigOptSearch
|
||||
return SigOptSearch
|
||||
|
||||
SEARCH_ALG_IMPORT = {
|
||||
"ax": _import_ax_search,
|
||||
"dragonfly": _import_dragonfly_search,
|
||||
"skopt": _import_skopt_search,
|
||||
"hyperopt": _import_hyperopt_search,
|
||||
"bayesopt": _import_bayesopt_search,
|
||||
"bohb": _import_bohb_search,
|
||||
"nevergrad": _import_nevergrad_search,
|
||||
"optuna": _import_optuna_search,
|
||||
"zoopt": _import_zoopt_search,
|
||||
"sigopt": _import_sigopt_search,
|
||||
}
|
||||
search_alg = search_alg.lower()
|
||||
if search_alg not in SEARCH_ALG_IMPORT:
|
||||
raise ValueError(
|
||||
f"Search alg must be one of {list(SEARCH_ALG_IMPORT)}. "
|
||||
f"Got: {search_alg}")
|
||||
|
||||
SearcherClass = SEARCH_ALG_IMPORT[search_alg]()
|
||||
return SearcherClass(metric=metric, mode=mode, **kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SearchAlgorithm", "Searcher", "BasicVariantGenerator", "SearchGenerator",
|
||||
"grid_search", "Repeater", "ConcurrencyLimiter"
|
||||
|
||||
@@ -14,7 +14,8 @@ from ray import tune
|
||||
from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper,
|
||||
EarlyStopping)
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
|
||||
AsyncHyperBandScheduler)
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
|
||||
EPISODES_TOTAL, TRAINING_ITERATION,
|
||||
@@ -24,6 +25,8 @@ from ray.tune.logger import Logger
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.suggest import grid_search
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from ray.tune.suggest.ax import AxSearch
|
||||
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
|
||||
from ray.tune.utils import (flatten_dict, get_pinned_object,
|
||||
pin_in_object_store)
|
||||
@@ -1105,6 +1108,30 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
self.assertIn("LOG_STDERR", content)
|
||||
|
||||
|
||||
class ShimCreationTest(unittest.TestCase):
|
||||
def testCreateScheduler(self):
|
||||
kwargs = {"metric": "metric_foo", "mode": "min"}
|
||||
|
||||
scheduler = "async_hyperband"
|
||||
shim_scheduler = tune.create_scheduler(scheduler, **kwargs)
|
||||
real_scheduler = AsyncHyperBandScheduler(**kwargs)
|
||||
assert type(shim_scheduler) is type(real_scheduler)
|
||||
|
||||
def testCreateSearcher(self):
|
||||
kwargs = {"metric": "metric_foo", "mode": "min"}
|
||||
|
||||
searcher_ax = "ax"
|
||||
shim_searcher_ax = tune.create_searcher(searcher_ax, **kwargs)
|
||||
real_searcher_ax = AxSearch(space=[], **kwargs)
|
||||
assert type(shim_searcher_ax) is type(real_searcher_ax)
|
||||
|
||||
searcher_hyperopt = "hyperopt"
|
||||
shim_searcher_hyperopt = tune.create_searcher(searcher_hyperopt,
|
||||
**kwargs)
|
||||
real_searcher_hyperopt = HyperOptSearch({}, **kwargs)
|
||||
assert type(shim_searcher_hyperopt) is type(real_searcher_hyperopt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
Reference in New Issue
Block a user