diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index cb0e34dff..97b02339b 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -533,6 +533,26 @@ def main_loop(worker=global_worker): del task del function +def _submit_task(func_name, args, worker=global_worker): + """This is a wrapper around worker.submit_task. + + We use this wrapper so that in the remote decorator, we can call _submit_task + instead of worker.submit_task. The difference is that when we attempt to + serialize remote functions, we don't attempt to serialize the worker object, + which cannot be serialized. + """ + return worker.submit_task(func_name, args) + +def _mode(worker=global_worker): + """This is a wrapper around worker.mode. + + We use this wrapper so that in the remote decorator, we can call _mode() + instead of worker.mode. The difference is that when we attempt to serialize + remote functions, we don't attempt to serialize the worker object, which + cannot be serialized. + """ + return worker.mode + def remote(arg_types, return_types, worker=global_worker): """This decorator is used to create remote functions. @@ -554,51 +574,60 @@ def remote(arg_types, return_types, worker=global_worker): def func_call(*args, **kwargs): """This gets run immediately when a worker calls a remote function.""" args = list(args) - args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in func_call.keyword_defaults[len(args):]]) # fill in the remaining arguments - if worker.mode == ray.PYTHON_MODE: + args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments + if _mode() == ray.PYTHON_MODE: # In ray.PYTHON_MODE, remote calls simply execute the function. We copy # the arguments to prevent the function call from mutating them and to # match the usual behavior of immutable remote objects. return func(*copy.deepcopy(args)) - check_arguments(func_call, args) # throws an exception if args are invalid - objrefs = worker.submit_task(func_call.func_name, args) + check_arguments(arg_types, has_vararg_param, func_name, args) # throws an exception if args are invalid + objrefs = _submit_task(func_name, args) if len(objrefs) == 1: return objrefs[0] elif len(objrefs) > 1: return objrefs - func_call.func_name = "{}.{}".format(func.__module__, func.__name__) func_call.executor = func_executor func_call.arg_types = arg_types func_call.return_types = return_types func_call.is_remote = True - func_call.sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()] - func_call.keyword_defaults = [(k, v.default) for k, v in func_call.sig_params] - func_call.has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in func_call.sig_params]) - func_call.has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in func_call.sig_params]) - check_signature_supported(func_call) + func_name = "{}.{}".format(func.__module__, func.__name__) + func_call.func_name = func_name + func_call.func_doc = func.func_doc + sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()] + keyword_defaults = [(k, v.default) for k, v in sig_params] + has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params]) + func_call.has_vararg_param = has_vararg_param + has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params]) + check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name) if to_export is not None: ray.lib.export_function(worker.handle, to_export) return func_call return remote_decorator -def check_signature_supported(function): +def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name): """Check if we support the signature of this function. We currently do not allow remote functions to have **kwargs. We also do not support keyword argumens in conjunction with a *args argument. Args: - function (Callable): The function to check. + has_kwards_param (bool): True if the function being checked has a **kwargs + argument. + has_vararg_param (bool): True if the function being checked has a *args + argument. + keyword_defaults (List): A list of the default values for the arguments to + the function being checked. + name (str): The name of the function to check. Raises: Exception: An exception is raised if the signature is not supported. """ # check if the user specified kwargs - if function.has_kwargs_param: - raise "Function {} has a **kwargs argument, which is currently not supported.".format(function.__name__) + if has_kwargs_param: + raise "Function {} has a **kwargs argument, which is currently not supported.".format(name) # check if the user specified a variable number of arguments and any keyword arguments - if function.has_vararg_param and any([d != funcsigs._empty for _, d in function.keyword_defaults]): - raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(function.__name__) + if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]): + raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name) def check_return_values(function, result): @@ -636,14 +665,14 @@ def check_return_values(function, result): if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], ray.lib.ObjRef)): raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjRef.".format(i, function.__name__, type(result[i]), function.return_types[i])) -def typecheck_arg(arg, expected_type, i, function): +def typecheck_arg(arg, expected_type, i, name): """Check that an argument has the expected type. Args: arg: An argument to function. expected_type (type): The expected type of arg. i (int): The position of the argument to the function. - function (Callable): The remote function whose argument is being checked. + name (str): The name of the function. Raises: Exception: An exception is raised if arg does not have the expected type. @@ -656,32 +685,36 @@ def typecheck_arg(arg, expected_type, i, function): # TODO(mehrdadn): Should long really be convertible to int? pass else: - raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(arg), expected_type)) + raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, name, type(arg), expected_type)) -def check_arguments(function, args): +def check_arguments(arg_types, has_vararg_param, name, args): """Check that the arguments to the remote function have the right types. This is called by the worker that calls the remote function (not the worker that executes the remote function). Args: - function (Callable): The remote function whose arguments are being checked. - args (List): The arguments to the function + arg_types (List[type]): A list of the types of the arguments to the function + being checked. + has_vararg_param (bool): True if the function being checked has a *args + argument. + name (str): The name of the function. + args (List): The arguments to the function. Raises: Exception: An exception is raised the args do not all have the right types. """ # check the number of args - if len(args) != len(function.arg_types) and not function.has_vararg_param: - raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args))) - elif len(args) < len(function.arg_types) - 1 and function.has_vararg_param: - raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args))) + if len(args) != len(arg_types) and not has_vararg_param: + raise Exception("Function {} expects {} arguments, but received {}.".format(name, len(arg_types), len(args))) + elif len(args) < len(arg_types) - 1 and has_vararg_param: + raise Exception("Function {} expects at least {} arguments, but received {}.".format(name, len(arg_types) - 1, len(args))) for (i, arg) in enumerate(args): - if i <= len(function.arg_types) - 1: - expected_type = function.arg_types[i] - elif function.has_vararg_param: - expected_type = function.arg_types[-1] + if i <= len(arg_types) - 1: + expected_type = arg_types[i] + elif has_vararg_param: + expected_type = arg_types[-1] else: assert False, "This code should be unreachable." @@ -689,7 +722,7 @@ def check_arguments(function, args): # TODO(rkn): When we have type information in the ObjRef, do type checking here. pass else: - typecheck_arg(arg, expected_type, i, function) + typecheck_arg(arg, expected_type, i, name) def get_arguments_for_execution(function, args, worker=global_worker): """Retrieve the arguments for the remote function. diff --git a/test/runtest.py b/test/runtest.py index 4ad01a845..ba7087e02 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -274,7 +274,7 @@ class APITest(unittest.TestCase): def testDefiningRemoteFunctions(self): worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py") - ray.services.start_ray_local(num_workers=1, worker_path=worker_path, driver_mode=ray.SCRIPT_MODE) + ray.services.start_ray_local(num_workers=2, worker_path=worker_path, driver_mode=ray.SCRIPT_MODE) # Test that we can define a remote function in the shell. @ray.remote([int], [int]) @@ -305,6 +305,20 @@ class APITest(unittest.TestCase): return time.time() ray.get(j()) + # Test that we can define remote functions that call other remote functions. + @ray.remote([int], [int]) + def k(x): + return x + 1 + @ray.remote([int], [int]) + def l(x): + return k(x) + @ray.remote([int], [int]) + def m(x): + return ray.get(l(x)) + self.assertEqual(ray.get(k(1)), 2) + self.assertEqual(ray.get(l(1)), 2) + self.assertEqual(ray.get(m(1)), 2) + ray.services.cleanup() class TaskStatusTest(unittest.TestCase): @@ -448,7 +462,7 @@ class ReferenceCountingTest(unittest.TestCase): class PythonModeTest(unittest.TestCase): - def testObjRefAliasing(self): + def testPythonMode(self): ray.services.start_ray_local(driver_mode=ray.PYTHON_MODE) xref = test_functions.test_alias_h()