mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[tune] trim kwargs in shim instantiation functions (#12544)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from ray.utils import get_function_args
|
||||
from ray.tune.schedulers.trial_scheduler import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.schedulers.hyperband import HyperBandScheduler
|
||||
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
|
||||
@@ -53,7 +54,11 @@ def create_scheduler(
|
||||
f"Got: {scheduler}")
|
||||
|
||||
SchedulerClass = SCHEDULER_IMPORT[scheduler]
|
||||
return SchedulerClass(**kwargs)
|
||||
|
||||
scheduler_args = get_function_args(SchedulerClass)
|
||||
trimmed_kwargs = {k: v for k, v in kwargs.items() if k in scheduler_args}
|
||||
|
||||
return SchedulerClass(**trimmed_kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from ray.utils import get_function_args
|
||||
from ray.tune.suggest.search import SearchAlgorithm
|
||||
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
|
||||
@@ -93,7 +94,11 @@ def create_searcher(
|
||||
f"Got: {search_alg}")
|
||||
|
||||
SearcherClass = SEARCH_ALG_IMPORT[search_alg]()
|
||||
return SearcherClass(**kwargs)
|
||||
|
||||
search_alg_args = get_function_args(SearcherClass)
|
||||
trimmed_kwargs = {k: v for k, v in kwargs.items() if k in search_alg_args}
|
||||
|
||||
return SearcherClass(**trimmed_kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -1195,6 +1195,15 @@ class ShimCreationTest(unittest.TestCase):
|
||||
real_searcher_hyperopt = HyperOptSearch({}, **kwargs)
|
||||
assert type(shim_searcher_hyperopt) is type(real_searcher_hyperopt)
|
||||
|
||||
def testExtraParams(self):
|
||||
kwargs = {"metric": "metric_foo", "mode": "min", "extra_param": "test"}
|
||||
|
||||
scheduler = "async_hyperband"
|
||||
tune.create_scheduler(scheduler, **kwargs)
|
||||
|
||||
searcher_ax = "ax"
|
||||
tune.create_searcher(searcher_ax, **kwargs)
|
||||
|
||||
|
||||
class ApiTestFast(unittest.TestCase):
|
||||
@classmethod
|
||||
|
||||
+8
-2
@@ -3,7 +3,6 @@ import errno
|
||||
import hashlib
|
||||
import logging
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
@@ -12,11 +11,13 @@ import tempfile
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from inspect import signature
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
import ray
|
||||
import ray.gcs_utils
|
||||
import ray.ray_constants as ray_constants
|
||||
import psutil
|
||||
|
||||
pwd = None
|
||||
if sys.platform != "win32":
|
||||
@@ -806,3 +807,8 @@ def get_user():
|
||||
return pwd.getpwuid(os.getuid()).pw_name
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def get_function_args(callable):
|
||||
all_parameters = frozenset(signature(callable).parameters)
|
||||
return list(all_parameters)
|
||||
|
||||
Reference in New Issue
Block a user