From 11a89146841396aaa40adcfa01bff671076166be Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Tue, 6 Sep 2016 13:28:24 -0700 Subject: [PATCH] Allow users to serialize custom classes. (#393) * Allow serialization of custom classes. * Add documentation and test cases, also fix pickle case. * Don't allow old-style classes. --- doc/install-on-macosx.md | 3 +- doc/install-on-ubuntu.md | 3 +- doc/remote-functions.md | 29 +-- docker/deploy/Dockerfile | 3 +- docker/devel/Dockerfile | 3 +- docker/test-base/Dockerfile | 3 +- install-dependencies.sh | 6 +- lib/python/ray/__init__.py | 2 +- lib/python/ray/array/distributed/core.py | 8 - lib/python/ray/serialization.py | 166 +++++++++++++---- lib/python/ray/worker.py | 220 ++++++++++++----------- protos/graph.proto | 4 +- protos/types.proto | 5 - src/raylib.cc | 81 +++------ src/scheduler.cc | 53 +++--- src/worker.cc | 50 ++---- src/worker.h | 4 +- test/array_test.py | 17 +- test/failure_test.py | 22 ++- test/microbenchmarks.py | 1 - test/runtest.py | 215 +++++++++++++--------- thirdparty/numbuf | 2 +- 22 files changed, 497 insertions(+), 403 deletions(-) diff --git a/doc/install-on-macosx.md b/doc/install-on-macosx.md index aa0e8b398..4cedea64b 100644 --- a/doc/install-on-macosx.md +++ b/doc/install-on-macosx.md @@ -19,7 +19,8 @@ brew update brew install git cmake automake autoconf libtool boost graphviz sudo easy_install pip sudo pip install ipython --user -sudo pip install numpy funcsigs subprocess32 protobuf colorama graphviz cloudpickle --ignore-installed six +sudo pip install numpy funcsigs subprocess32 protobuf colorama graphviz --ignore-installed six +sudo pip install git+git://github.com/cloudpipe/cloudpickle.git@0d225a4695f1f65ae1cbb2e0bbc145e10167cce4 # We use the latest version of cloudpickle because it can serialize named tuples. ``` ## Build diff --git a/doc/install-on-ubuntu.md b/doc/install-on-ubuntu.md index 19e831fee..dc94e7821 100644 --- a/doc/install-on-ubuntu.md +++ b/doc/install-on-ubuntu.md @@ -15,7 +15,8 @@ First install the dependencies. We currently do not support Python 3. ``` sudo apt-get update sudo apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz -sudo pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle +sudo pip install ipython funcsigs subprocess32 protobuf colorama graphviz +sudo pip install git+git://github.com/cloudpipe/cloudpickle.git@0d225a4695f1f65ae1cbb2e0bbc145e10167cce4 # We use the latest version of cloudpickle because it can serialize named tuples. ``` ## Build diff --git a/doc/remote-functions.md b/doc/remote-functions.md index 29590ad20..c0f409f2e 100644 --- a/doc/remote-functions.md +++ b/doc/remote-functions.md @@ -52,28 +52,11 @@ types in the object store. **The serializable types are:** 3. Object IDs 4. Lists, tuples, and dictionaries of other serializable types, but excluding custom classes (for example, `[1, 1.0, "hello"]`, `{True: "hi", 1: ["hi"]}`) -5. Custom classes where the user has provided `serialize` and `desererialize` -methods +5. Custom classes in many cases. You must explicitly register the class. -If you wish to define a custom class and to allow it to be serialized in the -object store, you must implement `serialize` and `deserialize` methods which -convert the object to and from primitive data types. A simple example is shown -below. + ```python + class Foo(object): + pass -```python -BLOCK_SIZE = 1000 - -class ExampleClass(object): - def __init__(self, field1, field2): - # This example assumes that field1 and field2 are serializable types. - self.field1 = field1 - self.field2 = field2 - - @staticmethod - def deserialize(primitives): - (field1, field2) = primitives - return ExampleClass(field1, field2) - - def serialize(self): - return (self.field1, self.field2) -``` + ray.register_class(Foo) + ``` diff --git a/docker/deploy/Dockerfile b/docker/deploy/Dockerfile index 5dd48fcbd..f5e5be190 100644 --- a/docker/deploy/Dockerfile +++ b/docker/deploy/Dockerfile @@ -6,7 +6,8 @@ RUN apt-get update RUN apt-get -y install apt-utils RUN apt-get -y install sudo RUN apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz -RUN pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle +RUN pip install ipython funcsigs subprocess32 protobuf colorama graphviz +RUN pip install git+git://github.com/cloudpipe/cloudpickle.git@0d225a4695f1f65ae1cbb2e0bbc145e10167cce4 # We use the latest version of cloudpickle because it can serialize named tuples. RUN adduser --gecos --ingroup ray-user --disabled-login --gecos ray-user RUN adduser ray-user sudo RUN sed -i "s|%sudo\tALL=(ALL:ALL) ALL|%sudo\tALL=NOPASSWD: ALL|" /etc/sudoers diff --git a/docker/devel/Dockerfile b/docker/devel/Dockerfile index d2f35461a..c99055c82 100644 --- a/docker/devel/Dockerfile +++ b/docker/devel/Dockerfile @@ -6,7 +6,8 @@ RUN apt-get update RUN apt-get -y install apt-utils RUN apt-get -y install sudo RUN apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz -RUN pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle +RUN pip install ipython funcsigs subprocess32 protobuf colorama graphviz +RUN pip install git+git://github.com/cloudpipe/cloudpickle.git@0d225a4695f1f65ae1cbb2e0bbc145e10167cce4 # We use the latest version of cloudpickle because it can serialize named tuples. RUN adduser --gecos --ingroup ray-user --disabled-login --gecos ray-user --uid 500 RUN adduser ray-user sudo RUN sed -i "s|%sudo\tALL=(ALL:ALL) ALL|%sudo\tALL=NOPASSWD: ALL|" /etc/sudoers diff --git a/docker/test-base/Dockerfile b/docker/test-base/Dockerfile index 66deac4b7..d011f2438 100644 --- a/docker/test-base/Dockerfile +++ b/docker/test-base/Dockerfile @@ -7,7 +7,8 @@ RUN apt-get update RUN apt-get -y install apt-utils RUN apt-get -y install sudo RUN apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz -RUN pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle +RUN pip install ipython funcsigs subprocess32 protobuf colorama graphviz +RUN pip install git+git://github.com/cloudpipe/cloudpickle.git@0d225a4695f1f65ae1cbb2e0bbc145e10167cce4 # We use the latest version of cloudpickle because it can serialize named tuples. RUN adduser --gecos --ingroup ray-user --disabled-login --gecos ray-user RUN adduser ray-user sudo RUN sed -i "s|%sudo\tALL=(ALL:ALL) ALL|%sudo\tALL=NOPASSWD: ALL|" /etc/sudoers diff --git a/install-dependencies.sh b/install-dependencies.sh index be91454b7..815f930a2 100755 --- a/install-dependencies.sh +++ b/install-dependencies.sh @@ -31,11 +31,13 @@ if [[ $platform == "linux" ]]; then # These commands must be kept in sync with the installation instructions. sudo apt-get update sudo apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz - sudo pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle + sudo pip install ipython funcsigs subprocess32 protobuf colorama graphviz + sudo pip install git+git://github.com/cloudpipe/cloudpickle.git@0d225a4695f1f65ae1cbb2e0bbc145e10167cce4 # We use the latest version of cloudpickle because it can serialize named tuples. elif [[ $platform == "macosx" ]]; then # These commands must be kept in sync with the installation instructions. brew install git cmake automake autoconf libtool boost graphviz sudo easy_install pip sudo pip install ipython --user - sudo pip install numpy funcsigs subprocess32 protobuf colorama graphviz cloudpickle --ignore-installed six + sudo pip install numpy funcsigs subprocess32 protobuf colorama graphviz --ignore-installed six + sudo pip install git+git://github.com/cloudpipe/cloudpickle.git@0d225a4695f1f65ae1cbb2e0bbc145e10167cce4 # We use the latest version of cloudpickle because it can serialize named tuples. fi diff --git a/lib/python/ray/__init__.py b/lib/python/ray/__init__.py index b2a843eea..d00b54ef1 100644 --- a/lib/python/ray/__init__.py +++ b/lib/python/ray/__init__.py @@ -11,7 +11,7 @@ if hasattr(ctypes, "windll"): import config import serialization -from worker import scheduler_info, visualize_computation_graph, task_info, init, connect, disconnect, get, put, select, remote, kill_workers, restart_workers_local +from worker import scheduler_info, register_class, visualize_computation_graph, task_info, init, connect, disconnect, get, put, select, remote, kill_workers, restart_workers_local from worker import Reusable, reusables from libraylib import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE from libraylib import ObjectID diff --git a/lib/python/ray/array/distributed/core.py b/lib/python/ray/array/distributed/core.py index 4e6b41ef6..a17e9f03e 100644 --- a/lib/python/ray/array/distributed/core.py +++ b/lib/python/ray/array/distributed/core.py @@ -16,14 +16,6 @@ class DistArray(object): if self.num_blocks != list(self.objectids.shape): raise Exception("The fields `num_blocks` and `objectids` are inconsistent, `num_blocks` is {} and `objectids` has shape {}".format(self.num_blocks, list(self.objectids.shape))) - @staticmethod - def deserialize(primitives): - (shape, objectids) = primitives - return DistArray(shape, objectids) - - def serialize(self): - return (self.shape, self.objectids) - @staticmethod def compute_block_lower(index, shape): if len(index) != len(shape): diff --git a/lib/python/ray/serialization.py b/lib/python/ray/serialization.py index 92f3053ac..65bae2358 100644 --- a/lib/python/ray/serialization.py +++ b/lib/python/ray/serialization.py @@ -1,42 +1,138 @@ -import importlib import numpy as np - +import pickling import libraylib as raylib +import libnumbuf -def to_primitive(obj): - if hasattr(obj, "serialize"): - primitive_obj = ((type(obj).__module__, type(obj).__name__), obj.serialize()) - else: - primitive_obj = ("primitive", obj) - return primitive_obj +def check_serializable(cls): + """Throws an exception if Ray cannot serialize this class efficiently. -def from_primitive(primitive_obj): - if primitive_obj[0] == "primitive": - obj = primitive_obj[1] + Args: + cls (type): The class to be serialized. + + Raises: + Exception: An exception is raised if Ray cannot serialize this class + efficiently. + """ + if is_named_tuple(cls): + # This case works. + return + if not hasattr(cls, "__new__"): + raise Exception("The class {} does not have a '__new__' attribute, and is probably an old-style class. We do not support this. Please either make it a new-style class by inheriting from 'object', or use 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls)) + try: + obj = cls.__new__(cls) + except: + raise Exception("The class {} has overridden '__new__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls)) + if not hasattr(obj, "__dict__"): + raise Exception("Objects of the class {} do not have a `__dict__` attribute, so Ray cannot serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls)) + if hasattr(obj, "__slots__"): + raise Exception("The class {} uses '__slots__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls)) + +# This field keeps track of a whitelisted set of classes that Ray will +# serialize. +whitelisted_classes = {} +classes_to_pickle = set() +custom_serializers = {} +custom_deserializers = {} + +def class_identifier(typ): + """Return a string that identifies this type.""" + return "{}.{}".format(typ.__module__, typ.__name__) + +def is_named_tuple(cls): + """Return True if cls is a namedtuple and False otherwise.""" + b = cls.__bases__ + if len(b) != 1 or b[0] != tuple: + return False + f = getattr(cls, "_fields", None) + if not isinstance(f, tuple): + return False + return all(type(n) == str for n in f) + +def add_class_to_whitelist(cls, pickle=False, custom_serializer=None, custom_deserializer=None): + """Add cls to the list of classes that we can serialize. + + Args: + cls (type): The class that we can serialize. + pickle (bool): True if the serialization should be done with pickle. False + if it should be done efficiently with Ray. + custom_serializer: This argument is optional, but can be provided to + serialize objects of the class in a particular way. + custom_deserializer: This argument is optional, but can be provided to + deserialize objects of the class in a particular way. + """ + class_id = class_identifier(cls) + whitelisted_classes[class_id] = cls + if pickle: + classes_to_pickle.add(class_id) + if custom_serializer is not None: + custom_serializers[class_id] = custom_serializer + custom_deserializers[class_id] = custom_deserializer + +# Here we define a custom serializer and deserializer for handling numpy +# arrays that contain objects. +def array_custom_serializer(obj): + return obj.tolist(), obj.dtype.str +def array_custom_deserializer(serialized_obj): + return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1])) +add_class_to_whitelist(np.ndarray, pickle=False, custom_serializer=array_custom_serializer, custom_deserializer=array_custom_deserializer) + +def serialize(obj): + """This is the callback that will be used by numbuf. + + If numbuf does not know how to serialize an object, it will call this method. + + Args: + obj (object): A Python object. + + Returns: + A dictionary that has the key "_pyttype_" to identify the class, and + contains all information needed to reconstruct the object. + """ + class_id = class_identifier(type(obj)) + if class_id not in whitelisted_classes: + raise Exception("Ray does not know how to serialize the object {}. To fix this, call 'ray.register_class' on the class of the object.".format(obj)) + if class_id in classes_to_pickle: + serialized_obj = {"data": pickling.dumps(obj)} + elif class_id in custom_serializers.keys(): + serialized_obj = {"data": custom_serializers[class_id](obj)} else: - # This code assumes that the type module.__dict__[type_name] knows how to deserialize itself - type_module, type_name = primitive_obj[0] - module = importlib.import_module(type_module) - obj = module.__dict__[type_name].deserialize(primitive_obj[1]) + if not hasattr(obj, "__dict__"): + raise Exception("We do not know how to serialize the object '{}'".format(obj)) + serialized_obj = obj.__dict__ + if is_named_tuple(type(obj)): + # Handle the namedtuple case. + serialized_obj["_ray_getnewargs_"] = obj.__getnewargs__() + result = dict(serialized_obj, **{"_pytype_": class_id}) + return result + +def deserialize(serialized_obj): + """This is the callback that will be used by numbuf. + + If numbuf encounters a dictionary that contains the key "_pytype_" during + deserialization, it will ask this callback to deserialize the object. + + Args: + serialized_obj (object): A dictionary that contains the key "_pytype_". + + Returns: + A Python object. + """ + class_id = serialized_obj["_pytype_"] + cls = whitelisted_classes[class_id] + if class_id in classes_to_pickle: + obj = pickling.loads(serialized_obj["data"]) + elif class_id in custom_deserializers.keys(): + obj = custom_deserializers[class_id](serialized_obj["data"]) + else: + # In this case, serialized_obj should just be the __dict__ field. + if "_ray_getnewargs_" in serialized_obj: + obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"]) + serialized_obj.pop("_ray_getnewargs_") + else: + obj = cls.__new__(cls) + serialized_obj.pop("_pytype_") + obj.__dict__.update(serialized_obj) return obj -def is_arrow_serializable(value): - return isinstance(value, np.ndarray) and value.dtype.name in ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "float32", "float64"] - -def serialize(worker_capsule, obj): - primitive_obj = to_primitive(obj) - obj_capsule, contained_objectids = raylib.serialize_object(worker_capsule, primitive_obj) # contained_objectids is a list of the objectids contained in obj - return obj_capsule, contained_objectids - -def deserialize(worker_capsule, capsule): - primitive_obj = raylib.deserialize_object(worker_capsule, capsule) - return from_primitive(primitive_obj) - -def serialize_task(worker_capsule, func_name, args): - primitive_args = [(arg if isinstance(arg, raylib.ObjectID) else to_primitive(arg)) for arg in args] - return raylib.serialize_task(worker_capsule, func_name, primitive_args) - -def deserialize_task(worker_capsule, task): - func_name, primitive_args, return_objectids = task - args = [(arg if isinstance(arg, raylib.ObjectID) else from_primitive(arg)) for arg in primitive_args] - return func_name, args, return_objectids +# Register the callbacks with numbuf. +libnumbuf.register_callbacks(serialize, deserialize) diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 2ebdb1063..051d63c08 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -22,13 +22,31 @@ import services import libnumbuf import libraylib as raylib +contained_objectids = [] +def numbuf_serialize(value): + """This serializes a value and tracks the object IDs inside the value. + + We also define a custom ObjectID serializer which also closes over the global + variable contained_objectids, and whenever the custom serializer is called, it + adds the releevant ObjectID to the list contained_objectids. The list + contained_objectids should be reset between calls to numbuf_serialize. + + Args: + value: A Python object that will be serialized. + + Returns: + The serialized object. + """ + assert len(contained_objectids) == 0, "This should be unreachable." + return libnumbuf.serialize_list([value]) + class RayTaskError(Exception): """An object used internally to represent a task that threw an exception. If a task throws an exception during execution, a RayTaskError is stored in the object store for each of the task's outputs. When an object is retrieved from the object store, the Python method that retrieved it checks to see if - the object is a RayTaskError and if it is then an exceptionis thrown + the object is a RayTaskError and if it is then an exception is thrown propagating the error message. Currently, we either use the exception attribute or the traceback attribute @@ -50,32 +68,6 @@ class RayTaskError(Exception): self.exception = None self.traceback_str = traceback_str - @staticmethod - def deserialize(primitives): - """Create a RayTaskError from a primitive object.""" - function_name, exception, traceback_str = primitives - if exception[0] == "RayGetError": - exception = RayGetError.deserialize(exception[1]) - elif exception[0] == "RayGetArgumentError": - exception = RayGetArgumentError.deserialize(exception[1]) - elif exception[0] == "None": - exception = None - else: - assert False, "This code should be unreachable." - return RayTaskError(function_name, exception, traceback_str) - - def serialize(self): - """Turn a RayTaskError into a primitive object.""" - if isinstance(self.exception, RayGetError): - serialized_exception = ("RayGetError", self.exception.serialize()) - elif isinstance(self.exception, RayGetArgumentError): - serialized_exception = ("RayGetArgumentError", self.exception.serialize()) - elif self.exception is None: - serialized_exception = ("None",) - else: - assert False, "This code should be unreachable." - return (self.function_name, serialized_exception, self.traceback_str) - def __str__(self): """Format a RayTaskError as a string.""" if self.traceback_str is None: @@ -99,16 +91,6 @@ class RayGetError(Exception): self.objectid = objectid self.task_error = task_error - @staticmethod - def deserialize(primitives): - """Create a RayGetError from a primitive object.""" - objectid, task_error = primitives - return RayGetError(objectid, RayTaskError.deserialize(task_error)) - - def serialize(self): - """Turn a RayGetError into a primitive object.""" - return (self.objectid, self.task_error.serialize()) - def __str__(self): """Format a RayGetError as a string.""" return "Could not get objectid {}. It was created by remote function {}{}{} which failed with:\n\n{}".format(self.objectid, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error) @@ -132,16 +114,6 @@ class RayGetArgumentError(Exception): self.objectid = objectid self.task_error = task_error - @staticmethod - def deserialize(primitives): - """Create a RayGetArgumentError from a primitive object.""" - function_name, argument_index, objectid, task_error = primitives - return RayGetArgumentError(function_name, argument_index, objectid, RayTaskError.deserialize(task_error)) - - def serialize(self): - """Turn a RayGetArgumentError into a primitive object.""" - return (self.function_name, self.argument_index, self.objectid, self.task_error.serialize()) - def __str__(self): """Format a RayGetArgumentError as a string.""" return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error) @@ -366,30 +338,27 @@ class Worker(object): objectid (raylib.ObjectID): The object ID of the value to be put. value (serializable object): The value to put in the object store. """ - try: - # We put the value into a list here because in arrow the concept of - # "serializing a single object" does not exits. - schema, size, serialized = libnumbuf.serialize_list([value]) - # TODO(pcm): Right now, metadata is serialized twice, change that in the future - # in the following line, the "8" is for storing the metadata size, - # the len(schema) is for storing the metadata and the 4096 is for storing - # the metadata in the batch (see INITIAL_METADATA_SIZE in arrow) - size = size + 8 + len(schema) + 4096 - buff, segmentid = raylib.allocate_buffer(self.handle, objectid, size) - # write the metadata length - np.frombuffer(buff, dtype="int64", count=1)[0] = len(schema) - # metadata buffer - metadata = np.frombuffer(buff, dtype="byte", offset=8, count=len(schema)) - # write the metadata - metadata[:] = schema - data = np.frombuffer(buff, dtype="byte")[8 + len(schema):] - metadata_offset = libnumbuf.write_to_buffer(serialized, memoryview(data)) - raylib.finish_buffer(self.handle, objectid, segmentid, metadata_offset) - except: - # At the moment, custom object and objects that contain object IDs take this path - # TODO(pcm): Make sure that these are the only objects getting serialized to protobuf - object_capsule, contained_objectids = serialization.serialize(self.handle, value) # contained_objectids is a list of the objectids contained in object_capsule - raylib.put_object(self.handle, objectid, object_capsule, contained_objectids) + # We put the value into a list here because in arrow the concept of + # "serializing a single object" does not exits. + schema, size, serialized = numbuf_serialize(value) + global contained_objectids + raylib.add_contained_objectids(self.handle, objectid, contained_objectids) + contained_objectids = [] + # TODO(pcm): Right now, metadata is serialized twice, change that in the future + # in the following line, the "8" is for storing the metadata size, + # the len(schema) is for storing the metadata and the 8192 is for storing + # the metadata in the batch (see INITIAL_METADATA_SIZE in arrow) + size = size + 8 + len(schema) + 4096 + buff, segmentid = raylib.allocate_buffer(self.handle, objectid, size) + # write the metadata length + np.frombuffer(buff, dtype="int64", count=1)[0] = len(schema) + # metadata buffer + metadata = np.frombuffer(buff, dtype="byte", offset=8, count=len(schema)) + # write the metadata + metadata[:] = schema + data = np.frombuffer(buff, dtype="byte")[8 + len(schema):] + metadata_offset = libnumbuf.write_to_buffer(serialized, memoryview(data)) + raylib.finish_buffer(self.handle, objectid, segmentid, metadata_offset) def get_object(self, objectid): """Get the value in the local object store associated with objectid. @@ -400,32 +369,25 @@ class Worker(object): Args: objectid (raylib.ObjectID): The object ID of the value to retrieve. """ - if raylib.is_arrow(self.handle, objectid): - ## this is the new codepath - buff, segmentid, metadata_offset = raylib.get_buffer(self.handle, objectid) - metadata_size = np.frombuffer(buff, dtype="int64", count=1)[0] - metadata = np.frombuffer(buff, dtype="byte", offset=8, count=metadata_size) - data = np.frombuffer(buff, dtype="byte")[8 + metadata_size:] - serialized = libnumbuf.read_from_buffer(memoryview(data), bytearray(metadata), metadata_offset) - # If there is currently no ObjectFixture for this ObjectID, then create a - # new one. The object_fixtures object is a WeakValueDictionary, so entries - # will be discarded when there are no strong references to their values. - # We create object_fixture outside of the assignment because if we created - # it inside the assignement it would immediately go out of scope. - object_fixture = None - if objectid.id not in object_fixtures: - object_fixture = ObjectFixture(objectid, segmentid, self.handle) - object_fixtures[objectid.id] = object_fixture - deserialized = libnumbuf.deserialize_list(serialized, object_fixtures[objectid.id]) - # Unwrap the object from the list (it was wrapped put_object) - assert len(deserialized) == 1 - result = deserialized[0] - ## this is the old codepath - # result, segmentid = raylib.get_arrow(self.handle, objectid) - else: - object_capsule, segmentid = raylib.get_object(self.handle, objectid) - result = serialization.deserialize(self.handle, object_capsule) - + assert raylib.is_arrow(self.handle, objectid), "All objects should be serialized using Arrow." + buff, segmentid, metadata_offset = raylib.get_buffer(self.handle, objectid) + metadata_size = np.frombuffer(buff, dtype="int64", count=1)[0] + metadata = np.frombuffer(buff, dtype="byte", offset=8, count=metadata_size) + data = np.frombuffer(buff, dtype="byte")[8 + metadata_size:] + serialized = libnumbuf.read_from_buffer(memoryview(data), bytearray(metadata), metadata_offset) + # If there is currently no ObjectFixture for this ObjectID, then create a + # new one. The object_fixtures object is a WeakValueDictionary, so entries + # will be discarded when there are no strong references to their values. + # We create object_fixture outside of the assignment because if we created + # it inside the assignement it would immediately go out of scope. + object_fixture = None + if objectid.id not in object_fixtures: + object_fixture = ObjectFixture(objectid, segmentid, self.handle) + object_fixtures[objectid.id] = object_fixture + deserialized = libnumbuf.deserialize_list(serialized, object_fixtures[objectid.id]) + # Unwrap the object from the list (it was wrapped put_object) + assert len(deserialized) == 1 + result = deserialized[0] return result def alias_objectids(self, alias_objectid, target_objectid): @@ -445,7 +407,10 @@ class Worker(object): be object IDs or they can be values. If they are values, they must be serializable objecs. """ - task_capsule = serialization.serialize_task(self.handle, func_name, args) + # Convert all of the argumens to object IDs. It is a little strange that we + # are calling put, which is external to this class. + args = [arg if isinstance(arg, raylib.ObjectID) else put(arg, worker=self) for arg in args] + task_capsule = raylib.serialize_task(self.handle, func_name, args) objectids = raylib.submit_task(self.handle, task_capsule) return objectids @@ -461,10 +426,13 @@ class Worker(object): not take any arguments. If it returns anything, its return values will not be used. """ + if self.mode not in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]: + raise Exception("run_function_on_all_workers can only be called on a driver.") # First run the function on the driver. function(self) # Then run the function on all of the workers. - raylib.run_function_on_all_workers(self.handle, pickling.dumps(function)) + if self.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]: + raylib.run_function_on_all_workers(self.handle, pickling.dumps(function)) global_worker = Worker() """Worker: The global Worker object for this worker process. @@ -568,6 +536,28 @@ def task_info(worker=global_worker): check_connected(worker) return raylib.task_info(worker.handle) +def initialize_numbuf(worker=global_worker): + """Initialize the serialization library. + + This defines a custom serializer for object IDs and also tells numbuf to + serialize several exception classes that we define for error handling. + """ + # Define a custom serializer and deserializer for handling Object IDs. + def objectid_custom_serializer(obj): + class_identifier = serialization.class_identifier(type(obj)) + contained_objectids.append(obj) + return raylib.serialize_objectid(worker.handle, obj) + def objectid_custom_deserializer(serialized_obj): + return raylib.deserialize_objectid(worker.handle, serialized_obj) + serialization.add_class_to_whitelist(raylib.ObjectID, pickle=False, custom_serializer=objectid_custom_serializer, custom_deserializer=objectid_custom_deserializer) + + if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]: + # These should only be called on the driver because register_class will + # export the class to all of the workers. + register_class(RayTaskError) + register_class(RayGetError) + register_class(RayGetArgumentError) + def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_address=None, node_ip_address=None, driver_mode=raylib.SCRIPT_MODE): """Either connect to an existing Ray cluster or start one and connect to it. @@ -735,14 +725,16 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl # the same. script_directory = os.path.abspath(os.path.dirname(sys.argv[0])) current_directory = os.path.abspath(os.path.curdir) - worker.run_function_on_all_workers(lambda worker : sys.path.insert(1, script_directory)) - worker.run_function_on_all_workers(lambda worker : sys.path.insert(1, current_directory)) + worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, script_directory)) + worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, current_directory)) # Export cached remote functions to the workers. for function_name, function_to_export in worker.cached_remote_functions: raylib.export_remote_function(worker.handle, function_name, function_to_export) # Export cached reusable variables to the workers. for name, reusable_variable in reusables._cached_reusables: _export_reusable_variable(name, reusable_variable) + # Initialize the serialization library. + initialize_numbuf() worker.cached_remote_functions = None reusables._cached_reusables = None @@ -757,6 +749,30 @@ def disconnect(worker=global_worker): worker.cached_remote_functions = [] reusables._cached_reusables = [] +def register_class(cls, pickle=False, worker=global_worker): + """Enable workers to serialize or deserialize objects of a particular class. + + This method runs the register_class function defined below on every worker, + which will enable libnumbuf to properly serialize and deserialize objects of + this class. + + Args: + cls (type): The class that libnumbuf should serialize. + pickle (bool): If False then objects of this class will be serialized by + turning their __dict__ fields into a dictionary. If True, then objects + of this class will be serialized using pickle. + + Raises: + Exception: An exception is raised if pickle=False and the class cannot be + efficiently serialized by Ray. + """ + # Raise an exception if cls cannot be serialized efficiently by Ray. + if not pickle: + serialization.check_serializable(cls) + def register_class_for_serialization(worker): + serialization.add_class_to_whitelist(cls, pickle=pickle) + worker.run_function_on_all_workers(register_class_for_serialization) + def get(objectid, worker=global_worker): """Get a remote object or a list of remote objects from the object store. @@ -915,7 +931,7 @@ def main_loop(worker=global_worker): After the task executes, the worker resets any reusable variables that were accessed by the task. """ - function_name, args, return_objectids = serialization.deserialize_task(worker.handle, task) + function_name, args, return_objectids = task try: arguments = get_arguments_for_execution(worker.functions[function_name], args, worker) # get args from objstore outputs = worker.functions[function_name].executor(arguments) # execute the function diff --git a/protos/graph.proto b/protos/graph.proto index 1ca6e7530..a1044899b 100644 --- a/protos/graph.proto +++ b/protos/graph.proto @@ -1,10 +1,8 @@ syntax = "proto3"; -import "types.proto"; - message Task { string name = 1; // Name of the function call. Must not be empty. - repeated Value arg = 2; // List of arguments, can be either object IDs or protobuf descriptions of object passed by value + repeated uint64 arg = 2; // List of object IDs of the arguments to the function. repeated uint64 result = 3; // Object IDs for result } diff --git a/protos/types.proto b/protos/types.proto index 093c92a9d..ecb16a321 100644 --- a/protos/types.proto +++ b/protos/types.proto @@ -106,11 +106,6 @@ message Dict { repeated DictEntry elem = 1; } -message Value { - uint64 id = 1; // For pass by object ID - Obj obj = 2; // For pass by value -} - message Array { repeated uint64 shape = 1; sint64 dtype = 2; diff --git a/src/raylib.cc b/src/raylib.cc index c7cf1b217..2604175bf 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -37,6 +37,7 @@ static void PyObjectID_dealloc(PyObjectID *self) { PyObjectToWorker(self->worker_capsule, &worker); std::vector objectids; objectids.push_back(self->id); + RAY_LOG(RAY_REFCOUNT, "In PyObjectID_dealloc, calling decrement_reference_count for objectid " << self->id); worker->decrement_reference_count(objectids); Py_DECREF(self->worker_capsule); // The corresponding increment happens in PyObjectID_init. self->ob_type->tp_free((PyObject*) self); @@ -476,28 +477,24 @@ static PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vect } } -// This returns the serialized object and a list of the object references contained in that object. -static PyObject* serialize_object(PyObject* self, PyObject* args) { - Obj* obj = new Obj(); // TODO: to be freed in capsul destructor - PyObject* worker_capsule; - PyObject* pyval; - if (!PyArg_ParseTuple(args, "OO", &worker_capsule, &pyval)) { - return NULL; - } - std::vector objectids; - if (serialize(worker_capsule, pyval, obj, objectids) != 0) { - return NULL; - } +// This converts an Python ObjectID to an Python integer. +static PyObject* serialize_objectid(PyObject* self, PyObject* args) { Worker* worker; - PyObjectToWorker(worker_capsule, &worker); - PyObject* contained_objectids = PyList_New(objectids.size()); - for (int i = 0; i < objectids.size(); ++i) { - PyList_SetItem(contained_objectids, i, make_pyobjectid(worker_capsule, objectids[i])); + ObjectID objectid; + if (!PyArg_ParseTuple(args, "O&O&", &PyObjectToWorker, &worker, &PyObjectToObjectID, &objectid)) { + return NULL; } - PyObject* t = PyTuple_New(2); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple. - PyTuple_SetItem(t, 0, PyCapsule_New(static_cast(obj), "obj", &ObjCapsule_Destructor)); - PyTuple_SetItem(t, 1, contained_objectids); - return t; + return PyInt_FromLong(objectid); +} + +// This converts a Python integer to a Python ObjectID. +static PyObject* deserialize_objectid(PyObject* self, PyObject* args) { + PyObject* worker_capsule; + int objectid; + if (!PyArg_ParseTuple(args, "Oi", &worker_capsule, &objectid)) { + return NULL; + } + return make_pyobjectid(worker_capsule, static_cast(objectid)); } static PyObject* allocate_buffer(PyObject* self, PyObject* args) { @@ -567,17 +564,6 @@ static PyObject* unmap_object(PyObject* self, PyObject* args) { Py_RETURN_NONE; } -static PyObject* deserialize_object(PyObject* self, PyObject* args) { - PyObject* worker_capsule; - Obj* obj; - if (!PyArg_ParseTuple(args, "OO&", &worker_capsule, &PyObjectToObj, &obj)) { - return NULL; - } - std::vector objectids; // This is a vector of all the objectids that are serialized in this task, including objectids that are contained in Python objects that are passed by value. - return deserialize(worker_capsule, *obj, objectids); - // TODO(rkn): Should we do anything with objectids? -} - static PyObject* serialize_task(PyObject* self, PyObject* args) { PyObject* worker_capsule; Task* task = new Task(); // TODO: to be freed in capsule destructor @@ -592,14 +578,9 @@ static PyObject* serialize_task(PyObject* self, PyObject* args) { if (PyList_Check(arguments)) { for (size_t i = 0, size = PyList_Size(arguments); i < size; ++i) { PyObject* element = PyList_GetItem(arguments, i); - if (PyObject_IsInstance(element, (PyObject*)&PyObjectIDType)) { - ObjectID objectid = ((PyObjectID*) element)->id; - task->add_arg()->set_id(objectid); - objectids.push_back(objectid); - } else { - Obj* arg = task->add_arg()->mutable_obj(); - serialize(worker_capsule, PyList_GetItem(arguments, i), arg, objectids); - } + ObjectID objectid = ((PyObjectID*) element)->id; + task->add_arg(objectid); + objectids.push_back(objectid); } } else { PyErr_SetString(RayError, "serialize_task: second argument needs to be a list"); @@ -634,13 +615,8 @@ static PyObject* deserialize_task(PyObject* worker_capsule, const Task& task) { int argsize = task.arg_size(); PyObject* arglist = PyList_New(argsize); for (int i = 0; i < argsize; ++i) { - const Value& val = task.arg(i); - if (!val.has_obj()) { - PyList_SetItem(arglist, i, make_pyobjectid(worker_capsule, val.id())); - objectids.push_back(val.id()); - } else { - PyList_SetItem(arglist, i, deserialize(worker_capsule, val.obj(), objectids)); - } + PyList_SetItem(arglist, i, make_pyobjectid(worker_capsule, task.arg(i))); + objectids.push_back(task.arg(i)); } Worker* worker; PyObjectToWorker(worker_capsule, &worker); @@ -869,12 +845,11 @@ static PyObject* get_objectid(PyObject* self, PyObject* args) { return make_pyobjectid(worker_capsule, objectid); } -static PyObject* put_object(PyObject* self, PyObject* args) { +static PyObject* add_contained_objectids(PyObject* self, PyObject* args) { Worker* worker; ObjectID objectid; - Obj* obj; PyObject* contained_objectids; - if (!PyArg_ParseTuple(args, "O&O&O&O", &PyObjectToWorker, &worker, &PyObjectToObjectID, &objectid, &PyObjectToObj, &obj, &contained_objectids)) { + if (!PyArg_ParseTuple(args, "O&O&O", &PyObjectToWorker, &worker, &PyObjectToObjectID, &objectid, &contained_objectids)) { return NULL; } RAY_CHECK(PyList_Check(contained_objectids), "The contained_objectids argument must be a list.") @@ -885,7 +860,7 @@ static PyObject* put_object(PyObject* self, PyObject* args) { PyObjectToObjectID(PyList_GetItem(contained_objectids, i), &contained_objectid); vec_contained_objectids.push_back(contained_objectid); } - worker->put_object(objectid, obj, vec_contained_objectids); + worker->add_contained_objectids(objectid, vec_contained_objectids); Py_RETURN_NONE; } @@ -1088,8 +1063,8 @@ static PyObject* kill_workers(PyObject* self, PyObject* args) { } static PyMethodDef RayLibMethods[] = { - { "serialize_object", serialize_object, METH_VARARGS, "serialize an object to protocol buffers" }, - { "deserialize_object", deserialize_object, METH_VARARGS, "deserialize an object from protocol buffers" }, + { "serialize_objectid", serialize_objectid, METH_VARARGS, "serialize an object id" }, + { "deserialize_objectid", deserialize_objectid, METH_VARARGS, "deserialize an object id" }, { "allocate_buffer", allocate_buffer, METH_VARARGS, "Allocates and returns buffer for objectid."}, { "finish_buffer", finish_buffer, METH_VARARGS, "Makes the buffer immutable and closes memory segment of objectid."}, { "get_buffer", get_buffer, METH_VARARGS, "Gets buffer for objectid"}, @@ -1101,7 +1076,7 @@ static PyMethodDef RayLibMethods[] = { { "connected", connected, METH_VARARGS, "check if the worker is connected to the scheduler and the object store" }, { "register_remote_function", register_remote_function, METH_VARARGS, "register a function with the scheduler" }, { "notify_failure", notify_failure, METH_VARARGS, "notify the scheduler of a failure" }, - { "put_object", put_object, METH_VARARGS, "put a protocol buffer object (given as a capsule) on the local object store" }, + { "add_contained_objectids", add_contained_objectids, METH_VARARGS, "notify the scheduler about the object IDs contained in a remote object" }, { "get_object", get_object, METH_VARARGS, "get protocol buffer object from the local object store" }, { "get_objectid", get_objectid, METH_VARARGS, "register a new object reference with the scheduler" }, { "request_object" , request_object, METH_VARARGS, "request an object to be delivered to the local object store" }, diff --git a/src/scheduler.cc b/src/scheduler.cc index 9e5a3cfd1..b25630abf 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -673,15 +673,13 @@ void SchedulerService::assign_task(OperationId operationid, WorkerId workerid, c AckReply reply; RAY_LOG(RAY_INFO, "starting to send arguments"); for (size_t i = 0; i < task.arg_size(); ++i) { - if (!task.arg(i).has_obj()) { - ObjectID objectid = task.arg(i).id(); - ObjectID canonical_objectid = get_canonical_objectid(objectid); - // Notify the relevant objstore about potential aliasing when it's ready - GET(alias_notification_queue_)->push_back(std::make_pair(objstoreid, std::make_pair(objectid, canonical_objectid))); - attempt_notify_alias(objstoreid, objectid, canonical_objectid); - RAY_LOG(RAY_DEBUG, "task contains object ref " << canonical_objectid); - deliver_object_async_if_necessary(canonical_objectid, pick_objstore(canonical_objectid), objstoreid); - } + ObjectID objectid = task.arg(i); + ObjectID canonical_objectid = get_canonical_objectid(objectid); + // Notify the relevant objstore about potential aliasing when it's ready + GET(alias_notification_queue_)->push_back(std::make_pair(objstoreid, std::make_pair(objectid, canonical_objectid))); + attempt_notify_alias(objstoreid, objectid, canonical_objectid); + RAY_LOG(RAY_DEBUG, "task contains object ref " << canonical_objectid); + deliver_object_async_if_necessary(canonical_objectid, pick_objstore(canonical_objectid), objstoreid); } { auto workers = GET(workers_); @@ -694,15 +692,13 @@ void SchedulerService::assign_task(OperationId operationid, WorkerId workerid, c bool SchedulerService::can_run(const Task& task) { auto objtable = GET(objtable_); for (int i = 0; i < task.arg_size(); ++i) { - if (!task.arg(i).has_obj()) { - ObjectID objectid = task.arg(i).id(); - if (!has_canonical_objectid(objectid)) { - return false; - } - ObjectID canonical_objectid = get_canonical_objectid(objectid); - if (canonical_objectid >= objtable->size() || (*objtable)[canonical_objectid].size() == 0) { - return false; - } + ObjectID objectid = task.arg(i); + if (!has_canonical_objectid(objectid)) { + return false; + } + ObjectID canonical_objectid = get_canonical_objectid(objectid); + if (canonical_objectid >= objtable->size() || (*objtable)[canonical_objectid].size() == 0) { + return false; } } return true; @@ -939,16 +935,14 @@ void SchedulerService::schedule_tasks_location_aware() { // determine how many objects would need to be shipped size_t num_shipped_objects = 0; for (int j = 0; j < task.arg_size(); ++j) { - if (!task.arg(j).has_obj()) { - ObjectID objectid = task.arg(j).id(); - RAY_CHECK(has_canonical_objectid(objectid), "no canonical object ref found even though task is ready; that should not be possible!"); - ObjectID canonical_objectid = get_canonical_objectid(objectid); - { - // check if the object is already in the local object store - auto objtable = GET(objtable_); - if (!std::binary_search((*objtable)[canonical_objectid].begin(), (*objtable)[canonical_objectid].end(), objstoreid)) { - num_shipped_objects += 1; - } + ObjectID objectid = task.arg(j); + RAY_CHECK(has_canonical_objectid(objectid), "no canonical object ref found even though task is ready; that should not be possible!"); + ObjectID canonical_objectid = get_canonical_objectid(objectid); + { + // check if the object is already in the local object store + auto objtable = GET(objtable_); + if (!std::binary_search((*objtable)[canonical_objectid].begin(), (*objtable)[canonical_objectid].end(), objstoreid)) { + num_shipped_objects += 1; } } } @@ -1059,6 +1053,9 @@ void SchedulerService::deallocate_object(ObjectID canonical_objectid, const MySy } locations.clear(); } + // Decrement the reference count for all of the object IDs contained in this + // object. The corresponding increments happen in add_contained_objectids in + // worker.cc. decrement_ref_count((*contained_objectids)[canonical_objectid], reference_counts, contained_objectids); } diff --git a/src/worker.cc b/src/worker.cc index 917d18ef3..3fb3b40b2 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -231,42 +231,24 @@ slice Worker::get_object(ObjectID objectid) { return slice; } -// TODO(pcm): More error handling -// contained_objectids is a vector of all the objectids contained in obj -void Worker::put_object(ObjectID objectid, const Obj* obj, std::vector &contained_objectids) { - RAY_CHECK(connected_, "Attempted to perform put_object but failed."); - std::string data; - obj->SerializeToString(&data); // TODO(pcm): get rid of this serialization - ObjRequest request; - request.workerid = workerid_; - request.type = ObjRequestType::ALLOC; - request.objectid = objectid; - request.size = data.size(); - RAY_CHECK(request_obj_queue_.send(&request), "error sending over IPC"); +void Worker::add_contained_objectids(ObjectID objectid, std::vector &contained_objectids) { + RAY_CHECK(connected_, "Attempted to perform add_contained_objectids but failed."); if (contained_objectids.size() > 0) { - RAY_LOG(RAY_REFCOUNT, "In put_object, calling increment_reference_count for contained objectids"); - increment_reference_count(contained_objectids); // Notify the scheduler that some object references are serialized in the objstore. + RAY_LOG(RAY_REFCOUNT, "In add_contained_objectids, calling increment_reference_count for contained objectids"); + // Notify the scheduler that some object references are serialized in the + // objstore. The corresponding decrement happens when the object + // corresponding to objectid is deallocated. + increment_reference_count(contained_objectids); + // Notify the scheduler about the objectids that we are serializing in the objstore. + AddContainedObjectIDsRequest contained_objectids_request; + contained_objectids_request.set_objectid(objectid); + for (int i = 0; i < contained_objectids.size(); ++i) { + contained_objectids_request.add_contained_objectid(contained_objectids[i]); // TODO(rkn): The naming here is bad + } + AckReply reply; + ClientContext context; + RAY_CHECK_GRPC(scheduler_stub_->AddContainedObjectIDs(&context, contained_objectids_request, &reply)); } - ObjHandle result; - RAY_CHECK(receive_obj_queue_.receive(&result), "error receiving over IPC"); - uint8_t* target = segmentpool_->get_address(result); - std::memcpy(target, data.data(), data.size()); - // We immediately unmap here; if the object is going to be accessed again, it will be mapped again; - // This is reqired because we do not have a mechanism to unmap the object later. - segmentpool_->unmap_segment(result.segmentid()); - request.type = ObjRequestType::WORKER_DONE; - request.metadata_offset = 0; - RAY_CHECK(request_obj_queue_.send(&request), "Failed to send request from the worker to the object store because the message queue was full."); - - // Notify the scheduler about the objectids that we are serializing in the objstore. - AddContainedObjectIDsRequest contained_objectids_request; - contained_objectids_request.set_objectid(objectid); - for (int i = 0; i < contained_objectids.size(); ++i) { - contained_objectids_request.add_contained_objectid(contained_objectids[i]); // TODO(rkn): The naming here is bad - } - AckReply reply; - ClientContext context; - RAY_CHECK_GRPC(scheduler_stub_->AddContainedObjectIDs(&context, contained_objectids_request, &reply)); } #define CHECK_ARROW_STATUS(s, msg) \ diff --git a/src/worker.h b/src/worker.h index b7a86141a..3e2366e5b 100644 --- a/src/worker.h +++ b/src/worker.h @@ -62,8 +62,8 @@ class Worker { ObjectID get_objectid(); // request an object to be delivered to the local object store void request_object(ObjectID objectid); - // stores an object to the local object store - void put_object(ObjectID objectid, const Obj* obj, std::vector &contained_objectids); + // Notify the scheduler about the object IDs contained within a remote object. + void add_contained_objectids(ObjectID objectid, std::vector &contained_objectids); // retrieve serialized object from local object store slice get_object(ObjectID objectid); // Allocates buffer for objectid with size of size diff --git a/test/array_test.py b/test/array_test.py index b4889e513..c33579d02 100644 --- a/test/array_test.py +++ b/test/array_test.py @@ -2,8 +2,6 @@ import unittest import ray import numpy as np import time -import subprocess32 as subprocess -import os from numpy.testing import assert_equal, assert_almost_equal import ray.array.remote as ra @@ -45,23 +43,11 @@ class RemoteArrayTest(unittest.TestCase): class DistributedArrayTest(unittest.TestCase): - def testSerialization(self): - for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: - reload(module) - ray.init(start_ray_local=True, num_workers=0) - - x = da.DistArray([2, 3, 4], np.array([[[ray.put(0)]]])) - capsule, _ = ray.serialization.serialize(ray.worker.global_worker.handle, x) - y = ray.serialization.deserialize(ray.worker.global_worker.handle, capsule) - self.assertEqual(x.shape, y.shape) - self.assertEqual(x.objectids[0, 0, 0].id, y.objectids[0, 0, 0].id) - - ray.worker.cleanup() - def testAssemble(self): for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: reload(module) ray.init(start_ray_local=True, num_workers=1) + ray.register_class(da.DistArray) a = ra.ones.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) b = ra.zeros.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) @@ -74,6 +60,7 @@ class DistributedArrayTest(unittest.TestCase): for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: reload(module) ray.init(start_ray_local=True, num_objstores=2, num_workers=10) + ray.register_class(da.DistArray) x = da.zeros.remote([9, 25, 51], "float") assert_equal(ray.get(da.assemble.remote(x)), np.zeros([9, 25, 51])) diff --git a/test/failure_test.py b/test/failure_test.py index 347af0024..78e6c6227 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -5,7 +5,6 @@ import time import test_functions class FailureTest(unittest.TestCase): - def testUnknownSerialization(self): reload(test_functions) ray.init(start_ray_local=True, num_workers=1, driver_mode=ray.SILENT_MODE) @@ -18,6 +17,27 @@ class FailureTest(unittest.TestCase): ray.worker.cleanup() +class TaskSerializationTest(unittest.TestCase): + def testReturnAndPassUnknownType(self): + ray.init(start_ray_local=True, num_workers=1, driver_mode=ray.SILENT_MODE) + + class Foo(object): + pass + # Check that returning an unknown type from a remote function raises an + # exception. + @ray.remote + def f(): + return Foo() + self.assertRaises(Exception, lambda : ray.get(f.remote())) + # Check that passing an unknown type into a remote function raises an + # exception. + @ray.remote + def g(x): + return 1 + self.assertRaises(Exception, lambda : g.remote(Foo())) + + ray.worker.cleanup() + class TaskStatusTest(unittest.TestCase): def testFailedTask(self): reload(test_functions) diff --git a/test/microbenchmarks.py b/test/microbenchmarks.py index 89cba5f6c..44251e53e 100644 --- a/test/microbenchmarks.py +++ b/test/microbenchmarks.py @@ -1,7 +1,6 @@ import unittest import ray import time -import os import numpy as np import test_functions diff --git a/test/runtest.py b/test/runtest.py index d0fc80dd5..f58ffb3c7 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -2,81 +2,105 @@ import unittest import ray import numpy as np import time -import subprocess32 as subprocess -import os +import string import sys -from numpy.testing import assert_equal +from collections import namedtuple +import libnumbuf import test_functions import ray.array.remote as ra import ray.array.distributed as da -RAY_TEST_OBJECTS = [[1, "hello", 3.0], 42, 43L, "hello world", 42.0, 1L << 62, - (1.0, "hi"), None, (None, None), ("hello", None), - True, False, (True, False), u"\u262F", - {True: "hello", False: "world"}, - {"hello" : "world", 1: 42, 1.0: 45}, {}, {(): ()}, - {(1, 2): 1}, {(): [1, 2, "hi"]}, (), [], [()], ((),), - np.int8(3), np.int32(4), np.int64(5), - np.uint8(3), np.uint32(4), np.uint64(5), - np.float32(1.0), np.float64(1.0)] +def assert_equal(obj1, obj2): + if type(obj1).__module__ == np.__name__ or type(obj2).__module__ == np.__name__: + if (hasattr(obj1, "shape") and obj1.shape == ()) or (hasattr(obj2, "shape") and obj2.shape == ()): + # This is a special case because currently np.testing.assert_equal fails + # because we do not properly handle different numerical types. + assert obj1 == obj2, "Objects {} and {} are different.".format(obj1, obj2) + else: + np.testing.assert_equal(obj1, obj2) + elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"): + special_keys = ["_pytype_"] + assert set(obj1.__dict__.keys() + special_keys) == set(obj2.__dict__.keys() + special_keys), "Objects {} and {} are different.".format(obj1, obj2) + for key in obj1.__dict__.keys(): + if key not in special_keys: + assert_equal(obj1.__dict__[key], obj2.__dict__[key]) + elif type(obj1) is dict or type(obj2) is dict: + assert_equal(obj1.keys(), obj2.keys()) + for key in obj1.keys(): + assert_equal(obj1[key], obj2[key]) + elif type(obj1) is list or type(obj2) is list: + assert len(obj1) == len(obj2), "Objects {} and {} are lists with different lengths.".format(obj1, obj2) + for i in range(len(obj1)): + assert_equal(obj1[i], obj2[i]) + elif type(obj1) is tuple or type(obj2) is tuple: + assert len(obj1) == len(obj2), "Objects {} and {} are tuples with different lengths.".format(obj1, obj2) + for i in range(len(obj1)): + assert_equal(obj1[i], obj2[i]) + else: + assert obj1 == obj2, "Objects {} and {} are different.".format(obj1, obj2) -class UserDefinedType(object): +PRIMITIVE_OBJECTS = [0, 0.0, 0L, 1L << 62, "a", string.printable, "\u262F", + u"hello world", u"\xff\xfe\x9c\x001\x000\x00", None, True, + False, [], (), {}, np.int8(3), np.int32(4), np.int64(5), + np.uint8(3), np.uint32(4), np.uint64(5), np.float32(1.0), + np.float64(1.0), np.zeros([100, 100]), + np.random.normal(size=[100, 100]), np.array(["hi", 3]), + np.array(["hi", 3], dtype=object), + np.array([["hi", u"hi"], [1.0, 1L]])] + +COMPLEX_OBJECTS = [#[[[[[[[[[[[[]]]]]]]]]]]], + {"obj{}".format(i): np.random.normal(size=[100, 100]) for i in range(10)}, + #{(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {}}}}}}}}}}}}}, + #((((((((((),),),),),),),),),), + #{"a": {"b": {"c": {"d": {}}}}} + ] + +class Foo(object): def __init__(self): pass - @staticmethod - def deserialize(primitives): - return "user defined type" +class Bar(object): + def __init__(self): + for i, val in enumerate(PRIMITIVE_OBJECTS + COMPLEX_OBJECTS): + setattr(self, "field{}".format(i), val) - def serialize(self): - return "user defined type" +class Baz(object): + def __init__(self): + self.foo = Foo() + self.bar = Bar() + def method(self, arg): + pass -class SerializationTest(unittest.TestCase): +class Qux(object): + def __init__(self): + self.objs = [Foo(), Bar(), Baz()] - def roundTripTest(self, data): - serialized, _ = ray.serialization.serialize(ray.worker.global_worker.handle, data) - result = ray.serialization.deserialize(ray.worker.global_worker.handle, serialized) - assert_equal(data, result) +class SubQux(Qux): + def __init__(self): + Qux.__init__(self) - def numpyTypeTest(self, typ): - self.roundTripTest(np.random.randint(0, 10, size=(100, 100)).astype(typ)) - self.roundTripTest(np.array(0).astype(typ)) - self.roundTripTest(np.empty((0,)).astype(typ)) +class CustomError(Exception): + pass - def testSerialize(self): - ray.init(start_ray_local=True, num_workers=0) +Point = namedtuple("Point", ["x", "y"]) +NamedTupleExample = namedtuple("Example", "field1, field2, field3, field4, field5") - for val in RAY_TEST_OBJECTS: - self.roundTripTest(val) +CUSTOM_OBJECTS = [Exception("Test object."), CustomError(), Point(11, y=22), + Foo(), Bar(), Baz(), # Qux(), SubQux(), + NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3])] - self.roundTripTest(np.zeros((100, 100))) +BASE_OBJECTS = PRIMITIVE_OBJECTS + COMPLEX_OBJECTS + CUSTOM_OBJECTS - self.numpyTypeTest("int8") - self.numpyTypeTest("uint8") - self.numpyTypeTest("int16") - self.numpyTypeTest("uint16") - self.numpyTypeTest("int32") - self.numpyTypeTest("uint32") - self.numpyTypeTest("float32") - self.numpyTypeTest("float64") +LIST_OBJECTS = [[obj] for obj in BASE_OBJECTS] +TUPLE_OBJECTS = [(obj,) for obj in BASE_OBJECTS] +# The check that type(obj).__module__ != "numpy" should be unnecessary, but +# otherwise this seems to fail on Mac OS X on Travis. +DICT_OBJECTS = ([{obj: obj} for obj in PRIMITIVE_OBJECTS if obj.__hash__ is not None and type(obj).__module__ != "numpy"] + +# DICT_OBJECTS = ([{obj: obj} for obj in BASE_OBJECTS if obj.__hash__ is not None] + + [{0: obj} for obj in BASE_OBJECTS]) - ref0 = ray.put(0) - ref1 = ray.put(0) - ref2 = ray.put(0) - ref3 = ray.put(0) - - a = np.array([[ref0, ref1], [ref2, ref3]]) - capsule, _ = ray.serialization.serialize(ray.worker.global_worker.handle, a) - result = ray.serialization.deserialize(ray.worker.global_worker.handle, capsule) - self.assertTrue((a == result).all()) - - self.roundTripTest(ref0) - self.roundTripTest([ref0, ref1, ref2, ref3]) - self.roundTripTest({"0": ref0, "1": ref1, "2": ref2, "3": ref3}) - self.roundTripTest((ref0, 1)) - - ray.worker.cleanup() +RAY_TEST_OBJECTS = BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS class ObjStoreTest(unittest.TestCase): @@ -93,36 +117,23 @@ class ObjStoreTest(unittest.TestCase): ray.reusables._cached_reusables = [] # This is a hack to make the test run. ray.connect(node_ip_address, scheduler_address, objstore_address=objstore_addresses[1], mode=ray.SCRIPT_MODE, worker=w2) + for cls in [Foo, Bar, Baz, Qux, SubQux, Exception, CustomError, Point, NamedTupleExample]: + ray.register_class(cls) + # putting and getting an object shouldn't change it for data in RAY_TEST_OBJECTS: objectid = ray.put(data, w1) result = ray.get(objectid, w1) - self.assertEqual(result, data) + assert_equal(result, data) # putting an object, shipping it to another worker, and getting it shouldn't change it for data in RAY_TEST_OBJECTS: - objectid = ray.put(data, w1) - result = ray.get(objectid, w2) - self.assertEqual(result, data) - - # putting an object, shipping it to another worker, and getting it shouldn't change it - for data in RAY_TEST_OBJECTS: - objectid = ray.put(data, w2) - result = ray.get(objectid, w1) - self.assertEqual(result, data) - - ARRAY_TEST_OBJECTS = [np.zeros([10, 20]), np.random.normal(size=[45, 25]), - ("a", np.random.normal(size=[10, 10])), - ["a", np.random.normal(size=[10, 10])]] - - # putting an array, shipping it to another worker, and getting it shouldn't change it - for data in ARRAY_TEST_OBJECTS: objectid = ray.put(data, w1) result = ray.get(objectid, w2) assert_equal(result, data) - # putting an array, shipping it to another worker, and getting it shouldn't change it - for data in ARRAY_TEST_OBJECTS: + # putting an object, shipping it to another worker, and getting it shouldn't change it + for data in RAY_TEST_OBJECTS: objectid = ray.put(data, w2) result = ray.get(objectid, w1) assert_equal(result, data) @@ -182,6 +193,23 @@ class WorkerTest(unittest.TestCase): class APITest(unittest.TestCase): + def testRegisterClass(self): + ray.init(start_ray_local=True, num_workers=0) + + # Check that putting an object of a class that has not been registered + # throws an exception. + class TempClass(object): + pass + self.assertRaises(Exception, lambda : ray.put(Foo)) + # Check that registering a class that Ray cannot serialize efficiently + # raises an exception. + self.assertRaises(Exception, lambda : ray.register_class(type(True))) + # Check that registering the same class with pickle works. + ray.register_class(type(float), pickle=True) + self.assertEqual(ray.get(ray.put(float)), float) + + ray.worker.cleanup() + def testKeywordArgs(self): reload(test_functions) ray.init(start_ray_local=True, num_workers=1) @@ -379,41 +407,60 @@ class ReferenceCountingTest(unittest.TestCase): for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: reload(module) ray.init(start_ray_local=True, num_workers=1) + ray.register_class(da.DistArray) + + def check_not_deallocated(object_ids): + reference_counts = ray.scheduler_info()["reference_counts"] + for object_id in object_ids: + self.assertGreater(reference_counts[object_id.id], 0) + + def check_everything_deallocated(): + reference_counts = ray.scheduler_info()["reference_counts"] + self.assertEqual(reference_counts, len(reference_counts) * [-1]) z = da.zeros.remote([da.BLOCK_SIZE, 2 * da.BLOCK_SIZE]) time.sleep(0.1) objectid_val = z.id - self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val:(objectid_val + 3)], [1, 1, 1]) - + time.sleep(0.1) + check_not_deallocated([z]) del z time.sleep(0.1) - self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val:(objectid_val + 3)], [-1, -1, -1]) + check_everything_deallocated() x = ra.zeros.remote([10, 10]) y = ra.zeros.remote([10, 10]) z = ra.dot.remote(x, y) objectid_val = x.id time.sleep(0.1) - self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val:(objectid_val + 3)], [1, 1, 1]) - + check_not_deallocated([x, y, z]) del x time.sleep(0.1) - self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val:(objectid_val + 3)], [-1, 1, 1]) + check_not_deallocated([y, z]) del y time.sleep(0.1) - self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val:(objectid_val + 3)], [-1, -1, 1]) + check_not_deallocated([z]) del z time.sleep(0.1) - self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val:(objectid_val + 3)], [-1, -1, -1]) + check_everything_deallocated() + + z = da.zeros.remote([4 * da.BLOCK_SIZE]) + time.sleep(0.1) + check_not_deallocated(ray.get(z).objectids.tolist()) + del z + time.sleep(0.1) + check_everything_deallocated() ray.worker.cleanup() def testGet(self): ray.init(start_ray_local=True, num_workers=3) + for cls in [Foo, Bar, Baz, Qux, SubQux, Exception, CustomError, Point, NamedTupleExample]: + ray.register_class(cls) + # Remote objects should be deallocated when the corresponding ObjectID goes # out of scope, and all results of ray.get called on the ID go out of scope. - for val in RAY_TEST_OBJECTS + [np.zeros((2, 2)), UserDefinedType()]: + for val in RAY_TEST_OBJECTS: x = ray.put(val) objectid = x.id xval = ray.get(x) diff --git a/thirdparty/numbuf b/thirdparty/numbuf index c4c33bd08..5ac2df432 160000 --- a/thirdparty/numbuf +++ b/thirdparty/numbuf @@ -1 +1 @@ -Subproject commit c4c33bd087f9b4aed47f518d4e99fdd401abccbe +Subproject commit 5ac2df4329d2dc3039a503151ab985067e28c733