[tune] move _SCHEDULERS to tune.schedulers and add all available schedulers (#11218)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Thomas Tumiel
2020-10-08 23:10:23 +00:00
committed by GitHub
parent 6cb00208f7
commit 587319debc
3 changed files with 18 additions and 55 deletions
+14 -36
View File
@@ -10,8 +10,6 @@ from ray.tune.schedulers.pbt import (PopulationBasedTraining,
def create_scheduler(
scheduler,
metric=None,
mode=None,
**kwargs,
):
"""Instantiate a scheduler based on the given string.
@@ -20,45 +18,25 @@ def create_scheduler(
Args:
scheduler (str): The scheduler to use.
metric (str): The training result objective value attribute. Stopping
procedures will use this attribute.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
**kwargs: Additional parameters.
**kwargs: Scheduler parameters.
These keyword arguments will be passed to the initialization
function of the chosen class.
function of the chosen scheduler.
Returns:
ray.tune.schedulers.trial_scheduler.TrialScheduler: The scheduler.
Example:
>>> scheduler = tune.create_scheduler('pbt')
>>> scheduler = tune.create_scheduler('pbt', **pbt_kwargs)
"""
def _import_async_hyperband_scheduler():
from ray.tune.schedulers import AsyncHyperBandScheduler
return AsyncHyperBandScheduler
def _import_median_stopping_rule_scheduler():
from ray.tune.schedulers import MedianStoppingRule
return MedianStoppingRule
def _import_hyperband_scheduler():
from ray.tune.schedulers import HyperBandScheduler
return HyperBandScheduler
def _import_hb_bohb_scheduler():
from ray.tune.schedulers import HyperBandForBOHB
return HyperBandForBOHB
def _import_pbt_search():
from ray.tune.schedulers import PopulationBasedTraining
return PopulationBasedTraining
SCHEDULER_IMPORT = {
"async_hyperband": _import_async_hyperband_scheduler,
"median_stopping_rule": _import_median_stopping_rule_scheduler,
"hyperband": _import_hyperband_scheduler,
"hb_bohb": _import_hb_bohb_scheduler,
"pbt": _import_pbt_search,
"fifo": FIFOScheduler,
"async_hyperband": AsyncHyperBandScheduler,
"asynchyperband": AsyncHyperBandScheduler,
"median_stopping_rule": MedianStoppingRule,
"medianstopping": MedianStoppingRule,
"hyperband": HyperBandScheduler,
"hb_bohb": HyperBandForBOHB,
"pbt": PopulationBasedTraining,
"pbt_replay": PopulationBasedTrainingReplay,
}
scheduler = scheduler.lower()
if scheduler not in SCHEDULER_IMPORT:
@@ -66,8 +44,8 @@ def create_scheduler(
f"Search alg must be one of {list(SCHEDULER_IMPORT)}. "
f"Got: {scheduler}")
SchedulerClass = SCHEDULER_IMPORT[scheduler]()
return SchedulerClass(metric=metric, mode=mode, **kwargs)
SchedulerClass = SCHEDULER_IMPORT[scheduler]
return SchedulerClass(**kwargs)
__all__ = [
+1 -17
View File
@@ -14,18 +14,10 @@ from ray.tune.registry import get_trainable_cls
from ray.tune.syncer import wait_for_sync, set_sync_periods, SyncConfig
from ray.tune.trial_runner import TrialRunner
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
FIFOScheduler, MedianStoppingRule)
from ray.tune.schedulers import FIFOScheduler
logger = logging.getLogger(__name__)
_SCHEDULERS = {
"FIFO": FIFOScheduler,
"MedianStopping": MedianStoppingRule,
"HyperBand": HyperBandScheduler,
"AsyncHyperBand": AsyncHyperBandScheduler,
}
try:
class_name = get_ipython().__class__.__name__
IS_NOTEBOOK = True if "Terminal" not in class_name else False
@@ -33,14 +25,6 @@ except NameError:
IS_NOTEBOOK = False
def _make_scheduler(args):
if args.scheduler in _SCHEDULERS:
return _SCHEDULERS[args.scheduler](**args.scheduler_config)
else:
raise TuneError("Unknown scheduler: {}, should be one of {}".format(
args.scheduler, _SCHEDULERS.keys()))
def _check_default_resources_override(run_identifier):
if not isinstance(run_identifier, str):
# If obscure dtype, assume it is overridden.