[tune] implement shim instantiation (#10456)

* Create ray.tune.suggest.create.create_scheduler

* Update __init__.py

* Resolve conflict in __init__.py

* Create ray.tune.schedulers.create.create_scheduler

* Update __init__.py

* Move create_scheduler to tune.schedulers.__init__

* Move create_searcher to tune.suggest.__init__

* Delete tune.suggest.create

* Delete tune.schedulers.create

* Update imports for shim functions in tune.__init__

* Remove shim from tune.suggest.__init__.__all__

* Remove shim from tune.schedulers.__init__.__all__

* Add ShimCreationTest

* Move ShimCreationTest to test_api

* Delete test_shim.py

* Add docstring for ray.tune.create_scheduler

* Add docstring to ray.tune.create_searcher

* Fix typo in ray.tune.create_scheduler docstring

* Fix lint errors in tune.schedulers.__init__

* Fix lint errors in tune.suggest.__init__

* Fix lint errors in tune.suggest.__init__

* Fix lint errors in tune.schedulers.__init__

* Fix imports in test_api

* Fix lint errors in test_api

* Fix kwargs in create_searcher

* Fix kwargs in create_scheduler

* Merge branch 'master' into shim-instantiation

* Update use-case in docs in tune.create_scheduler

* Update use-case in docs in tune.create_searcher

* Remove duplicate pytest run from test_api

* Add check to create_searcher


Co-authored-by: Richard Liaw <rliaw@berkeley.edu>

* Add check to create_scheduler

* lint

* Compare types of instances in test_api

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>

* Add tune.create_searcher to docs

* Fix doc build

* Fix tests

* Add tune.create_scheduler to docs

* Fix tests

* Fix lint errors

* Update Ax search for master

* Fix metric kwarg for Ax in test_api

* Fix doc build

* Fix HyperOptSearch import in test_api

* Fix HyperOptSearch import in create_searcher

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Sumanth Ratna
2020-09-05 12:36:42 -04:00
committed by GitHub
parent f03e91788a
commit 54215ff287
6 changed files with 197 additions and 2 deletions
+6
View File
@@ -196,3 +196,9 @@ TrialScheduler
.. autoclass:: ray.tune.schedulers.TrialScheduler
:members:
Shim Instantiation (tune.create_scheduler)
------------------------------------------
There is also a shim function that constructs the scheduler based on the provided string. This can be useful if the scheduler you want to use changes often (e.g., specifying the scheduler via a CLI option or config file).
.. automethod:: ray.tune.create_scheduler
+9
View File
@@ -79,6 +79,7 @@ Tune also provides helpful utilities to use with Search Algorithms:
* :ref:`repeater`: Support for running each *sampled hyperparameter* with multiple random seeds.
* :ref:`limiter`: Limits the amount of concurrent trials when running optimization.
* :ref:`shim`: Allows creation of the search algorithm object given a string.
Saving and Restoring
--------------------
@@ -268,3 +269,11 @@ If you are interested in implementing or contributing a new Search Algorithm, pr
:members:
:private-members:
:show-inheritance:
.. _shim:
Shim Instantiation (tune.create_searcher)
-----------------------------------------
There is also a shim function that constructs the search algorithm based on the provided string. This can be useful if the search algorithm you want to use changes often (e.g., specifying the search algorithm via a CLI option or config file).
.. automethod:: ray.tune.create_searcher
+3 -1
View File
@@ -15,6 +15,8 @@ from ray.tune.progress_reporter import (ProgressReporter, CLIReporter,
from ray.tune.sample import (function, sample_from, uniform, quniform, choice,
randint, qrandint, randn, qrandn, loguniform,
qloguniform)
from ray.tune.suggest import create_searcher
from ray.tune.schedulers import create_scheduler
__all__ = [
"Trainable", "DurableTrainable", "TuneError", "grid_search",
@@ -24,5 +26,5 @@ __all__ = [
"loguniform", "qloguniform", "ExperimentAnalysis", "Analysis",
"CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report",
"get_trial_dir", "get_trial_name", "get_trial_id", "make_checkpoint_dir",
"save_checkpoint", "checkpoint_dir"
"save_checkpoint", "checkpoint_dir", "create_searcher", "create_scheduler"
]
+63
View File
@@ -7,6 +7,69 @@ from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule
from ray.tune.schedulers.pbt import (PopulationBasedTraining,
PopulationBasedTrainingReplay)
def create_scheduler(
scheduler,
metric="episode_reward_mean",
mode="max",
**kwargs,
):
"""Instantiate a scheduler based on the given string.
This is useful for swapping between different schedulers.
Args:
scheduler (str): The scheduler to use.
metric (str): The training result objective value attribute. Stopping
procedures will use this attribute.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
**kwargs: Additional parameters.
These keyword arguments will be passed to the initialization
function of the chosen class.
Returns:
ray.tune.schedulers.trial_scheduler.TrialScheduler: The scheduler.
Example:
>>> scheduler = tune.create_scheduler('pbt')
"""
def _import_async_hyperband_scheduler():
from ray.tune.schedulers import AsyncHyperBandScheduler
return AsyncHyperBandScheduler
def _import_median_stopping_rule_scheduler():
from ray.tune.schedulers import MedianStoppingRule
return MedianStoppingRule
def _import_hyperband_scheduler():
from ray.tune.schedulers import HyperBandScheduler
return HyperBandScheduler
def _import_hb_bohb_scheduler():
from ray.tune.schedulers import HyperBandForBOHB
return HyperBandForBOHB
def _import_pbt_search():
from ray.tune.schedulers import PopulationBasedTraining
return PopulationBasedTraining
SCHEDULER_IMPORT = {
"async_hyperband": _import_async_hyperband_scheduler,
"median_stopping_rule": _import_median_stopping_rule_scheduler,
"hyperband": _import_hyperband_scheduler,
"hb_bohb": _import_hb_bohb_scheduler,
"pbt": _import_pbt_search,
}
scheduler = scheduler.lower()
if scheduler not in SCHEDULER_IMPORT:
raise ValueError(
f"Search alg must be one of {list(SCHEDULER_IMPORT)}. "
f"Got: {scheduler}")
SchedulerClass = SCHEDULER_IMPORT[scheduler]()
return SchedulerClass(metric=metric, mode=mode, **kwargs)
__all__ = [
"TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler",
"ASHAScheduler", "MedianStoppingRule", "FIFOScheduler",
+88
View File
@@ -5,6 +5,94 @@ from ray.tune.suggest.search_generator import SearchGenerator
from ray.tune.suggest.variant_generator import grid_search
from ray.tune.suggest.repeater import Repeater
def create_searcher(
search_alg,
metric="episode_reward_mean",
mode="max",
**kwargs,
):
"""Instantiate a search algorithm based on the given string.
This is useful for swapping between different search algorithms.
Args:
search_alg (str): The search algorithm to use.
metric (str): The training result objective value attribute. Stopping
procedures will use this attribute.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
**kwargs: Additional parameters.
These keyword arguments will be passed to the initialization
function of the chosen class.
Returns:
ray.tune.suggest.Searcher: The search algorithm.
Example:
>>> search_alg = tune.create_searcher('ax')
"""
def _import_ax_search():
from ray.tune.suggest.ax import AxSearch
return AxSearch
def _import_dragonfly_search():
from ray.tune.suggest.dragonfly import DragonflySearch
return DragonflySearch
def _import_skopt_search():
from ray.tune.suggest.skopt import SkOptSearch
return SkOptSearch
def _import_hyperopt_search():
from ray.tune.suggest.hyperopt import HyperOptSearch
return HyperOptSearch
def _import_bayesopt_search():
from ray.tune.suggest.bayesopt import BayesOptSearch
return BayesOptSearch
def _import_bohb_search():
from ray.tune.suggest.bohb import TuneBOHB
return TuneBOHB
def _import_nevergrad_search():
from ray.tune.suggest.nevergrad import NevergradSearch
return NevergradSearch
def _import_optuna_search():
from ray.tune.suggest.optuna import OptunaSearch
return OptunaSearch
def _import_zoopt_search():
from ray.tune.suggest.zoopt import ZOOptSearch
return ZOOptSearch
def _import_sigopt_search():
from ray.tune.suggest.sigopt import SigOptSearch
return SigOptSearch
SEARCH_ALG_IMPORT = {
"ax": _import_ax_search,
"dragonfly": _import_dragonfly_search,
"skopt": _import_skopt_search,
"hyperopt": _import_hyperopt_search,
"bayesopt": _import_bayesopt_search,
"bohb": _import_bohb_search,
"nevergrad": _import_nevergrad_search,
"optuna": _import_optuna_search,
"zoopt": _import_zoopt_search,
"sigopt": _import_sigopt_search,
}
search_alg = search_alg.lower()
if search_alg not in SEARCH_ALG_IMPORT:
raise ValueError(
f"Search alg must be one of {list(SEARCH_ALG_IMPORT)}. "
f"Got: {search_alg}")
SearcherClass = SEARCH_ALG_IMPORT[search_alg]()
return SearcherClass(metric=metric, mode=mode, **kwargs)
__all__ = [
"SearchAlgorithm", "Searcher", "BasicVariantGenerator", "SearchGenerator",
"grid_search", "Repeater", "ConcurrencyLimiter"
+28 -1
View File
@@ -14,7 +14,8 @@ from ray import tune
from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper,
EarlyStopping)
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
AsyncHyperBandScheduler)
from ray.tune.trial import Trial
from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
EPISODES_TOTAL, TRAINING_ITERATION,
@@ -24,6 +25,8 @@ from ray.tune.logger import Logger
from ray.tune.experiment import Experiment
from ray.tune.resources import Resources
from ray.tune.suggest import grid_search
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.suggest.ax import AxSearch
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
from ray.tune.utils import (flatten_dict, get_pinned_object,
pin_in_object_store)
@@ -1105,6 +1108,30 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertIn("LOG_STDERR", content)
class ShimCreationTest(unittest.TestCase):
def testCreateScheduler(self):
kwargs = {"metric": "metric_foo", "mode": "min"}
scheduler = "async_hyperband"
shim_scheduler = tune.create_scheduler(scheduler, **kwargs)
real_scheduler = AsyncHyperBandScheduler(**kwargs)
assert type(shim_scheduler) is type(real_scheduler)
def testCreateSearcher(self):
kwargs = {"metric": "metric_foo", "mode": "min"}
searcher_ax = "ax"
shim_searcher_ax = tune.create_searcher(searcher_ax, **kwargs)
real_searcher_ax = AxSearch(space=[], **kwargs)
assert type(shim_searcher_ax) is type(real_searcher_ax)
searcher_hyperopt = "hyperopt"
shim_searcher_hyperopt = tune.create_searcher(searcher_hyperopt,
**kwargs)
real_searcher_hyperopt = HyperOptSearch({}, **kwargs)
assert type(shim_searcher_hyperopt) is type(real_searcher_hyperopt)
if __name__ == "__main__":
import pytest
import sys