make remote functions pickleable (#281)

This commit is contained in:
Robert Nishihara
2016-07-19 16:05:45 -07:00
committed by Philipp Moritz
parent 6c96a05ab4
commit 3bf46ce5ac
2 changed files with 80 additions and 33 deletions
+64 -31
View File
@@ -533,6 +533,26 @@ def main_loop(worker=global_worker):
del task
del function
def _submit_task(func_name, args, worker=global_worker):
"""This is a wrapper around worker.submit_task.
We use this wrapper so that in the remote decorator, we can call _submit_task
instead of worker.submit_task. The difference is that when we attempt to
serialize remote functions, we don't attempt to serialize the worker object,
which cannot be serialized.
"""
return worker.submit_task(func_name, args)
def _mode(worker=global_worker):
"""This is a wrapper around worker.mode.
We use this wrapper so that in the remote decorator, we can call _mode()
instead of worker.mode. The difference is that when we attempt to serialize
remote functions, we don't attempt to serialize the worker object, which
cannot be serialized.
"""
return worker.mode
def remote(arg_types, return_types, worker=global_worker):
"""This decorator is used to create remote functions.
@@ -554,51 +574,60 @@ def remote(arg_types, return_types, worker=global_worker):
def func_call(*args, **kwargs):
"""This gets run immediately when a worker calls a remote function."""
args = list(args)
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in func_call.keyword_defaults[len(args):]]) # fill in the remaining arguments
if worker.mode == ray.PYTHON_MODE:
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
if _mode() == ray.PYTHON_MODE:
# In ray.PYTHON_MODE, remote calls simply execute the function. We copy
# the arguments to prevent the function call from mutating them and to
# match the usual behavior of immutable remote objects.
return func(*copy.deepcopy(args))
check_arguments(func_call, args) # throws an exception if args are invalid
objrefs = worker.submit_task(func_call.func_name, args)
check_arguments(arg_types, has_vararg_param, func_name, args) # throws an exception if args are invalid
objrefs = _submit_task(func_name, args)
if len(objrefs) == 1:
return objrefs[0]
elif len(objrefs) > 1:
return objrefs
func_call.func_name = "{}.{}".format(func.__module__, func.__name__)
func_call.executor = func_executor
func_call.arg_types = arg_types
func_call.return_types = return_types
func_call.is_remote = True
func_call.sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
func_call.keyword_defaults = [(k, v.default) for k, v in func_call.sig_params]
func_call.has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in func_call.sig_params])
func_call.has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in func_call.sig_params])
check_signature_supported(func_call)
func_name = "{}.{}".format(func.__module__, func.__name__)
func_call.func_name = func_name
func_call.func_doc = func.func_doc
sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
keyword_defaults = [(k, v.default) for k, v in sig_params]
has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params])
func_call.has_vararg_param = has_vararg_param
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
if to_export is not None:
ray.lib.export_function(worker.handle, to_export)
return func_call
return remote_decorator
def check_signature_supported(function):
def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name):
"""Check if we support the signature of this function.
We currently do not allow remote functions to have **kwargs. We also do not
support keyword argumens in conjunction with a *args argument.
Args:
function (Callable): The function to check.
has_kwards_param (bool): True if the function being checked has a **kwargs
argument.
has_vararg_param (bool): True if the function being checked has a *args
argument.
keyword_defaults (List): A list of the default values for the arguments to
the function being checked.
name (str): The name of the function to check.
Raises:
Exception: An exception is raised if the signature is not supported.
"""
# check if the user specified kwargs
if function.has_kwargs_param:
raise "Function {} has a **kwargs argument, which is currently not supported.".format(function.__name__)
if has_kwargs_param:
raise "Function {} has a **kwargs argument, which is currently not supported.".format(name)
# check if the user specified a variable number of arguments and any keyword arguments
if function.has_vararg_param and any([d != funcsigs._empty for _, d in function.keyword_defaults]):
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(function.__name__)
if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]):
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name)
def check_return_values(function, result):
@@ -636,14 +665,14 @@ def check_return_values(function, result):
if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], ray.lib.ObjRef)):
raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjRef.".format(i, function.__name__, type(result[i]), function.return_types[i]))
def typecheck_arg(arg, expected_type, i, function):
def typecheck_arg(arg, expected_type, i, name):
"""Check that an argument has the expected type.
Args:
arg: An argument to function.
expected_type (type): The expected type of arg.
i (int): The position of the argument to the function.
function (Callable): The remote function whose argument is being checked.
name (str): The name of the function.
Raises:
Exception: An exception is raised if arg does not have the expected type.
@@ -656,32 +685,36 @@ def typecheck_arg(arg, expected_type, i, function):
# TODO(mehrdadn): Should long really be convertible to int?
pass
else:
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(arg), expected_type))
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, name, type(arg), expected_type))
def check_arguments(function, args):
def check_arguments(arg_types, has_vararg_param, name, args):
"""Check that the arguments to the remote function have the right types.
This is called by the worker that calls the remote function (not the worker
that executes the remote function).
Args:
function (Callable): The remote function whose arguments are being checked.
args (List): The arguments to the function
arg_types (List[type]): A list of the types of the arguments to the function
being checked.
has_vararg_param (bool): True if the function being checked has a *args
argument.
name (str): The name of the function.
args (List): The arguments to the function.
Raises:
Exception: An exception is raised the args do not all have the right types.
"""
# check the number of args
if len(args) != len(function.arg_types) and not function.has_vararg_param:
raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args)))
elif len(args) < len(function.arg_types) - 1 and function.has_vararg_param:
raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args)))
if len(args) != len(arg_types) and not has_vararg_param:
raise Exception("Function {} expects {} arguments, but received {}.".format(name, len(arg_types), len(args)))
elif len(args) < len(arg_types) - 1 and has_vararg_param:
raise Exception("Function {} expects at least {} arguments, but received {}.".format(name, len(arg_types) - 1, len(args)))
for (i, arg) in enumerate(args):
if i <= len(function.arg_types) - 1:
expected_type = function.arg_types[i]
elif function.has_vararg_param:
expected_type = function.arg_types[-1]
if i <= len(arg_types) - 1:
expected_type = arg_types[i]
elif has_vararg_param:
expected_type = arg_types[-1]
else:
assert False, "This code should be unreachable."
@@ -689,7 +722,7 @@ def check_arguments(function, args):
# TODO(rkn): When we have type information in the ObjRef, do type checking here.
pass
else:
typecheck_arg(arg, expected_type, i, function)
typecheck_arg(arg, expected_type, i, name)
def get_arguments_for_execution(function, args, worker=global_worker):
"""Retrieve the arguments for the remote function.