mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 19:33:06 +08:00
Reintroduce passing arguments by value to remote functions. (#425)
* Reintroduce passing arguments by value to remote functions. * Check size of arguments passed by value. * Fix computation graph visualization.
This commit is contained in:
committed by
Philipp Moritz
parent
0191d42751
commit
ba56b08474
@@ -28,7 +28,7 @@ def graph_to_graphviz(computation_graph):
|
||||
creator_operationid = op.creator_operationid if op.creator_operationid != 2 ** 64 - 1 else "-root"
|
||||
dot.edge("op" + str(creator_operationid), "op" + str(i), style="dotted", constraint="false")
|
||||
for arg in op.task.arg:
|
||||
if not arg.HasField("obj"):
|
||||
dot.node(str(arg.id))
|
||||
dot.edge(str(arg.id), "op" + str(i))
|
||||
if len(arg.serialized_arg) == 0:
|
||||
dot.node(str(arg.objectid))
|
||||
dot.edge(str(arg.objectid), "op" + str(i))
|
||||
return dot
|
||||
|
||||
@@ -3,6 +3,87 @@ import pickling
|
||||
import libraylib as raylib
|
||||
import libnumbuf
|
||||
|
||||
def is_argument_serializable(value):
|
||||
"""Checks if value is a composition of primitive types.
|
||||
|
||||
This will return True if the argument is one of the following:
|
||||
- An int
|
||||
- A float
|
||||
- A bool
|
||||
- None
|
||||
- A list of length at most 100 whose elements are serializable
|
||||
- A tuple of length at most 100 whose elements are serializable
|
||||
- A dict of length at most 100 whose keys and values are serializable
|
||||
- A string of length at most 100.
|
||||
- A unicode string of length at most 100.
|
||||
|
||||
Args:
|
||||
value: A Python object.
|
||||
|
||||
Returns:
|
||||
True if the object can be serialized as a composition of primitive types and
|
||||
False otherwise.
|
||||
"""
|
||||
t = type(value)
|
||||
if t is int or t is float or t is long or t is bool or value is None:
|
||||
return True
|
||||
if t is list:
|
||||
if len(value) <= 100:
|
||||
for element in value:
|
||||
if not is_argument_serializable(element):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
if t is tuple:
|
||||
if len(value) <= 100:
|
||||
for element in value:
|
||||
if not is_argument_serializable(element):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
if t is dict:
|
||||
if len(value) <= 100:
|
||||
for k, v in value.iteritems():
|
||||
if not is_argument_serializable(k) or not is_argument_serializable(v):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
if t is str:
|
||||
return len(value) <= 100
|
||||
if t is unicode:
|
||||
return len(value) <= 100
|
||||
return False
|
||||
|
||||
def serialize_argument_if_possible(value):
|
||||
"""This method serializes arguments that are passed by value.
|
||||
|
||||
The result will be deserialized by deserialize_argument.
|
||||
|
||||
Returns:
|
||||
None if value cannot be efficiently serialized or is too big, and otherwise
|
||||
this returns the serialized value as a string.
|
||||
"""
|
||||
if not is_argument_serializable(value):
|
||||
# The argument is not obviously serializable using __repr__, so we will not
|
||||
# serialize it.
|
||||
return None
|
||||
serialized_value = value.__repr__()
|
||||
if len(serialized_value) > 1000:
|
||||
# The argument is too big, so we will not pass it by value.
|
||||
return None
|
||||
# Return the serialized argument.
|
||||
return serialized_value
|
||||
|
||||
def deserialize_argument(serialized_value):
|
||||
"""This method deserializes arguments that are passed by value.
|
||||
|
||||
The argument will have been serialized by serialize_argument.
|
||||
"""
|
||||
return eval(serialized_value)
|
||||
|
||||
def check_serializable(cls):
|
||||
"""Throws an exception if Ray cannot serialize this class efficiently.
|
||||
|
||||
|
||||
@@ -413,8 +413,20 @@ class Worker(object):
|
||||
"""
|
||||
# Convert all of the argumens to object IDs. It is a little strange that we
|
||||
# are calling put, which is external to this class.
|
||||
args = [arg if isinstance(arg, raylib.ObjectID) else put(arg, worker=self) for arg in args]
|
||||
task_capsule = raylib.serialize_task(self.handle, func_name, args)
|
||||
serialized_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, raylib.ObjectID):
|
||||
next_arg = arg
|
||||
else:
|
||||
serialized_arg = serialization.serialize_argument_if_possible(arg)
|
||||
if serialized_arg is not None:
|
||||
# Serialize the argument and pass it by value.
|
||||
next_arg = serialized_arg
|
||||
else:
|
||||
# Put the objet in the object store under the hood.
|
||||
next_arg = put(arg)
|
||||
serialized_args.append(next_arg)
|
||||
task_capsule = raylib.serialize_task(self.handle, func_name, serialized_args)
|
||||
objectids = raylib.submit_task(self.handle, task_capsule)
|
||||
return objectids
|
||||
|
||||
@@ -935,9 +947,9 @@ def main_loop(worker=global_worker):
|
||||
After the task executes, the worker resets any reusable variables that were
|
||||
accessed by the task.
|
||||
"""
|
||||
function_name, args, return_objectids = task
|
||||
function_name, serialized_args, return_objectids = task
|
||||
try:
|
||||
arguments = get_arguments_for_execution(worker.functions[function_name], args, worker) # get args from objstore
|
||||
arguments = get_arguments_for_execution(worker.functions[function_name], serialized_args, worker) # get args from objstore
|
||||
outputs = worker.functions[function_name].executor(arguments) # execute the function
|
||||
if len(return_objectids) == 1:
|
||||
outputs = (outputs,)
|
||||
@@ -1197,7 +1209,7 @@ def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaul
|
||||
if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]):
|
||||
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name)
|
||||
|
||||
def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
def get_arguments_for_execution(function, serialized_args, worker=global_worker):
|
||||
"""Retrieve the arguments for the remote function.
|
||||
|
||||
This retrieves the values for the arguments to the remote function that were
|
||||
@@ -1207,7 +1219,9 @@ def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
Args:
|
||||
function (Callable): The remote function whose arguments are being
|
||||
retrieved.
|
||||
args (List): The arguments to the function.
|
||||
serialized_args (List): The arguments to the function. These are either
|
||||
strings representing serialized objects passed by value or they are
|
||||
ObjectIDs.
|
||||
|
||||
Returns:
|
||||
The retrieved arguments in addition to the arguments that were passed by
|
||||
@@ -1218,7 +1232,7 @@ def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
the arguments failed.
|
||||
"""
|
||||
arguments = []
|
||||
for (i, arg) in enumerate(args):
|
||||
for (i, arg) in enumerate(serialized_args):
|
||||
if isinstance(arg, raylib.ObjectID):
|
||||
# get the object from the local object store
|
||||
_logger().info("Getting argument {} for function {}.".format(i, function.__name__))
|
||||
@@ -1230,7 +1244,7 @@ def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
_logger().info("Successfully retrieved argument {} for function {}.".format(i, function.__name__))
|
||||
else:
|
||||
# pass the argument by value
|
||||
argument = arg
|
||||
argument = serialization.deserialize_argument(arg)
|
||||
|
||||
arguments.append(argument)
|
||||
return arguments
|
||||
|
||||
Reference in New Issue
Block a user