mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:08:32 +08:00
[tune] implement shim instantiation (#10456)
* Create ray.tune.suggest.create.create_scheduler * Update __init__.py * Resolve conflict in __init__.py * Create ray.tune.schedulers.create.create_scheduler * Update __init__.py * Move create_scheduler to tune.schedulers.__init__ * Move create_searcher to tune.suggest.__init__ * Delete tune.suggest.create * Delete tune.schedulers.create * Update imports for shim functions in tune.__init__ * Remove shim from tune.suggest.__init__.__all__ * Remove shim from tune.schedulers.__init__.__all__ * Add ShimCreationTest * Move ShimCreationTest to test_api * Delete test_shim.py * Add docstring for ray.tune.create_scheduler * Add docstring to ray.tune.create_searcher * Fix typo in ray.tune.create_scheduler docstring * Fix lint errors in tune.schedulers.__init__ * Fix lint errors in tune.suggest.__init__ * Fix lint errors in tune.suggest.__init__ * Fix lint errors in tune.schedulers.__init__ * Fix imports in test_api * Fix lint errors in test_api * Fix kwargs in create_searcher * Fix kwargs in create_scheduler * Merge branch 'master' into shim-instantiation * Update use-case in docs in tune.create_scheduler * Update use-case in docs in tune.create_searcher * Remove duplicate pytest run from test_api * Add check to create_searcher Co-authored-by: Richard Liaw <rliaw@berkeley.edu> * Add check to create_scheduler * lint * Compare types of instances in test_api Co-authored-by: Richard Liaw <rliaw@berkeley.edu> * Add tune.create_searcher to docs * Fix doc build * Fix tests * Add tune.create_scheduler to docs * Fix tests * Fix lint errors * Update Ax search for master * Fix metric kwarg for Ax in test_api * Fix doc build * Fix HyperOptSearch import in test_api * Fix HyperOptSearch import in create_searcher Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -15,6 +15,8 @@ from ray.tune.progress_reporter import (ProgressReporter, CLIReporter,
|
||||
from ray.tune.sample import (function, sample_from, uniform, quniform, choice,
|
||||
randint, qrandint, randn, qrandn, loguniform,
|
||||
qloguniform)
|
||||
from ray.tune.suggest import create_searcher
|
||||
from ray.tune.schedulers import create_scheduler
|
||||
|
||||
__all__ = [
|
||||
"Trainable", "DurableTrainable", "TuneError", "grid_search",
|
||||
@@ -24,5 +26,5 @@ __all__ = [
|
||||
"loguniform", "qloguniform", "ExperimentAnalysis", "Analysis",
|
||||
"CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report",
|
||||
"get_trial_dir", "get_trial_name", "get_trial_id", "make_checkpoint_dir",
|
||||
"save_checkpoint", "checkpoint_dir"
|
||||
"save_checkpoint", "checkpoint_dir", "create_searcher", "create_scheduler"
|
||||
]
|
||||
|
||||
@@ -7,6 +7,69 @@ from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.schedulers.pbt import (PopulationBasedTraining,
|
||||
PopulationBasedTrainingReplay)
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
scheduler,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
**kwargs,
|
||||
):
|
||||
"""Instantiate a scheduler based on the given string.
|
||||
|
||||
This is useful for swapping between different schedulers.
|
||||
|
||||
Args:
|
||||
scheduler (str): The scheduler to use.
|
||||
metric (str): The training result objective value attribute. Stopping
|
||||
procedures will use this attribute.
|
||||
mode (str): One of {min, max}. Determines whether objective is
|
||||
minimizing or maximizing the metric attribute.
|
||||
**kwargs: Additional parameters.
|
||||
These keyword arguments will be passed to the initialization
|
||||
function of the chosen class.
|
||||
Returns:
|
||||
ray.tune.schedulers.trial_scheduler.TrialScheduler: The scheduler.
|
||||
Example:
|
||||
>>> scheduler = tune.create_scheduler('pbt')
|
||||
"""
|
||||
|
||||
def _import_async_hyperband_scheduler():
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
return AsyncHyperBandScheduler
|
||||
|
||||
def _import_median_stopping_rule_scheduler():
|
||||
from ray.tune.schedulers import MedianStoppingRule
|
||||
return MedianStoppingRule
|
||||
|
||||
def _import_hyperband_scheduler():
|
||||
from ray.tune.schedulers import HyperBandScheduler
|
||||
return HyperBandScheduler
|
||||
|
||||
def _import_hb_bohb_scheduler():
|
||||
from ray.tune.schedulers import HyperBandForBOHB
|
||||
return HyperBandForBOHB
|
||||
|
||||
def _import_pbt_search():
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
return PopulationBasedTraining
|
||||
|
||||
SCHEDULER_IMPORT = {
|
||||
"async_hyperband": _import_async_hyperband_scheduler,
|
||||
"median_stopping_rule": _import_median_stopping_rule_scheduler,
|
||||
"hyperband": _import_hyperband_scheduler,
|
||||
"hb_bohb": _import_hb_bohb_scheduler,
|
||||
"pbt": _import_pbt_search,
|
||||
}
|
||||
scheduler = scheduler.lower()
|
||||
if scheduler not in SCHEDULER_IMPORT:
|
||||
raise ValueError(
|
||||
f"Search alg must be one of {list(SCHEDULER_IMPORT)}. "
|
||||
f"Got: {scheduler}")
|
||||
|
||||
SchedulerClass = SCHEDULER_IMPORT[scheduler]()
|
||||
return SchedulerClass(metric=metric, mode=mode, **kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler",
|
||||
"ASHAScheduler", "MedianStoppingRule", "FIFOScheduler",
|
||||
|
||||
@@ -5,6 +5,94 @@ from ray.tune.suggest.search_generator import SearchGenerator
|
||||
from ray.tune.suggest.variant_generator import grid_search
|
||||
from ray.tune.suggest.repeater import Repeater
|
||||
|
||||
|
||||
def create_searcher(
|
||||
search_alg,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
**kwargs,
|
||||
):
|
||||
"""Instantiate a search algorithm based on the given string.
|
||||
|
||||
This is useful for swapping between different search algorithms.
|
||||
|
||||
Args:
|
||||
search_alg (str): The search algorithm to use.
|
||||
metric (str): The training result objective value attribute. Stopping
|
||||
procedures will use this attribute.
|
||||
mode (str): One of {min, max}. Determines whether objective is
|
||||
minimizing or maximizing the metric attribute.
|
||||
**kwargs: Additional parameters.
|
||||
These keyword arguments will be passed to the initialization
|
||||
function of the chosen class.
|
||||
Returns:
|
||||
ray.tune.suggest.Searcher: The search algorithm.
|
||||
Example:
|
||||
>>> search_alg = tune.create_searcher('ax')
|
||||
"""
|
||||
|
||||
def _import_ax_search():
|
||||
from ray.tune.suggest.ax import AxSearch
|
||||
return AxSearch
|
||||
|
||||
def _import_dragonfly_search():
|
||||
from ray.tune.suggest.dragonfly import DragonflySearch
|
||||
return DragonflySearch
|
||||
|
||||
def _import_skopt_search():
|
||||
from ray.tune.suggest.skopt import SkOptSearch
|
||||
return SkOptSearch
|
||||
|
||||
def _import_hyperopt_search():
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
return HyperOptSearch
|
||||
|
||||
def _import_bayesopt_search():
|
||||
from ray.tune.suggest.bayesopt import BayesOptSearch
|
||||
return BayesOptSearch
|
||||
|
||||
def _import_bohb_search():
|
||||
from ray.tune.suggest.bohb import TuneBOHB
|
||||
return TuneBOHB
|
||||
|
||||
def _import_nevergrad_search():
|
||||
from ray.tune.suggest.nevergrad import NevergradSearch
|
||||
return NevergradSearch
|
||||
|
||||
def _import_optuna_search():
|
||||
from ray.tune.suggest.optuna import OptunaSearch
|
||||
return OptunaSearch
|
||||
|
||||
def _import_zoopt_search():
|
||||
from ray.tune.suggest.zoopt import ZOOptSearch
|
||||
return ZOOptSearch
|
||||
|
||||
def _import_sigopt_search():
|
||||
from ray.tune.suggest.sigopt import SigOptSearch
|
||||
return SigOptSearch
|
||||
|
||||
SEARCH_ALG_IMPORT = {
|
||||
"ax": _import_ax_search,
|
||||
"dragonfly": _import_dragonfly_search,
|
||||
"skopt": _import_skopt_search,
|
||||
"hyperopt": _import_hyperopt_search,
|
||||
"bayesopt": _import_bayesopt_search,
|
||||
"bohb": _import_bohb_search,
|
||||
"nevergrad": _import_nevergrad_search,
|
||||
"optuna": _import_optuna_search,
|
||||
"zoopt": _import_zoopt_search,
|
||||
"sigopt": _import_sigopt_search,
|
||||
}
|
||||
search_alg = search_alg.lower()
|
||||
if search_alg not in SEARCH_ALG_IMPORT:
|
||||
raise ValueError(
|
||||
f"Search alg must be one of {list(SEARCH_ALG_IMPORT)}. "
|
||||
f"Got: {search_alg}")
|
||||
|
||||
SearcherClass = SEARCH_ALG_IMPORT[search_alg]()
|
||||
return SearcherClass(metric=metric, mode=mode, **kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SearchAlgorithm", "Searcher", "BasicVariantGenerator", "SearchGenerator",
|
||||
"grid_search", "Repeater", "ConcurrencyLimiter"
|
||||
|
||||
@@ -14,7 +14,8 @@ from ray import tune
|
||||
from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper,
|
||||
EarlyStopping)
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
|
||||
AsyncHyperBandScheduler)
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
|
||||
EPISODES_TOTAL, TRAINING_ITERATION,
|
||||
@@ -24,6 +25,8 @@ from ray.tune.logger import Logger
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.suggest import grid_search
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from ray.tune.suggest.ax import AxSearch
|
||||
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
|
||||
from ray.tune.utils import (flatten_dict, get_pinned_object,
|
||||
pin_in_object_store)
|
||||
@@ -1105,6 +1108,30 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
self.assertIn("LOG_STDERR", content)
|
||||
|
||||
|
||||
class ShimCreationTest(unittest.TestCase):
|
||||
def testCreateScheduler(self):
|
||||
kwargs = {"metric": "metric_foo", "mode": "min"}
|
||||
|
||||
scheduler = "async_hyperband"
|
||||
shim_scheduler = tune.create_scheduler(scheduler, **kwargs)
|
||||
real_scheduler = AsyncHyperBandScheduler(**kwargs)
|
||||
assert type(shim_scheduler) is type(real_scheduler)
|
||||
|
||||
def testCreateSearcher(self):
|
||||
kwargs = {"metric": "metric_foo", "mode": "min"}
|
||||
|
||||
searcher_ax = "ax"
|
||||
shim_searcher_ax = tune.create_searcher(searcher_ax, **kwargs)
|
||||
real_searcher_ax = AxSearch(space=[], **kwargs)
|
||||
assert type(shim_searcher_ax) is type(real_searcher_ax)
|
||||
|
||||
searcher_hyperopt = "hyperopt"
|
||||
shim_searcher_hyperopt = tune.create_searcher(searcher_hyperopt,
|
||||
**kwargs)
|
||||
real_searcher_hyperopt = HyperOptSearch({}, **kwargs)
|
||||
assert type(shim_searcher_hyperopt) is type(real_searcher_hyperopt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
Reference in New Issue
Block a user