diff --git a/python/ray/tune/sample.py b/python/ray/tune/sample.py index 4a2180b5d..e4d349ee9 100644 --- a/python/ray/tune/sample.py +++ b/python/ray/tune/sample.py @@ -333,8 +333,7 @@ class Function(Domain): domain: "Function", spec: Optional[Union[List[Dict], Dict]] = None, size: int = 1): - pass_spec = len(signature(domain.func).parameters) > 0 - if pass_spec: + if domain.pass_spec: items = [ domain.func(spec[i] if isinstance(spec, list) else spec) for i in range(size) @@ -347,11 +346,23 @@ class Function(Domain): default_sampler_cls = _CallSampler def __init__(self, func: Callable): - if len(signature(func).parameters) > 1: - raise ValueError( - "The function passed to a `Function` parameter must accept " - "either 0 or 1 parameters.") + sig = signature(func) + pass_spec = True # whether we should pass `spec` when calling `func` + try: + sig.bind({}) + except TypeError: + pass_spec = False + + if not pass_spec: + try: + sig.bind() + except TypeError as exc: + raise ValueError( + "The function passed to a `Function` parameter must be " + "callable with either 0 or 1 parameters.") from exc + + self.pass_spec = pass_spec self.func = func def is_function(self): diff --git a/python/ray/tune/tests/test_sample.py b/python/ray/tune/tests/test_sample.py index 4c1d1a1cb..1b6821439 100644 --- a/python/ray/tune/tests/test_sample.py +++ b/python/ray/tune/tests/test_sample.py @@ -142,6 +142,32 @@ class SearchSpaceTest(unittest.TestCase): self.assertTrue(any(-4 < s < 4 for s in samples)) self.assertTrue(-2 < np.mean(samples) < 2) + def testFunctionSignature(self): + from functools import partial + + def sample_a(): + return 0 + + def sample_b(spec): + return 1 + + def sample_c(spec, b="ok"): + return 2 + + def sample_d_invalid(spec, b): + return 3 + + sample_d_valid = partial(sample_d_invalid, b="ok") + + for sample_fn in [sample_a, sample_b, sample_c, sample_d_valid]: + fn = tune.sample_from(sample_fn) + sample = fn.sample(None) + self.assertIsNotNone(sample) + + with self.assertRaises(ValueError): + fn = tune.sample_from(sample_d_invalid) + print(fn.sample(None)) + def testQuantized(self): bounded_positive = tune.sample.Float(1e-4, 1e-1)