mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 07:23:55 +08:00
[tune] move _SCHEDULERS to tune.schedulers and add all available schedulers (#11218)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -10,8 +10,6 @@ from ray.tune.schedulers.pbt import (PopulationBasedTraining,
|
||||
|
||||
def create_scheduler(
|
||||
scheduler,
|
||||
metric=None,
|
||||
mode=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Instantiate a scheduler based on the given string.
|
||||
@@ -20,45 +18,25 @@ def create_scheduler(
|
||||
|
||||
Args:
|
||||
scheduler (str): The scheduler to use.
|
||||
metric (str): The training result objective value attribute. Stopping
|
||||
procedures will use this attribute.
|
||||
mode (str): One of {min, max}. Determines whether objective is
|
||||
minimizing or maximizing the metric attribute.
|
||||
**kwargs: Additional parameters.
|
||||
**kwargs: Scheduler parameters.
|
||||
These keyword arguments will be passed to the initialization
|
||||
function of the chosen class.
|
||||
function of the chosen scheduler.
|
||||
Returns:
|
||||
ray.tune.schedulers.trial_scheduler.TrialScheduler: The scheduler.
|
||||
Example:
|
||||
>>> scheduler = tune.create_scheduler('pbt')
|
||||
>>> scheduler = tune.create_scheduler('pbt', **pbt_kwargs)
|
||||
"""
|
||||
|
||||
def _import_async_hyperband_scheduler():
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
return AsyncHyperBandScheduler
|
||||
|
||||
def _import_median_stopping_rule_scheduler():
|
||||
from ray.tune.schedulers import MedianStoppingRule
|
||||
return MedianStoppingRule
|
||||
|
||||
def _import_hyperband_scheduler():
|
||||
from ray.tune.schedulers import HyperBandScheduler
|
||||
return HyperBandScheduler
|
||||
|
||||
def _import_hb_bohb_scheduler():
|
||||
from ray.tune.schedulers import HyperBandForBOHB
|
||||
return HyperBandForBOHB
|
||||
|
||||
def _import_pbt_search():
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
return PopulationBasedTraining
|
||||
|
||||
SCHEDULER_IMPORT = {
|
||||
"async_hyperband": _import_async_hyperband_scheduler,
|
||||
"median_stopping_rule": _import_median_stopping_rule_scheduler,
|
||||
"hyperband": _import_hyperband_scheduler,
|
||||
"hb_bohb": _import_hb_bohb_scheduler,
|
||||
"pbt": _import_pbt_search,
|
||||
"fifo": FIFOScheduler,
|
||||
"async_hyperband": AsyncHyperBandScheduler,
|
||||
"asynchyperband": AsyncHyperBandScheduler,
|
||||
"median_stopping_rule": MedianStoppingRule,
|
||||
"medianstopping": MedianStoppingRule,
|
||||
"hyperband": HyperBandScheduler,
|
||||
"hb_bohb": HyperBandForBOHB,
|
||||
"pbt": PopulationBasedTraining,
|
||||
"pbt_replay": PopulationBasedTrainingReplay,
|
||||
}
|
||||
scheduler = scheduler.lower()
|
||||
if scheduler not in SCHEDULER_IMPORT:
|
||||
@@ -66,8 +44,8 @@ def create_scheduler(
|
||||
f"Search alg must be one of {list(SCHEDULER_IMPORT)}. "
|
||||
f"Got: {scheduler}")
|
||||
|
||||
SchedulerClass = SCHEDULER_IMPORT[scheduler]()
|
||||
return SchedulerClass(metric=metric, mode=mode, **kwargs)
|
||||
SchedulerClass = SCHEDULER_IMPORT[scheduler]
|
||||
return SchedulerClass(**kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
+1
-17
@@ -14,18 +14,10 @@ from ray.tune.registry import get_trainable_cls
|
||||
from ray.tune.syncer import wait_for_sync, set_sync_periods, SyncConfig
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
|
||||
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
|
||||
FIFOScheduler, MedianStoppingRule)
|
||||
from ray.tune.schedulers import FIFOScheduler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SCHEDULERS = {
|
||||
"FIFO": FIFOScheduler,
|
||||
"MedianStopping": MedianStoppingRule,
|
||||
"HyperBand": HyperBandScheduler,
|
||||
"AsyncHyperBand": AsyncHyperBandScheduler,
|
||||
}
|
||||
|
||||
try:
|
||||
class_name = get_ipython().__class__.__name__
|
||||
IS_NOTEBOOK = True if "Terminal" not in class_name else False
|
||||
@@ -33,14 +25,6 @@ except NameError:
|
||||
IS_NOTEBOOK = False
|
||||
|
||||
|
||||
def _make_scheduler(args):
|
||||
if args.scheduler in _SCHEDULERS:
|
||||
return _SCHEDULERS[args.scheduler](**args.scheduler_config)
|
||||
else:
|
||||
raise TuneError("Unknown scheduler: {}, should be one of {}".format(
|
||||
args.scheduler, _SCHEDULERS.keys()))
|
||||
|
||||
|
||||
def _check_default_resources_override(run_identifier):
|
||||
if not isinstance(run_identifier, str):
|
||||
# If obscure dtype, assume it is overridden.
|
||||
|
||||
Reference in New Issue
Block a user