diff --git a/python/ray/signature.py b/python/ray/signature.py index c4ae60aa3..5c6da1766 100644 --- a/python/ray/signature.py +++ b/python/ray/signature.py @@ -4,6 +4,7 @@ from __future__ import print_function from collections import namedtuple import funcsigs +from funcsigs import Parameter from ray.utils import is_cython @@ -14,15 +15,16 @@ FunctionSignature = namedtuple("FunctionSignature", [ """This class is used to represent a function signature. Attributes: - keyword_names: The names of the functions keyword arguments. This is used - to test if an incorrect keyword argument has been passed to the - function. + arg_names: A list containing the name of all arguments. arg_defaults: A dictionary mapping from argument name to argument default value. If the argument is not a keyword argument, the default value will be funcsigs._empty. arg_is_positionals: A dictionary mapping from argument name to a bool. The bool will be true if the argument is a *args argument. Otherwise it will be false. + keyword_names: A set containing the names of the keyword arguments. + Note most arguments in Python can be called as positional or keyword + arguments, so this overlaps (sometimes completely) with arg_names. function_name: The name of the function whose signature is being inspected. This is used for printing better error messages. """ @@ -85,16 +87,13 @@ def check_signature_supported(func, warn=False): function_name = func.__name__ sig_params = get_signature_params(func) - has_vararg_param = False has_kwargs_param = False - has_keyword_arg = False + has_kwonly_param = False for keyword_name, parameter in sig_params: - if parameter.kind == parameter.VAR_KEYWORD: + if parameter.kind == Parameter.VAR_KEYWORD: has_kwargs_param = True - if parameter.kind == parameter.VAR_POSITIONAL: - has_vararg_param = True - if parameter.default != funcsigs._empty: - has_keyword_arg = True + if parameter.kind == Parameter.KEYWORD_ONLY: + has_kwonly_param = True if has_kwargs_param: message = ("The function {} has a **kwargs argument, which is " @@ -103,12 +102,11 @@ def check_signature_supported(func, warn=False): print(message) else: raise Exception(message) - # Check if the user specified a variable number of arguments and any - # keyword arguments. - if has_vararg_param and has_keyword_arg: - message = ("Function {} has a *args argument as well as a keyword " - "argument, which is currently not supported." - .format(function_name)) + + if has_kwonly_param: + message = ("The function {} has a keyword only argument " + "(defined after * or *args), which is currently " + "not supported.".format(function_name)) if warn: print(message) else: @@ -136,20 +134,18 @@ def extract_signature(func, ignore_first=False): func.__name__)) sig_params = sig_params[1:] - # Extract the names of the keyword arguments. - keyword_names = set() - for keyword_name, parameter in sig_params: - if parameter.default != funcsigs._empty: - keyword_names.add(keyword_name) - # Construct the argument default values and other argument information. arg_names = [] arg_defaults = [] arg_is_positionals = [] - for keyword_name, parameter in sig_params: - arg_names.append(keyword_name) + keyword_names = set() + for arg_name, parameter in sig_params: + arg_names.append(arg_name) arg_defaults.append(parameter.default) arg_is_positionals.append(parameter.kind == parameter.VAR_POSITIONAL) + if parameter.kind == Parameter.POSITIONAL_OR_KEYWORD: + # Note KEYWORD_ONLY arguments currently unsupported. + keyword_names.add(arg_name) return FunctionSignature(arg_names, arg_defaults, arg_is_positionals, keyword_names, func.__name__) @@ -189,8 +185,14 @@ def extend_args(function_signature, args, kwargs): keyword_name, function_name)) # Fill in the remaining arguments. - zipped_info = list(zip(arg_names, arg_defaults, - arg_is_positionals))[len(args):] + for skipped_name in arg_names[0:len(args)]: + if skipped_name in kwargs: + raise Exception("Positional and keyword value provided for the " + "argument '{}' for the function '{}'".format( + keyword_name, function_name)) + + zipped_info = zip(arg_names, arg_defaults, arg_is_positionals) + zipped_info = list(zipped_info)[len(args):] for keyword_name, default_value, is_positional in zipped_info: if keyword_name in kwargs: args.append(kwargs[keyword_name]) @@ -206,9 +208,8 @@ def extend_args(function_signature, args, kwargs): "'{}' for the function '{}'.".format( keyword_name, function_name)) - too_many_arguments = (len(args) > len(arg_names) - and (len(arg_is_positionals) == 0 - or not arg_is_positionals[-1])) + no_positionals = len(arg_is_positionals) == 0 or not arg_is_positionals[-1] + too_many_arguments = len(args) > len(arg_names) and no_positionals if too_many_arguments: raise Exception("Too many arguments were passed to the function '{}'" .format(function_name)) diff --git a/python/ray/test/test_functions.py b/python/ray/test/test_functions.py index 4883624fe..b2b7ac7d1 100644 --- a/python/ray/test/test_functions.py +++ b/python/ray/test/test_functions.py @@ -68,16 +68,6 @@ try: except Exception: kwargs_exception_thrown = True -try: - - @ray.remote - def varargs_and_kwargs_throw_exception(a, b="hi", *c): - return "{} {} {}".format(a, b, c) - - varargs_and_kwargs_exception_thrown = False -except Exception: - varargs_and_kwargs_exception_thrown = True - # test throwing an exception diff --git a/test/actor_test.py b/test/actor_test.py index 51e37e846..3d99c6349 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -57,6 +57,9 @@ class ActorAPI(unittest.TestCase): self.assertEqual( ray.get(actor.get_values.remote(0, arg2="d", arg1=0)), (1, 2, "cd")) + self.assertEqual( + ray.get(actor.get_values.remote(arg2="d", arg1=0, arg0=2)), + (3, 2, "cd")) # Make sure we get an exception if the constructor is called # incorrectly. @@ -66,6 +69,9 @@ class ActorAPI(unittest.TestCase): with self.assertRaises(Exception): actor = Actor.remote(0, 1, 2, arg3=3) + with self.assertRaises(Exception): + actor = Actor.remote(0, arg0=1) + # Make sure we get an exception if the method is called incorrectly. actor = Actor.remote(1) with self.assertRaises(Exception): diff --git a/test/runtest.py b/test/runtest.py index b47e586af..694574b00 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -529,6 +529,8 @@ class APITest(unittest.TestCase): self.assertEqual(ray.get(x), "1 hi") x = test_functions.keyword_fct1.remote(1, b="world") self.assertEqual(ray.get(x), "1 world") + x = test_functions.keyword_fct1.remote(a=1, b="world") + self.assertEqual(ray.get(x), "1 world") x = test_functions.keyword_fct2.remote(a="w", b="hi") self.assertEqual(ray.get(x), "w hi") @@ -545,6 +547,10 @@ class APITest(unittest.TestCase): x = test_functions.keyword_fct3.remote(0, 1, c="w", d="hi") self.assertEqual(ray.get(x), "0 1 w hi") + x = test_functions.keyword_fct3.remote(0, b=1, c="w", d="hi") + self.assertEqual(ray.get(x), "0 1 w hi") + x = test_functions.keyword_fct3.remote(a=0, b=1, c="w", d="hi") + self.assertEqual(ray.get(x), "0 1 w hi") x = test_functions.keyword_fct3.remote(0, 1, d="hi", c="w") self.assertEqual(ray.get(x), "0 1 w hi") x = test_functions.keyword_fct3.remote(0, 1, c="w") @@ -553,6 +559,8 @@ class APITest(unittest.TestCase): self.assertEqual(ray.get(x), "0 1 hello hi") x = test_functions.keyword_fct3.remote(0, 1) self.assertEqual(ray.get(x), "0 1 hello world") + x = test_functions.keyword_fct3.remote(a=0, b=1) + self.assertEqual(ray.get(x), "0 1 hello world") # Check that we cannot pass invalid keyword arguments to functions. @ray.remote @@ -573,6 +581,9 @@ class APITest(unittest.TestCase): with self.assertRaises(Exception): f2.remote(0, w=0) + with self.assertRaises(Exception): + f2.remote(3, x=3) + # Make sure we get an exception if too many arguments are passed in. with self.assertRaises(Exception): f2.remote(1, 2, 3, 4) @@ -593,7 +604,6 @@ class APITest(unittest.TestCase): self.assertEqual(ray.get(x), "1 2") self.assertTrue(test_functions.kwargs_exception_thrown) - self.assertTrue(test_functions.varargs_and_kwargs_exception_thrown) @ray.remote def f1(*args):