Reintroduce passing arguments by value to remote functions. (#425)

* Reintroduce passing arguments by value to remote functions.

* Check size of arguments passed by value.

* Fix computation graph visualization.
This commit is contained in:
Robert Nishihara
2016-09-10 21:11:18 -07:00
committed by Philipp Moritz
parent 0191d42751
commit ba56b08474
7 changed files with 203 additions and 50 deletions
+3 -3
View File
@@ -28,7 +28,7 @@ def graph_to_graphviz(computation_graph):
creator_operationid = op.creator_operationid if op.creator_operationid != 2 ** 64 - 1 else "-root"
dot.edge("op" + str(creator_operationid), "op" + str(i), style="dotted", constraint="false")
for arg in op.task.arg:
if not arg.HasField("obj"):
dot.node(str(arg.id))
dot.edge(str(arg.id), "op" + str(i))
if len(arg.serialized_arg) == 0:
dot.node(str(arg.objectid))
dot.edge(str(arg.objectid), "op" + str(i))
return dot
+81
View File
@@ -3,6 +3,87 @@ import pickling
import libraylib as raylib
import libnumbuf
def is_argument_serializable(value):
"""Checks if value is a composition of primitive types.
This will return True if the argument is one of the following:
- An int
- A float
- A bool
- None
- A list of length at most 100 whose elements are serializable
- A tuple of length at most 100 whose elements are serializable
- A dict of length at most 100 whose keys and values are serializable
- A string of length at most 100.
- A unicode string of length at most 100.
Args:
value: A Python object.
Returns:
True if the object can be serialized as a composition of primitive types and
False otherwise.
"""
t = type(value)
if t is int or t is float or t is long or t is bool or value is None:
return True
if t is list:
if len(value) <= 100:
for element in value:
if not is_argument_serializable(element):
return False
return True
else:
return False
if t is tuple:
if len(value) <= 100:
for element in value:
if not is_argument_serializable(element):
return False
return True
else:
return False
if t is dict:
if len(value) <= 100:
for k, v in value.iteritems():
if not is_argument_serializable(k) or not is_argument_serializable(v):
return False
return True
else:
return False
if t is str:
return len(value) <= 100
if t is unicode:
return len(value) <= 100
return False
def serialize_argument_if_possible(value):
"""This method serializes arguments that are passed by value.
The result will be deserialized by deserialize_argument.
Returns:
None if value cannot be efficiently serialized or is too big, and otherwise
this returns the serialized value as a string.
"""
if not is_argument_serializable(value):
# The argument is not obviously serializable using __repr__, so we will not
# serialize it.
return None
serialized_value = value.__repr__()
if len(serialized_value) > 1000:
# The argument is too big, so we will not pass it by value.
return None
# Return the serialized argument.
return serialized_value
def deserialize_argument(serialized_value):
"""This method deserializes arguments that are passed by value.
The argument will have been serialized by serialize_argument.
"""
return eval(serialized_value)
def check_serializable(cls):
"""Throws an exception if Ray cannot serialize this class efficiently.
+22 -8
View File
@@ -413,8 +413,20 @@ class Worker(object):
"""
# Convert all of the argumens to object IDs. It is a little strange that we
# are calling put, which is external to this class.
args = [arg if isinstance(arg, raylib.ObjectID) else put(arg, worker=self) for arg in args]
task_capsule = raylib.serialize_task(self.handle, func_name, args)
serialized_args = []
for arg in args:
if isinstance(arg, raylib.ObjectID):
next_arg = arg
else:
serialized_arg = serialization.serialize_argument_if_possible(arg)
if serialized_arg is not None:
# Serialize the argument and pass it by value.
next_arg = serialized_arg
else:
# Put the objet in the object store under the hood.
next_arg = put(arg)
serialized_args.append(next_arg)
task_capsule = raylib.serialize_task(self.handle, func_name, serialized_args)
objectids = raylib.submit_task(self.handle, task_capsule)
return objectids
@@ -935,9 +947,9 @@ def main_loop(worker=global_worker):
After the task executes, the worker resets any reusable variables that were
accessed by the task.
"""
function_name, args, return_objectids = task
function_name, serialized_args, return_objectids = task
try:
arguments = get_arguments_for_execution(worker.functions[function_name], args, worker) # get args from objstore
arguments = get_arguments_for_execution(worker.functions[function_name], serialized_args, worker) # get args from objstore
outputs = worker.functions[function_name].executor(arguments) # execute the function
if len(return_objectids) == 1:
outputs = (outputs,)
@@ -1197,7 +1209,7 @@ def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaul
if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]):
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name)
def get_arguments_for_execution(function, args, worker=global_worker):
def get_arguments_for_execution(function, serialized_args, worker=global_worker):
"""Retrieve the arguments for the remote function.
This retrieves the values for the arguments to the remote function that were
@@ -1207,7 +1219,9 @@ def get_arguments_for_execution(function, args, worker=global_worker):
Args:
function (Callable): The remote function whose arguments are being
retrieved.
args (List): The arguments to the function.
serialized_args (List): The arguments to the function. These are either
strings representing serialized objects passed by value or they are
ObjectIDs.
Returns:
The retrieved arguments in addition to the arguments that were passed by
@@ -1218,7 +1232,7 @@ def get_arguments_for_execution(function, args, worker=global_worker):
the arguments failed.
"""
arguments = []
for (i, arg) in enumerate(args):
for (i, arg) in enumerate(serialized_args):
if isinstance(arg, raylib.ObjectID):
# get the object from the local object store
_logger().info("Getting argument {} for function {}.".format(i, function.__name__))
@@ -1230,7 +1244,7 @@ def get_arguments_for_execution(function, args, worker=global_worker):
_logger().info("Successfully retrieved argument {} for function {}.".format(i, function.__name__))
else:
# pass the argument by value
argument = arg
argument = serialization.deserialize_argument(arg)
arguments.append(argument)
return arguments