[tune] better signature check for tune.sample_from (#13171)

* [tune] better signature check for `tune.sample_from`

* Update python/ray/tune/sample.py

Co-authored-by: Sumanth Ratna <sumanthratna@gmail.com>

Co-authored-by: Sumanth Ratna <sumanthratna@gmail.com>
This commit is contained in:
Kai Fricke
2021-01-05 17:04:18 +01:00
committed by GitHub
parent e8162f1b1f
commit 96c2d3d2b5
2 changed files with 43 additions and 6 deletions
+17 -6
View File
@@ -333,8 +333,7 @@ class Function(Domain):
domain: "Function",
spec: Optional[Union[List[Dict], Dict]] = None,
size: int = 1):
pass_spec = len(signature(domain.func).parameters) > 0
if pass_spec:
if domain.pass_spec:
items = [
domain.func(spec[i] if isinstance(spec, list) else spec)
for i in range(size)
@@ -347,11 +346,23 @@ class Function(Domain):
default_sampler_cls = _CallSampler
def __init__(self, func: Callable):
if len(signature(func).parameters) > 1:
raise ValueError(
"The function passed to a `Function` parameter must accept "
"either 0 or 1 parameters.")
sig = signature(func)
pass_spec = True # whether we should pass `spec` when calling `func`
try:
sig.bind({})
except TypeError:
pass_spec = False
if not pass_spec:
try:
sig.bind()
except TypeError as exc:
raise ValueError(
"The function passed to a `Function` parameter must be "
"callable with either 0 or 1 parameters.") from exc
self.pass_spec = pass_spec
self.func = func
def is_function(self):
+26
View File
@@ -142,6 +142,32 @@ class SearchSpaceTest(unittest.TestCase):
self.assertTrue(any(-4 < s < 4 for s in samples))
self.assertTrue(-2 < np.mean(samples) < 2)
def testFunctionSignature(self):
from functools import partial
def sample_a():
return 0
def sample_b(spec):
return 1
def sample_c(spec, b="ok"):
return 2
def sample_d_invalid(spec, b):
return 3
sample_d_valid = partial(sample_d_invalid, b="ok")
for sample_fn in [sample_a, sample_b, sample_c, sample_d_valid]:
fn = tune.sample_from(sample_fn)
sample = fn.sample(None)
self.assertIsNotNone(sample)
with self.assertRaises(ValueError):
fn = tune.sample_from(sample_d_invalid)
print(fn.sample(None))
def testQuantized(self):
bounded_positive = tune.sample.Float(1e-4, 1e-1)