diff --git a/python/ray/tune/schedulers/__init__.py b/python/ray/tune/schedulers/__init__.py index 5e51bdab2..3bbb8cc70 100644 --- a/python/ray/tune/schedulers/__init__.py +++ b/python/ray/tune/schedulers/__init__.py @@ -10,8 +10,6 @@ from ray.tune.schedulers.pbt import (PopulationBasedTraining, def create_scheduler( scheduler, - metric=None, - mode=None, **kwargs, ): """Instantiate a scheduler based on the given string. @@ -20,45 +18,25 @@ def create_scheduler( Args: scheduler (str): The scheduler to use. - metric (str): The training result objective value attribute. Stopping - procedures will use this attribute. - mode (str): One of {min, max}. Determines whether objective is - minimizing or maximizing the metric attribute. - **kwargs: Additional parameters. + **kwargs: Scheduler parameters. These keyword arguments will be passed to the initialization - function of the chosen class. + function of the chosen scheduler. Returns: ray.tune.schedulers.trial_scheduler.TrialScheduler: The scheduler. Example: - >>> scheduler = tune.create_scheduler('pbt') + >>> scheduler = tune.create_scheduler('pbt', **pbt_kwargs) """ - def _import_async_hyperband_scheduler(): - from ray.tune.schedulers import AsyncHyperBandScheduler - return AsyncHyperBandScheduler - - def _import_median_stopping_rule_scheduler(): - from ray.tune.schedulers import MedianStoppingRule - return MedianStoppingRule - - def _import_hyperband_scheduler(): - from ray.tune.schedulers import HyperBandScheduler - return HyperBandScheduler - - def _import_hb_bohb_scheduler(): - from ray.tune.schedulers import HyperBandForBOHB - return HyperBandForBOHB - - def _import_pbt_search(): - from ray.tune.schedulers import PopulationBasedTraining - return PopulationBasedTraining - SCHEDULER_IMPORT = { - "async_hyperband": _import_async_hyperband_scheduler, - "median_stopping_rule": _import_median_stopping_rule_scheduler, - "hyperband": _import_hyperband_scheduler, - "hb_bohb": _import_hb_bohb_scheduler, - "pbt": _import_pbt_search, + "fifo": FIFOScheduler, + "async_hyperband": AsyncHyperBandScheduler, + "asynchyperband": AsyncHyperBandScheduler, + "median_stopping_rule": MedianStoppingRule, + "medianstopping": MedianStoppingRule, + "hyperband": HyperBandScheduler, + "hb_bohb": HyperBandForBOHB, + "pbt": PopulationBasedTraining, + "pbt_replay": PopulationBasedTrainingReplay, } scheduler = scheduler.lower() if scheduler not in SCHEDULER_IMPORT: @@ -66,8 +44,8 @@ def create_scheduler( f"Search alg must be one of {list(SCHEDULER_IMPORT)}. " f"Got: {scheduler}") - SchedulerClass = SCHEDULER_IMPORT[scheduler]() - return SchedulerClass(metric=metric, mode=mode, **kwargs) + SchedulerClass = SCHEDULER_IMPORT[scheduler] + return SchedulerClass(**kwargs) __all__ = [ diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 9c5646786..30dbde888 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -14,18 +14,10 @@ from ray.tune.registry import get_trainable_cls from ray.tune.syncer import wait_for_sync, set_sync_periods, SyncConfig from ray.tune.trial_runner import TrialRunner from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter -from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler, - FIFOScheduler, MedianStoppingRule) +from ray.tune.schedulers import FIFOScheduler logger = logging.getLogger(__name__) -_SCHEDULERS = { - "FIFO": FIFOScheduler, - "MedianStopping": MedianStoppingRule, - "HyperBand": HyperBandScheduler, - "AsyncHyperBand": AsyncHyperBandScheduler, -} - try: class_name = get_ipython().__class__.__name__ IS_NOTEBOOK = True if "Terminal" not in class_name else False @@ -33,14 +25,6 @@ except NameError: IS_NOTEBOOK = False -def _make_scheduler(args): - if args.scheduler in _SCHEDULERS: - return _SCHEDULERS[args.scheduler](**args.scheduler_config) - else: - raise TuneError("Unknown scheduler: {}, should be one of {}".format( - args.scheduler, _SCHEDULERS.keys())) - - def _check_default_resources_override(run_identifier): if not isinstance(run_identifier, str): # If obscure dtype, assume it is overridden. diff --git a/rllib/train.py b/rllib/train.py index f16205943..a89a23f18 100755 --- a/rllib/train.py +++ b/rllib/train.py @@ -10,7 +10,8 @@ from ray.cluster_utils import Cluster from ray.tune.config_parser import make_parser from ray.tune.result import DEFAULT_RESULTS_DIR from ray.tune.resources import resources_to_json -from ray.tune.tune import _make_scheduler, run_experiments +from ray.tune.tune import run_experiments +from ray.tune.schedulers import create_scheduler from ray.rllib.utils.framework import try_import_tf, try_import_torch # Try to import both backends for flag checking/warnings. @@ -207,7 +208,7 @@ def run(args, parser): run_experiments( experiments, - scheduler=_make_scheduler(args), + scheduler=create_scheduler(args.scheduler, **args.scheduler_config), resume=args.resume, queue_trials=args.queue_trials, verbose=verbose,