[tune] trim kwargs in shim instantiation functions (#12544)

This commit is contained in:
Kaushik B
2020-12-03 01:37:00 +05:30
committed by GitHub
parent da42bf29d0
commit 7422abddb4
4 changed files with 29 additions and 4 deletions
+6 -1
View File
@@ -1,3 +1,4 @@
from ray.utils import get_function_args
from ray.tune.schedulers.trial_scheduler import TrialScheduler, FIFOScheduler
from ray.tune.schedulers.hyperband import HyperBandScheduler
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
@@ -53,7 +54,11 @@ def create_scheduler(
f"Got: {scheduler}")
SchedulerClass = SCHEDULER_IMPORT[scheduler]
return SchedulerClass(**kwargs)
scheduler_args = get_function_args(SchedulerClass)
trimmed_kwargs = {k: v for k, v in kwargs.items() if k in scheduler_args}
return SchedulerClass(**trimmed_kwargs)
__all__ = [
+6 -1
View File
@@ -1,3 +1,4 @@
from ray.utils import get_function_args
from ray.tune.suggest.search import SearchAlgorithm
from ray.tune.suggest.basic_variant import BasicVariantGenerator
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
@@ -93,7 +94,11 @@ def create_searcher(
f"Got: {search_alg}")
SearcherClass = SEARCH_ALG_IMPORT[search_alg]()
return SearcherClass(**kwargs)
search_alg_args = get_function_args(SearcherClass)
trimmed_kwargs = {k: v for k, v in kwargs.items() if k in search_alg_args}
return SearcherClass(**trimmed_kwargs)
__all__ = [
+9
View File
@@ -1195,6 +1195,15 @@ class ShimCreationTest(unittest.TestCase):
real_searcher_hyperopt = HyperOptSearch({}, **kwargs)
assert type(shim_searcher_hyperopt) is type(real_searcher_hyperopt)
def testExtraParams(self):
kwargs = {"metric": "metric_foo", "mode": "min", "extra_param": "test"}
scheduler = "async_hyperband"
tune.create_scheduler(scheduler, **kwargs)
searcher_ax = "ax"
tune.create_searcher(searcher_ax, **kwargs)
class ApiTestFast(unittest.TestCase):
@classmethod
+8 -2
View File
@@ -3,7 +3,6 @@ import errno
import hashlib
import logging
import multiprocessing
import numpy as np
import os
import signal
import subprocess
@@ -12,11 +11,13 @@ import tempfile
import threading
import time
import uuid
from inspect import signature
import numpy as np
import psutil
import ray
import ray.gcs_utils
import ray.ray_constants as ray_constants
import psutil
pwd = None
if sys.platform != "win32":
@@ -806,3 +807,8 @@ def get_user():
return pwd.getpwuid(os.getuid()).pw_name
except Exception:
return ""
def get_function_args(callable):
all_parameters = frozenset(signature(callable).parameters)
return list(all_parameters)