diff --git a/python/ray/actor.py b/python/ray/actor.py index 04f476e0a..6688a57df 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -11,6 +11,7 @@ import traceback import ray.local_scheduler import ray.pickling as pickling +import ray.signature as signature import ray.worker import ray.experimental.state as state @@ -206,12 +207,12 @@ def actor(*args, **kwargs): def make_actor(Class): # The function actor_method_call gets called if somebody tries to call a # method on their local actor stub object. - def actor_method_call(actor_id, attr, *args, **kwargs): + def actor_method_call(actor_id, attr, function_signature, *args, + **kwargs): ray.worker.check_connected() ray.worker.check_main_thread() - args = list(args) - if len(kwargs) > 0: - raise Exception("Actors currently do not support **kwargs.") + args = signature.extend_args(function_signature, args, kwargs) + function_id = get_actor_method_function_id(attr) # TODO(pcm): Extend args with keyword args. object_ids = ray.worker.global_worker.submit_task(function_id, "", @@ -229,12 +230,27 @@ def actor(*args, **kwargs): k: v for (k, v) in inspect.getmembers( Class, predicate=(lambda x: (inspect.isfunction(x) or inspect.ismethod(x))))} + # Extract the signatures of each of the methods. This will be used to + # catch some errors if the methods are called with inappropriate + # arguments. + self._ray_method_signatures = dict() + for k, v in self._ray_actor_methods.items(): + # Print a warning message if the method signature is not supported. + # We don't raise an exception because if the actor inherits from a + # class that has a method whose signature we don't support, we + # there may not be much the user can do about it. + signature.check_signature_supported(v, warn=True) + self._ray_method_signatures[k] = signature.extract_signature( + v, ignore_first=True) + export_actor(self._ray_actor_id, Class, self._ray_actor_methods.keys(), num_cpus, num_gpus, ray.worker.global_worker) # Call __init__ as a remote function. if "__init__" in self._ray_actor_methods.keys(): - actor_method_call(self._ray_actor_id, "__init__", *args, **kwargs) + actor_method_call(self._ray_actor_id, "__init__", + self._ray_method_signatures["__init__"], + *args, **kwargs) else: print("WARNING: this object has no __init__ method.") @@ -244,11 +260,13 @@ def actor(*args, **kwargs): def __getattribute__(self, attr): # The following is needed so we can still access self.actor_methods. - if attr in ["_ray_actor_id", "_ray_actor_methods"]: + if attr in ["_ray_actor_id", "_ray_actor_methods", + "_ray_method_signatures"]: return super(NewClass, self).__getattribute__(attr) if attr in self._ray_actor_methods.keys(): return lambda *args, **kwargs: actor_method_call( - self._ray_actor_id, attr, *args, **kwargs) + self._ray_actor_id, attr, self._ray_method_signatures[attr], + *args, **kwargs) # There is no method with this name, so raise an exception. raise AttributeError("'{}' Actor object has no attribute '{}'" .format(Class, attr)) diff --git a/python/ray/signature.py b/python/ray/signature.py new file mode 100644 index 000000000..4c28a024b --- /dev/null +++ b/python/ray/signature.py @@ -0,0 +1,172 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +import funcsigs + +FunctionSignature = namedtuple("FunctionSignature", ["arg_names", + "arg_defaults", + "arg_is_positionals", + "keyword_names", + "function_name"]) +"""This class is used to represent a function signature. + +Attributes: + keyword_names: The names of the functions keyword arguments. This is used to + test if an incorrect keyword argument has been passed to the function. + arg_defaults: A dictionary mapping from argument name to argument default + value. If the argument is not a keyword argument, the default value will be + funcsigs._empty. + arg_is_positionals: A dictionary mapping from argument name to a bool. The + bool will be true if the argument is a *args argument. Otherwise it will be + false. + function_name: The name of the function whose signature is being inspected. + This is used for printing better error messages. +""" + + +def check_signature_supported(func, warn=False): + """Check if we support the signature of this function. + + We currently do not allow remote functions to have **kwargs. We also do not + support keyword arguments in conjunction with a *args argument. + + Args: + func: The function whose signature should be checked. + warn: If this is true, a warning will be printed if the signature is not + supported. If it is false, an exception will be raised if the signature + is not supported. + + Raises: + Exception: An exception is raised if the signature is not supported. + """ + function_name = func.__name__ + sig_params = [(k, v) for k, v + in funcsigs.signature(func).parameters.items()] + + has_vararg_param = False + has_kwargs_param = False + has_keyword_arg = False + for keyword_name, parameter in sig_params: + if parameter.kind == parameter.VAR_KEYWORD: + has_kwargs_param = True + if parameter.kind == parameter.VAR_POSITIONAL: + has_vararg_param = True + if parameter.default != funcsigs._empty: + has_keyword_arg = True + + if has_kwargs_param: + message = ("The function {} has a **kwargs argument, which is " + "currently not supported.".format(function_name)) + if warn: + print(message) + else: + raise Exception(message) + # Check if the user specified a variable number of arguments and any keyword + # arguments. + if has_vararg_param and has_keyword_arg: + message = ("Function {} has a *args argument as well as a keyword " + "argument, which is currently not supported." + .format(function_name)) + if warn: + print(message) + else: + raise Exception(message) + + +def extract_signature(func, ignore_first=False): + """Extract the function signature from the function. + + Args: + func: The function whose signature should be extracted. + ignore_first: True if the first argument should be ignored. This should be + used when func is a method of a class. + + Returns: + A function signature object, which includes the names of the keyword + arguments as well as their default values. + """ + sig_params = [(k, v) for k, v + in funcsigs.signature(func).parameters.items()] + + if ignore_first: + if len(sig_params) == 0: + raise Exception("Methods must take a 'self' argument, but the method " + "'{}' does not have one.".format(func.__name__)) + sig_params = sig_params[1:] + + # Extract the names of the keyword arguments. + keyword_names = set() + for keyword_name, parameter in sig_params: + if parameter.default != funcsigs._empty: + keyword_names.add(keyword_name) + + # Construct the argument default values and other argument information. + arg_names = [] + arg_defaults = [] + arg_is_positionals = [] + for keyword_name, parameter in sig_params: + arg_names.append(keyword_name) + arg_defaults.append(parameter.default) + arg_is_positionals.append(parameter.kind == parameter.VAR_POSITIONAL) + + return FunctionSignature(arg_names, arg_defaults, arg_is_positionals, + keyword_names, func.__name__) + + +def extend_args(function_signature, args, kwargs): + """Extend the arguments that were passed into a function. + + This extends the arguments that were passed into a function with the default + arguments provided in the function definition. + + Args: + function_signature: The function signature of the function being called. + args: The non-keyword arguments passed into the function. + kwargs: The keyword arguments passed into the function. + + Returns: + An extended list of arguments to pass into the function. + + Raises: + Exception: An exception may be raised if the function cannot be called with + these arguments. + """ + arg_names = function_signature.arg_names + arg_defaults = function_signature.arg_defaults + arg_is_positionals = function_signature.arg_is_positionals + keyword_names = function_signature.keyword_names + function_name = function_signature.function_name + + args = list(args) + + for keyword_name in kwargs: + if keyword_name not in keyword_names: + raise Exception("The name '{}' is not a valid keyword argument for the " + "function '{}'.".format(keyword_name, function_name)) + + # Fill in the remaining arguments. + zipped_info = list(zip(arg_names, arg_defaults, + arg_is_positionals))[len(args):] + for keyword_name, default_value, is_positional in zipped_info: + if keyword_name in kwargs: + args.append(kwargs[keyword_name]) + else: + if default_value != funcsigs._empty: + args.append(default_value) + else: + # This means that there is a missing argument. Unless this is the last + # argument and it is a *args argument in which case it can be omitted. + if not is_positional: + raise Exception("No value was provided for the argument '{}' for " + "the function '{}'.".format(keyword_name, + function_name)) + + too_many_arguments = (len(args) > len(arg_names) and + (len(arg_is_positionals) == 0 or + not arg_is_positionals[-1])) + if too_many_arguments: + raise Exception("Too many arguments were passed to the function '{}'" + .format(function_name)) + return args diff --git a/python/ray/worker.py b/python/ray/worker.py index a56507818..33f3dc9a0 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -6,7 +6,6 @@ import atexit import collections import colorama import copy -import funcsigs import hashlib import inspect import json @@ -24,6 +23,7 @@ import ray.experimental.state as state import ray.pickling as pickling import ray.serialization as serialization import ray.services as services +import ray.signature as signature import ray.numbuf import ray.local_scheduler import ray.plasma @@ -2067,13 +2067,8 @@ def remote(*args, **kwargs): """This gets run immediately when a worker calls a remote function.""" check_connected() check_main_thread() - args = list(args) - # Fill in the remaining arguments. - args.extend([kwargs[keyword] if keyword in kwargs else default - for keyword, default in keyword_defaults[len(args):]]) - if any([arg is funcsigs._empty for arg in args]): - raise Exception("Not enough arguments were provided to {}." - .format(func_name)) + args = signature.extend_args(function_signature, args, kwargs) + if _mode() == PYTHON_MODE: # In PYTHON_MODE, remote calls simply execute the function. We copy # the arguments to prevent the function call from mutating them and @@ -2111,15 +2106,8 @@ def remote(*args, **kwargs): else: func_invoker.func_doc = func.func_doc - sig_params = [(k, v) for k, v - in funcsigs.signature(func).parameters.items()] - keyword_defaults = [(k, v.default) for k, v in sig_params] - has_vararg_param = any([v.kind == v.VAR_POSITIONAL - for k, v in sig_params]) - func_invoker.has_vararg_param = has_vararg_param - has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params]) - check_signature_supported(has_kwargs_param, has_vararg_param, - keyword_defaults, func_name) + signature.check_signature_supported(func) + function_signature = signature.extract_signature(func) # Everything ready - export the function if worker.mode in [SCRIPT_MODE, SILENT_MODE]: @@ -2162,37 +2150,6 @@ def remote(*args, **kwargs): return make_remote_decorator(num_return_vals, num_cpus, num_gpus) -def check_signature_supported(has_kwargs_param, has_vararg_param, - keyword_defaults, name): - """Check if we support the signature of this function. - - We currently do not allow remote functions to have **kwargs. We also do not - support keyword arguments in conjunction with a *args argument. - - Args: - has_kwards_param (bool): True if the function being checked has a **kwargs - argument. - has_vararg_param (bool): True if the function being checked has a *args - argument. - keyword_defaults (List): A list of the default values for the arguments to - the function being checked. - name (str): The name of the function to check. - - Raises: - Exception: An exception is raised if the signature is not supported. - """ - # Check if the user specified kwargs. - if has_kwargs_param: - raise ("Function {} has a **kwargs argument, which is currently not " - "supported.".format(name)) - # Check if the user specified a variable number of arguments and any keyword - # arguments. - if has_vararg_param and any([d != funcsigs._empty - for _, d in keyword_defaults]): - raise ("Function {} has a *args argument as well as a keyword argument, " - "which is currently not supported.".format(name)) - - def get_arguments_for_execution(function_name, serialized_args, worker=global_worker): """Retrieve the arguments for the remote function. diff --git a/test/actor_test.py b/test/actor_test.py index 9f3f95302..901ff496a 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -32,17 +32,27 @@ class ActorAPI(unittest.TestCase): actor = Actor(1, 2, "c") self.assertEqual(ray.get(actor.get_values(2, 3, "d")), (3, 5, "cd")) + actor = Actor(1, arg2="c") + self.assertEqual(ray.get(actor.get_values(0, arg2="d")), (1, 3, "cd")) + self.assertEqual(ray.get(actor.get_values(0, arg2="d", arg1=0)), + (1, 1, "cd")) + + actor = Actor(1, arg2="c", arg1=2) + self.assertEqual(ray.get(actor.get_values(0, arg2="d")), (1, 4, "cd")) + self.assertEqual(ray.get(actor.get_values(0, arg2="d", arg1=0)), + (1, 2, "cd")) + # Make sure we get an exception if the constructor is called incorrectly. - actor = Actor() with self.assertRaises(Exception): - ray.get(ray.get(actor.get_values(1))) + actor = Actor() + with self.assertRaises(Exception): - ray.get(ray.get(actor.get_values())) + actor = Actor(0, 1, 2, arg3=3) # Make sure we get an exception if the method is called incorrectly. actor = Actor(1) with self.assertRaises(Exception): - ray.get(ray.get(actor.get_values())) + ray.get(actor.get_values()) ray.worker.cleanup() @@ -73,6 +83,21 @@ class ActorAPI(unittest.TestCase): self.assertEqual(ray.get(actor.get_values(2, 3, 1, 2, 3, 4)), (3, 5, ("a", "b", "c", "d"), (1, 2, 3, 4))) + @ray.actor + class Actor(object): + def __init__(self, *args): + self.args = args + + def get_values(self, *args): + return self.args, args + + a = Actor() + self.assertEqual(ray.get(a.get_values()), ((), ())) + a = Actor(1) + self.assertEqual(ray.get(a.get_values(2)), ((1,), (2,))) + a = Actor(1, 2) + self.assertEqual(ray.get(a.get_values(3, 4)), ((1, 2), (3, 4))) + ray.worker.cleanup() def testNoArgs(self): diff --git a/test/failure_test.py b/test/failure_test.py index b9d6312c6..f213d26ab 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -297,55 +297,25 @@ class ActorTest(unittest.TestCase): pass # Make sure that we get errors if we call the constructor incorrectly. - # TODO(rkn): These errors should instead be thrown when the method is - # called. # Create an actor with too few arguments. - a = Actor() - wait_for_errors(b"task", 1) - self.assertEqual(len(ray.error_info()), 1) - if sys.version_info >= (3, 0): - self.assertIn("missing 1 required", - ray.error_info()[0][b"message"].decode("ascii")) - else: - self.assertIn("takes exactly 2 arguments", - ray.error_info()[0][b"message"].decode("ascii")) + with self.assertRaises(Exception): + a = Actor() # Create an actor with too many arguments. - a = Actor(1, 2) - wait_for_errors(b"task", 2) - self.assertEqual(len(ray.error_info()), 2) - if sys.version_info >= (3, 0): - self.assertIn("but 3 were given", - ray.error_info()[1][b"message"].decode("ascii")) - else: - self.assertIn("takes exactly 2 arguments", - ray.error_info()[1][b"message"].decode("ascii")) + with self.assertRaises(Exception): + a = Actor(1, 2) # Create an actor the correct number of arguments. a = Actor(1) # Call a method with too few arguments. - a.get_val() - wait_for_errors(b"task", 3) - self.assertEqual(len(ray.error_info()), 3) - if sys.version_info >= (3, 0): - self.assertIn("missing 1 required", - ray.error_info()[2][b"message"].decode("ascii")) - else: - self.assertIn("takes exactly 2 arguments", - ray.error_info()[2][b"message"].decode("ascii")) + with self.assertRaises(Exception): + a.get_val() # Call a method with too many arguments. - a.get_val(1, 2) - wait_for_errors(b"task", 4) - self.assertEqual(len(ray.error_info()), 4) - if sys.version_info >= (3, 0): - self.assertIn("but 3 were given", - ray.error_info()[3][b"message"].decode("ascii")) - else: - self.assertIn("takes exactly 2 arguments", - ray.error_info()[3][b"message"].decode("ascii")) + with self.assertRaises(Exception): + a.get_val(1, 2) # Call a method that doesn't exist. with self.assertRaises(AttributeError): a.nonexistent_method() diff --git a/test/runtest.py b/test/runtest.py index 3e4ccd45b..888500386 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -312,6 +312,35 @@ class APITest(unittest.TestCase): x = test_functions.keyword_fct3.remote(0, 1) self.assertEqual(ray.get(x), "0 1 hello world") + # Check that we cannot pass invalid keyword arguments to functions. + @ray.remote + def f1(): + return + + @ray.remote + def f2(x, y=0, z=0): + return + + # Make sure we get an exception if too many arguments are passed in. + with self.assertRaises(Exception): + f1.remote(3) + + with self.assertRaises(Exception): + f1.remote(x=3) + + with self.assertRaises(Exception): + f2.remote(0, w=0) + + # Make sure we get an exception if too many arguments are passed in. + with self.assertRaises(Exception): + f2.remote(1, 2, 3, 4) + + @ray.remote + def f3(x): + return x + + self.assertEqual(ray.get(f3.remote(4)), 4) + ray.worker.cleanup() def testVariableNumberOfArgs(self): @@ -326,6 +355,25 @@ class APITest(unittest.TestCase): self.assertTrue(test_functions.kwargs_exception_thrown) self.assertTrue(test_functions.varargs_and_kwargs_exception_thrown) + @ray.remote + def f1(*args): + return args + + @ray.remote + def f2(x, y, *args): + return x, y, args + + self.assertEqual(ray.get(f1.remote()), ()) + self.assertEqual(ray.get(f1.remote(1)), (1,)) + self.assertEqual(ray.get(f1.remote(1, 2, 3)), (1, 2, 3)) + with self.assertRaises(Exception): + f2.remote() + with self.assertRaises(Exception): + f2.remote(1) + self.assertEqual(ray.get(f2.remote(1, 2)), (1, 2, ())) + self.assertEqual(ray.get(f2.remote(1, 2, 3)), (1, 2, (3,))) + self.assertEqual(ray.get(f2.remote(1, 2, 3, 4)), (1, 2, (3, 4))) + ray.worker.cleanup() def testNoArgs(self): @@ -1389,13 +1437,18 @@ class GlobalStateAPI(unittest.TestCase): x_id = ray.put(1) result_id = f.remote(1, "hi", x_id) - # Wait for one additional task for the driver. - wait_for_num_tasks(1 + 1) - task_table = ray.global_state.task_table() - self.assertEqual(len(task_table), 1 + 1) - task_id_set = set(task_table.keys()) - task_id_set.remove(driver_task_id) - task_id = list(task_id_set)[0] + # Wait for one additional task to complete. + start_time = time.time() + while time.time() - start_time < 10: + wait_for_num_tasks(1 + 1) + task_table = ray.global_state.task_table() + self.assertEqual(len(task_table), 1 + 1) + task_id_set = set(task_table.keys()) + task_id_set.remove(driver_task_id) + task_id = list(task_id_set)[0] + if task_table[task_id]["State"] == "DONE": + break + time.sleep(0.1) self.assertEqual(task_table[task_id]["TaskSpec"]["ActorID"], ID_SIZE * "ff") self.assertEqual(task_table[task_id]["TaskSpec"]["Args"], [1, "hi", x_id])