diff --git a/python/ray/tune/schedulers/__init__.py b/python/ray/tune/schedulers/__init__.py index a8a581da3..940724ef7 100644 --- a/python/ray/tune/schedulers/__init__.py +++ b/python/ray/tune/schedulers/__init__.py @@ -1,3 +1,4 @@ +from ray.utils import get_function_args from ray.tune.schedulers.trial_scheduler import TrialScheduler, FIFOScheduler from ray.tune.schedulers.hyperband import HyperBandScheduler from ray.tune.schedulers.hb_bohb import HyperBandForBOHB @@ -53,7 +54,11 @@ def create_scheduler( f"Got: {scheduler}") SchedulerClass = SCHEDULER_IMPORT[scheduler] - return SchedulerClass(**kwargs) + + scheduler_args = get_function_args(SchedulerClass) + trimmed_kwargs = {k: v for k, v in kwargs.items() if k in scheduler_args} + + return SchedulerClass(**trimmed_kwargs) __all__ = [ diff --git a/python/ray/tune/suggest/__init__.py b/python/ray/tune/suggest/__init__.py index 97a54a449..6cd0c6c87 100644 --- a/python/ray/tune/suggest/__init__.py +++ b/python/ray/tune/suggest/__init__.py @@ -1,3 +1,4 @@ +from ray.utils import get_function_args from ray.tune.suggest.search import SearchAlgorithm from ray.tune.suggest.basic_variant import BasicVariantGenerator from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter @@ -93,7 +94,11 @@ def create_searcher( f"Got: {search_alg}") SearcherClass = SEARCH_ALG_IMPORT[search_alg]() - return SearcherClass(**kwargs) + + search_alg_args = get_function_args(SearcherClass) + trimmed_kwargs = {k: v for k, v in kwargs.items() if k in search_alg_args} + + return SearcherClass(**trimmed_kwargs) __all__ = [ diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index ea325fa45..cfb1d3f6c 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -1195,6 +1195,15 @@ class ShimCreationTest(unittest.TestCase): real_searcher_hyperopt = HyperOptSearch({}, **kwargs) assert type(shim_searcher_hyperopt) is type(real_searcher_hyperopt) + def testExtraParams(self): + kwargs = {"metric": "metric_foo", "mode": "min", "extra_param": "test"} + + scheduler = "async_hyperband" + tune.create_scheduler(scheduler, **kwargs) + + searcher_ax = "ax" + tune.create_searcher(searcher_ax, **kwargs) + class ApiTestFast(unittest.TestCase): @classmethod diff --git a/python/ray/utils.py b/python/ray/utils.py index 6659d7eb9..a3940d6e8 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -3,7 +3,6 @@ import errno import hashlib import logging import multiprocessing -import numpy as np import os import signal import subprocess @@ -12,11 +11,13 @@ import tempfile import threading import time import uuid +from inspect import signature +import numpy as np +import psutil import ray import ray.gcs_utils import ray.ray_constants as ray_constants -import psutil pwd = None if sys.platform != "win32": @@ -806,3 +807,8 @@ def get_user(): return pwd.getpwuid(os.getuid()).pw_name except Exception: return "" + + +def get_function_args(callable): + all_parameters = frozenset(signature(callable).parameters) + return list(all_parameters)