diff --git a/lib/python/ray/graph.py b/lib/python/ray/graph.py index 981f4ad99..e11a7824e 100644 --- a/lib/python/ray/graph.py +++ b/lib/python/ray/graph.py @@ -28,7 +28,7 @@ def graph_to_graphviz(computation_graph): creator_operationid = op.creator_operationid if op.creator_operationid != 2 ** 64 - 1 else "-root" dot.edge("op" + str(creator_operationid), "op" + str(i), style="dotted", constraint="false") for arg in op.task.arg: - if not arg.HasField("obj"): - dot.node(str(arg.id)) - dot.edge(str(arg.id), "op" + str(i)) + if len(arg.serialized_arg) == 0: + dot.node(str(arg.objectid)) + dot.edge(str(arg.objectid), "op" + str(i)) return dot diff --git a/lib/python/ray/serialization.py b/lib/python/ray/serialization.py index 65bae2358..bff61c3f0 100644 --- a/lib/python/ray/serialization.py +++ b/lib/python/ray/serialization.py @@ -3,6 +3,87 @@ import pickling import libraylib as raylib import libnumbuf +def is_argument_serializable(value): + """Checks if value is a composition of primitive types. + + This will return True if the argument is one of the following: + - An int + - A float + - A bool + - None + - A list of length at most 100 whose elements are serializable + - A tuple of length at most 100 whose elements are serializable + - A dict of length at most 100 whose keys and values are serializable + - A string of length at most 100. + - A unicode string of length at most 100. + + Args: + value: A Python object. + + Returns: + True if the object can be serialized as a composition of primitive types and + False otherwise. + """ + t = type(value) + if t is int or t is float or t is long or t is bool or value is None: + return True + if t is list: + if len(value) <= 100: + for element in value: + if not is_argument_serializable(element): + return False + return True + else: + return False + if t is tuple: + if len(value) <= 100: + for element in value: + if not is_argument_serializable(element): + return False + return True + else: + return False + if t is dict: + if len(value) <= 100: + for k, v in value.iteritems(): + if not is_argument_serializable(k) or not is_argument_serializable(v): + return False + return True + else: + return False + if t is str: + return len(value) <= 100 + if t is unicode: + return len(value) <= 100 + return False + +def serialize_argument_if_possible(value): + """This method serializes arguments that are passed by value. + + The result will be deserialized by deserialize_argument. + + Returns: + None if value cannot be efficiently serialized or is too big, and otherwise + this returns the serialized value as a string. + """ + if not is_argument_serializable(value): + # The argument is not obviously serializable using __repr__, so we will not + # serialize it. + return None + serialized_value = value.__repr__() + if len(serialized_value) > 1000: + # The argument is too big, so we will not pass it by value. + return None + # Return the serialized argument. + return serialized_value + +def deserialize_argument(serialized_value): + """This method deserializes arguments that are passed by value. + + The argument will have been serialized by serialize_argument. + """ + return eval(serialized_value) + def check_serializable(cls): """Throws an exception if Ray cannot serialize this class efficiently. diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 6ceaaaeee..f48e7882e 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -413,8 +413,20 @@ class Worker(object): """ # Convert all of the argumens to object IDs. It is a little strange that we # are calling put, which is external to this class. - args = [arg if isinstance(arg, raylib.ObjectID) else put(arg, worker=self) for arg in args] - task_capsule = raylib.serialize_task(self.handle, func_name, args) + serialized_args = [] + for arg in args: + if isinstance(arg, raylib.ObjectID): + next_arg = arg + else: + serialized_arg = serialization.serialize_argument_if_possible(arg) + if serialized_arg is not None: + # Serialize the argument and pass it by value. + next_arg = serialized_arg + else: + # Put the objet in the object store under the hood. + next_arg = put(arg) + serialized_args.append(next_arg) + task_capsule = raylib.serialize_task(self.handle, func_name, serialized_args) objectids = raylib.submit_task(self.handle, task_capsule) return objectids @@ -935,9 +947,9 @@ def main_loop(worker=global_worker): After the task executes, the worker resets any reusable variables that were accessed by the task. """ - function_name, args, return_objectids = task + function_name, serialized_args, return_objectids = task try: - arguments = get_arguments_for_execution(worker.functions[function_name], args, worker) # get args from objstore + arguments = get_arguments_for_execution(worker.functions[function_name], serialized_args, worker) # get args from objstore outputs = worker.functions[function_name].executor(arguments) # execute the function if len(return_objectids) == 1: outputs = (outputs,) @@ -1197,7 +1209,7 @@ def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaul if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]): raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name) -def get_arguments_for_execution(function, args, worker=global_worker): +def get_arguments_for_execution(function, serialized_args, worker=global_worker): """Retrieve the arguments for the remote function. This retrieves the values for the arguments to the remote function that were @@ -1207,7 +1219,9 @@ def get_arguments_for_execution(function, args, worker=global_worker): Args: function (Callable): The remote function whose arguments are being retrieved. - args (List): The arguments to the function. + serialized_args (List): The arguments to the function. These are either + strings representing serialized objects passed by value or they are + ObjectIDs. Returns: The retrieved arguments in addition to the arguments that were passed by @@ -1218,7 +1232,7 @@ def get_arguments_for_execution(function, args, worker=global_worker): the arguments failed. """ arguments = [] - for (i, arg) in enumerate(args): + for (i, arg) in enumerate(serialized_args): if isinstance(arg, raylib.ObjectID): # get the object from the local object store _logger().info("Getting argument {} for function {}.".format(i, function.__name__)) @@ -1230,7 +1244,7 @@ def get_arguments_for_execution(function, args, worker=global_worker): _logger().info("Successfully retrieved argument {} for function {}.".format(i, function.__name__)) else: # pass the argument by value - argument = arg + argument = serialization.deserialize_argument(arg) arguments.append(argument) return arguments diff --git a/protos/graph.proto b/protos/graph.proto index a1044899b..00ff96266 100644 --- a/protos/graph.proto +++ b/protos/graph.proto @@ -1,8 +1,13 @@ syntax = "proto3"; +message Arg { + uint64 objectid = 1; // The objectid for the argument. + string serialized_arg = 2; // A serialized representation of an argument passed by value. +} + message Task { string name = 1; // Name of the function call. Must not be empty. - repeated uint64 arg = 2; // List of object IDs of the arguments to the function. + repeated Arg arg = 2; // List of object IDs of the arguments to the function. repeated uint64 result = 3; // Object IDs for result } diff --git a/src/raylib.cc b/src/raylib.cc index 96c53436e..58f94680d 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -578,9 +578,21 @@ static PyObject* serialize_task(PyObject* self, PyObject* args) { if (PyList_Check(arguments)) { for (size_t i = 0, size = PyList_Size(arguments); i < size; ++i) { PyObject* element = PyList_GetItem(arguments, i); - ObjectID objectid = ((PyObjectID*) element)->id; - task->add_arg(objectid); - objectids.push_back(objectid); + if (PyObject_IsInstance(element, (PyObject*)&PyObjectIDType)) { + // Handle the case where the argument to the task is an ObjectID. + ObjectID objectid = ((PyObjectID*) element)->id; + task->add_arg()->set_objectid(objectid); + objectids.push_back(objectid); + } else if (PyString_CheckExact(element)) { + // Handle the case where the argument to the task is being passed by + // value and we receive an argument serialized as a string here. + char* buffer; + Py_ssize_t length; + PyString_AsStringAndSize(element, &buffer, &length); + task->add_arg()->set_serialized_arg(std::string(buffer, length)); + } else { + RAY_CHECK(false, "This code should be unreachable."); + } } } else { PyErr_SetString(RayError, "serialize_task: second argument needs to be a list"); @@ -595,17 +607,6 @@ static PyObject* serialize_task(PyObject* self, PyObject* args) { std::string output; task->SerializeToString(&output); int task_size = output.length(); - if (task_size > 1024) { - // Large objects should not be passed to tasks by value. Instead, they - // should be placed in the object store and passed by object - // reference. - RAY_LOG(RAY_INFO, "Warning: attempting to serialize a task with size " << task_size << "."); - PyErr_SetString(RaySizeError, "serialize_task: This task is too large (greater than 1024 bytes). " - "Please do not pass large objects by value to remote functions. " - "Instead, put large objects in the object store and pass them by " - "object reference to the remote function."); - return NULL; - } return PyCapsule_New(static_cast(task), "task", &TaskCapsule_Destructor); } @@ -615,8 +616,13 @@ static PyObject* deserialize_task(PyObject* worker_capsule, const Task& task) { int argsize = task.arg_size(); PyObject* arglist = PyList_New(argsize); for (int i = 0; i < argsize; ++i) { - PyList_SetItem(arglist, i, make_pyobjectid(worker_capsule, task.arg(i))); - objectids.push_back(task.arg(i)); + if (task.arg(i).serialized_arg().empty()) { + PyList_SetItem(arglist, i, make_pyobjectid(worker_capsule, task.arg(i).objectid())); + objectids.push_back(task.arg(i).objectid()); + } else { + PyObject* serialized_arg = PyString_FromStringAndSize(task.arg(i).serialized_arg().data(), task.arg(i).serialized_arg().size()); + PyList_SetItem(arglist, i, serialized_arg); + } } Worker* worker; PyObjectToWorker(worker_capsule, &worker); diff --git a/src/scheduler.cc b/src/scheduler.cc index 1ce5d0a01..d6f090cfd 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -690,13 +690,15 @@ void SchedulerService::assign_task(OperationId operationid, WorkerId workerid, c AckReply reply; RAY_LOG(RAY_INFO, "starting to send arguments"); for (size_t i = 0; i < task.arg_size(); ++i) { - ObjectID objectid = task.arg(i); - ObjectID canonical_objectid = get_canonical_objectid(objectid); - // Notify the relevant objstore about potential aliasing when it's ready - GET(alias_notification_queue_)->push_back(std::make_pair(objstoreid, std::make_pair(objectid, canonical_objectid))); - attempt_notify_alias(objstoreid, objectid, canonical_objectid); - RAY_LOG(RAY_DEBUG, "task contains object ref " << canonical_objectid); - deliver_object_async_if_necessary(canonical_objectid, pick_objstore(canonical_objectid), objstoreid); + if (task.arg(i).serialized_arg().empty()) { + ObjectID objectid = task.arg(i).objectid(); + ObjectID canonical_objectid = get_canonical_objectid(objectid); + // Notify the relevant objstore about potential aliasing when it's ready + GET(alias_notification_queue_)->push_back(std::make_pair(objstoreid, std::make_pair(objectid, canonical_objectid))); + attempt_notify_alias(objstoreid, objectid, canonical_objectid); + RAY_LOG(RAY_DEBUG, "task contains object ref " << canonical_objectid); + deliver_object_async_if_necessary(canonical_objectid, pick_objstore(canonical_objectid), objstoreid); + } } { auto workers = GET(workers_); @@ -709,13 +711,15 @@ void SchedulerService::assign_task(OperationId operationid, WorkerId workerid, c bool SchedulerService::can_run(const Task& task) { auto objtable = GET(objtable_); for (int i = 0; i < task.arg_size(); ++i) { - ObjectID objectid = task.arg(i); - if (!has_canonical_objectid(objectid)) { - return false; - } - ObjectID canonical_objectid = get_canonical_objectid(objectid); - if (canonical_objectid >= objtable->size() || (*objtable)[canonical_objectid].size() == 0) { - return false; + if (task.arg(i).serialized_arg().empty()) { + ObjectID objectid = task.arg(i).objectid(); + if (!has_canonical_objectid(objectid)) { + return false; + } + ObjectID canonical_objectid = get_canonical_objectid(objectid); + if (canonical_objectid >= objtable->size() || (*objtable)[canonical_objectid].size() == 0) { + return false; + } } } return true; @@ -952,14 +956,16 @@ void SchedulerService::schedule_tasks_location_aware() { // determine how many objects would need to be shipped size_t num_shipped_objects = 0; for (int j = 0; j < task.arg_size(); ++j) { - ObjectID objectid = task.arg(j); - RAY_CHECK(has_canonical_objectid(objectid), "no canonical object ref found even though task is ready; that should not be possible!"); - ObjectID canonical_objectid = get_canonical_objectid(objectid); - { - // check if the object is already in the local object store - auto objtable = GET(objtable_); - if (!std::binary_search((*objtable)[canonical_objectid].begin(), (*objtable)[canonical_objectid].end(), objstoreid)) { - num_shipped_objects += 1; + if (task.arg(j).serialized_arg().empty()) { + ObjectID objectid = task.arg(j).objectid(); + RAY_CHECK(has_canonical_objectid(objectid), "no canonical object ref found even though task is ready; that should not be possible!"); + ObjectID canonical_objectid = get_canonical_objectid(objectid); + { + // check if the object is already in the local object store + auto objtable = GET(objtable_); + if (!std::binary_search((*objtable)[canonical_objectid].begin(), (*objtable)[canonical_objectid].end(), objstoreid)) { + num_shipped_objects += 1; + } } } } diff --git a/test/runtest.py b/test/runtest.py index 801e33746..1a40b6abd 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -201,6 +201,29 @@ class WorkerTest(unittest.TestCase): class APITest(unittest.TestCase): + def testPassingArgumentsByValue(self): + ray.init(start_ray_local=True, num_workers=0) + + # The types that can be passed by value are defined by + # is_argument_serializable in serialization.py. + class Foo(object): + pass + CAN_PASS_BY_VALUE = [1, 1L, 1.0, True, False, None, [1L, 1.0, True, None], + ([1, 2, 3], {False: [1.0, u"hi", ()]}), 100 * ["a"]] + CANNOT_PASS_BY_VALUE = [int, np.int64(0), np.float64(0), Foo(), [Foo()], + (Foo()), {0: Foo()}, [[[int]]], 101 * [1], + np.zeros(10)] + + for obj in CAN_PASS_BY_VALUE: + self.assertTrue(ray.serialization.is_argument_serializable(obj)) + self.assertEqual(obj, ray.serialization.deserialize_argument(ray.serialization.serialize_argument_if_possible(obj))) + + for obj in CANNOT_PASS_BY_VALUE: + self.assertFalse(ray.serialization.is_argument_serializable(obj)) + self.assertEqual(None, ray.serialization.serialize_argument_if_possible(obj)) + + ray.worker.cleanup() + def testRegisterClass(self): ray.init(start_ray_local=True, num_workers=0) @@ -408,6 +431,24 @@ class APITest(unittest.TestCase): ray.worker.cleanup() + def testComputationGraph(self): + ray.init(start_ray_local=True, num_workers=1) + + @ray.remote + def f(x): + return x + @ray.remote + def g(x, y): + return x, y + a = f.remote(1) + b = f.remote(1) + c = f.remote(a, b) + c = f.remote(a, 1) + # Make sure that we can produce a computation_graph visualization. + ray.visualize_computation_graph(view=False) + + ray.worker.cleanup() + class ReferenceCountingTest(unittest.TestCase): def testDeallocation(self):