mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:06:31 +08:00
[tune] better signature check for tune.sample_from (#13171)
* [tune] better signature check for `tune.sample_from` * Update python/ray/tune/sample.py Co-authored-by: Sumanth Ratna <sumanthratna@gmail.com> Co-authored-by: Sumanth Ratna <sumanthratna@gmail.com>
This commit is contained in:
@@ -333,8 +333,7 @@ class Function(Domain):
|
||||
domain: "Function",
|
||||
spec: Optional[Union[List[Dict], Dict]] = None,
|
||||
size: int = 1):
|
||||
pass_spec = len(signature(domain.func).parameters) > 0
|
||||
if pass_spec:
|
||||
if domain.pass_spec:
|
||||
items = [
|
||||
domain.func(spec[i] if isinstance(spec, list) else spec)
|
||||
for i in range(size)
|
||||
@@ -347,11 +346,23 @@ class Function(Domain):
|
||||
default_sampler_cls = _CallSampler
|
||||
|
||||
def __init__(self, func: Callable):
|
||||
if len(signature(func).parameters) > 1:
|
||||
raise ValueError(
|
||||
"The function passed to a `Function` parameter must accept "
|
||||
"either 0 or 1 parameters.")
|
||||
sig = signature(func)
|
||||
|
||||
pass_spec = True # whether we should pass `spec` when calling `func`
|
||||
try:
|
||||
sig.bind({})
|
||||
except TypeError:
|
||||
pass_spec = False
|
||||
|
||||
if not pass_spec:
|
||||
try:
|
||||
sig.bind()
|
||||
except TypeError as exc:
|
||||
raise ValueError(
|
||||
"The function passed to a `Function` parameter must be "
|
||||
"callable with either 0 or 1 parameters.") from exc
|
||||
|
||||
self.pass_spec = pass_spec
|
||||
self.func = func
|
||||
|
||||
def is_function(self):
|
||||
|
||||
@@ -142,6 +142,32 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
self.assertTrue(any(-4 < s < 4 for s in samples))
|
||||
self.assertTrue(-2 < np.mean(samples) < 2)
|
||||
|
||||
def testFunctionSignature(self):
|
||||
from functools import partial
|
||||
|
||||
def sample_a():
|
||||
return 0
|
||||
|
||||
def sample_b(spec):
|
||||
return 1
|
||||
|
||||
def sample_c(spec, b="ok"):
|
||||
return 2
|
||||
|
||||
def sample_d_invalid(spec, b):
|
||||
return 3
|
||||
|
||||
sample_d_valid = partial(sample_d_invalid, b="ok")
|
||||
|
||||
for sample_fn in [sample_a, sample_b, sample_c, sample_d_valid]:
|
||||
fn = tune.sample_from(sample_fn)
|
||||
sample = fn.sample(None)
|
||||
self.assertIsNotNone(sample)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
fn = tune.sample_from(sample_d_invalid)
|
||||
print(fn.sample(None))
|
||||
|
||||
def testQuantized(self):
|
||||
bounded_positive = tune.sample.Float(1e-4, 1e-1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user