mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:00:10 +08:00
make remote functions pickleable (#281)
This commit is contained in:
committed by
Philipp Moritz
parent
6c96a05ab4
commit
3bf46ce5ac
+64
-31
@@ -533,6 +533,26 @@ def main_loop(worker=global_worker):
|
||||
del task
|
||||
del function
|
||||
|
||||
def _submit_task(func_name, args, worker=global_worker):
|
||||
"""This is a wrapper around worker.submit_task.
|
||||
|
||||
We use this wrapper so that in the remote decorator, we can call _submit_task
|
||||
instead of worker.submit_task. The difference is that when we attempt to
|
||||
serialize remote functions, we don't attempt to serialize the worker object,
|
||||
which cannot be serialized.
|
||||
"""
|
||||
return worker.submit_task(func_name, args)
|
||||
|
||||
def _mode(worker=global_worker):
|
||||
"""This is a wrapper around worker.mode.
|
||||
|
||||
We use this wrapper so that in the remote decorator, we can call _mode()
|
||||
instead of worker.mode. The difference is that when we attempt to serialize
|
||||
remote functions, we don't attempt to serialize the worker object, which
|
||||
cannot be serialized.
|
||||
"""
|
||||
return worker.mode
|
||||
|
||||
def remote(arg_types, return_types, worker=global_worker):
|
||||
"""This decorator is used to create remote functions.
|
||||
|
||||
@@ -554,51 +574,60 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
def func_call(*args, **kwargs):
|
||||
"""This gets run immediately when a worker calls a remote function."""
|
||||
args = list(args)
|
||||
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in func_call.keyword_defaults[len(args):]]) # fill in the remaining arguments
|
||||
if worker.mode == ray.PYTHON_MODE:
|
||||
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
|
||||
if _mode() == ray.PYTHON_MODE:
|
||||
# In ray.PYTHON_MODE, remote calls simply execute the function. We copy
|
||||
# the arguments to prevent the function call from mutating them and to
|
||||
# match the usual behavior of immutable remote objects.
|
||||
return func(*copy.deepcopy(args))
|
||||
check_arguments(func_call, args) # throws an exception if args are invalid
|
||||
objrefs = worker.submit_task(func_call.func_name, args)
|
||||
check_arguments(arg_types, has_vararg_param, func_name, args) # throws an exception if args are invalid
|
||||
objrefs = _submit_task(func_name, args)
|
||||
if len(objrefs) == 1:
|
||||
return objrefs[0]
|
||||
elif len(objrefs) > 1:
|
||||
return objrefs
|
||||
func_call.func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
func_call.executor = func_executor
|
||||
func_call.arg_types = arg_types
|
||||
func_call.return_types = return_types
|
||||
func_call.is_remote = True
|
||||
func_call.sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
|
||||
func_call.keyword_defaults = [(k, v.default) for k, v in func_call.sig_params]
|
||||
func_call.has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in func_call.sig_params])
|
||||
func_call.has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in func_call.sig_params])
|
||||
check_signature_supported(func_call)
|
||||
func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
func_call.func_name = func_name
|
||||
func_call.func_doc = func.func_doc
|
||||
sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
|
||||
keyword_defaults = [(k, v.default) for k, v in sig_params]
|
||||
has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params])
|
||||
func_call.has_vararg_param = has_vararg_param
|
||||
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
|
||||
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
|
||||
if to_export is not None:
|
||||
ray.lib.export_function(worker.handle, to_export)
|
||||
return func_call
|
||||
return remote_decorator
|
||||
|
||||
def check_signature_supported(function):
|
||||
def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name):
|
||||
"""Check if we support the signature of this function.
|
||||
|
||||
We currently do not allow remote functions to have **kwargs. We also do not
|
||||
support keyword argumens in conjunction with a *args argument.
|
||||
|
||||
Args:
|
||||
function (Callable): The function to check.
|
||||
has_kwards_param (bool): True if the function being checked has a **kwargs
|
||||
argument.
|
||||
has_vararg_param (bool): True if the function being checked has a *args
|
||||
argument.
|
||||
keyword_defaults (List): A list of the default values for the arguments to
|
||||
the function being checked.
|
||||
name (str): The name of the function to check.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if the signature is not supported.
|
||||
"""
|
||||
# check if the user specified kwargs
|
||||
if function.has_kwargs_param:
|
||||
raise "Function {} has a **kwargs argument, which is currently not supported.".format(function.__name__)
|
||||
if has_kwargs_param:
|
||||
raise "Function {} has a **kwargs argument, which is currently not supported.".format(name)
|
||||
# check if the user specified a variable number of arguments and any keyword arguments
|
||||
if function.has_vararg_param and any([d != funcsigs._empty for _, d in function.keyword_defaults]):
|
||||
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(function.__name__)
|
||||
if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]):
|
||||
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name)
|
||||
|
||||
|
||||
def check_return_values(function, result):
|
||||
@@ -636,14 +665,14 @@ def check_return_values(function, result):
|
||||
if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], ray.lib.ObjRef)):
|
||||
raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjRef.".format(i, function.__name__, type(result[i]), function.return_types[i]))
|
||||
|
||||
def typecheck_arg(arg, expected_type, i, function):
|
||||
def typecheck_arg(arg, expected_type, i, name):
|
||||
"""Check that an argument has the expected type.
|
||||
|
||||
Args:
|
||||
arg: An argument to function.
|
||||
expected_type (type): The expected type of arg.
|
||||
i (int): The position of the argument to the function.
|
||||
function (Callable): The remote function whose argument is being checked.
|
||||
name (str): The name of the function.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if arg does not have the expected type.
|
||||
@@ -656,32 +685,36 @@ def typecheck_arg(arg, expected_type, i, function):
|
||||
# TODO(mehrdadn): Should long really be convertible to int?
|
||||
pass
|
||||
else:
|
||||
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(arg), expected_type))
|
||||
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, name, type(arg), expected_type))
|
||||
|
||||
def check_arguments(function, args):
|
||||
def check_arguments(arg_types, has_vararg_param, name, args):
|
||||
"""Check that the arguments to the remote function have the right types.
|
||||
|
||||
This is called by the worker that calls the remote function (not the worker
|
||||
that executes the remote function).
|
||||
|
||||
Args:
|
||||
function (Callable): The remote function whose arguments are being checked.
|
||||
args (List): The arguments to the function
|
||||
arg_types (List[type]): A list of the types of the arguments to the function
|
||||
being checked.
|
||||
has_vararg_param (bool): True if the function being checked has a *args
|
||||
argument.
|
||||
name (str): The name of the function.
|
||||
args (List): The arguments to the function.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised the args do not all have the right types.
|
||||
"""
|
||||
# check the number of args
|
||||
if len(args) != len(function.arg_types) and not function.has_vararg_param:
|
||||
raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args)))
|
||||
elif len(args) < len(function.arg_types) - 1 and function.has_vararg_param:
|
||||
raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args)))
|
||||
if len(args) != len(arg_types) and not has_vararg_param:
|
||||
raise Exception("Function {} expects {} arguments, but received {}.".format(name, len(arg_types), len(args)))
|
||||
elif len(args) < len(arg_types) - 1 and has_vararg_param:
|
||||
raise Exception("Function {} expects at least {} arguments, but received {}.".format(name, len(arg_types) - 1, len(args)))
|
||||
|
||||
for (i, arg) in enumerate(args):
|
||||
if i <= len(function.arg_types) - 1:
|
||||
expected_type = function.arg_types[i]
|
||||
elif function.has_vararg_param:
|
||||
expected_type = function.arg_types[-1]
|
||||
if i <= len(arg_types) - 1:
|
||||
expected_type = arg_types[i]
|
||||
elif has_vararg_param:
|
||||
expected_type = arg_types[-1]
|
||||
else:
|
||||
assert False, "This code should be unreachable."
|
||||
|
||||
@@ -689,7 +722,7 @@ def check_arguments(function, args):
|
||||
# TODO(rkn): When we have type information in the ObjRef, do type checking here.
|
||||
pass
|
||||
else:
|
||||
typecheck_arg(arg, expected_type, i, function)
|
||||
typecheck_arg(arg, expected_type, i, name)
|
||||
|
||||
def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
"""Retrieve the arguments for the remote function.
|
||||
|
||||
Reference in New Issue
Block a user