Allow users to serialize custom classes. (#393)

* Allow serialization of custom classes.

* Add documentation and test cases, also fix pickle case.

* Don't allow old-style classes.
This commit is contained in:
Robert Nishihara
2016-09-06 13:28:24 -07:00
committed by Philipp Moritz
parent d5cb3ac090
commit 11a8914684
22 changed files with 497 additions and 403 deletions
+1 -1
View File
@@ -11,7 +11,7 @@ if hasattr(ctypes, "windll"):
import config
import serialization
from worker import scheduler_info, visualize_computation_graph, task_info, init, connect, disconnect, get, put, select, remote, kill_workers, restart_workers_local
from worker import scheduler_info, register_class, visualize_computation_graph, task_info, init, connect, disconnect, get, put, select, remote, kill_workers, restart_workers_local
from worker import Reusable, reusables
from libraylib import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE
from libraylib import ObjectID
-8
View File
@@ -16,14 +16,6 @@ class DistArray(object):
if self.num_blocks != list(self.objectids.shape):
raise Exception("The fields `num_blocks` and `objectids` are inconsistent, `num_blocks` is {} and `objectids` has shape {}".format(self.num_blocks, list(self.objectids.shape)))
@staticmethod
def deserialize(primitives):
(shape, objectids) = primitives
return DistArray(shape, objectids)
def serialize(self):
return (self.shape, self.objectids)
@staticmethod
def compute_block_lower(index, shape):
if len(index) != len(shape):
+131 -35
View File
@@ -1,42 +1,138 @@
import importlib
import numpy as np
import pickling
import libraylib as raylib
import libnumbuf
def to_primitive(obj):
if hasattr(obj, "serialize"):
primitive_obj = ((type(obj).__module__, type(obj).__name__), obj.serialize())
else:
primitive_obj = ("primitive", obj)
return primitive_obj
def check_serializable(cls):
"""Throws an exception if Ray cannot serialize this class efficiently.
def from_primitive(primitive_obj):
if primitive_obj[0] == "primitive":
obj = primitive_obj[1]
Args:
cls (type): The class to be serialized.
Raises:
Exception: An exception is raised if Ray cannot serialize this class
efficiently.
"""
if is_named_tuple(cls):
# This case works.
return
if not hasattr(cls, "__new__"):
raise Exception("The class {} does not have a '__new__' attribute, and is probably an old-style class. We do not support this. Please either make it a new-style class by inheriting from 'object', or use 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
try:
obj = cls.__new__(cls)
except:
raise Exception("The class {} has overridden '__new__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
if not hasattr(obj, "__dict__"):
raise Exception("Objects of the class {} do not have a `__dict__` attribute, so Ray cannot serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
if hasattr(obj, "__slots__"):
raise Exception("The class {} uses '__slots__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
# This field keeps track of a whitelisted set of classes that Ray will
# serialize.
whitelisted_classes = {}
classes_to_pickle = set()
custom_serializers = {}
custom_deserializers = {}
def class_identifier(typ):
"""Return a string that identifies this type."""
return "{}.{}".format(typ.__module__, typ.__name__)
def is_named_tuple(cls):
"""Return True if cls is a namedtuple and False otherwise."""
b = cls.__bases__
if len(b) != 1 or b[0] != tuple:
return False
f = getattr(cls, "_fields", None)
if not isinstance(f, tuple):
return False
return all(type(n) == str for n in f)
def add_class_to_whitelist(cls, pickle=False, custom_serializer=None, custom_deserializer=None):
"""Add cls to the list of classes that we can serialize.
Args:
cls (type): The class that we can serialize.
pickle (bool): True if the serialization should be done with pickle. False
if it should be done efficiently with Ray.
custom_serializer: This argument is optional, but can be provided to
serialize objects of the class in a particular way.
custom_deserializer: This argument is optional, but can be provided to
deserialize objects of the class in a particular way.
"""
class_id = class_identifier(cls)
whitelisted_classes[class_id] = cls
if pickle:
classes_to_pickle.add(class_id)
if custom_serializer is not None:
custom_serializers[class_id] = custom_serializer
custom_deserializers[class_id] = custom_deserializer
# Here we define a custom serializer and deserializer for handling numpy
# arrays that contain objects.
def array_custom_serializer(obj):
return obj.tolist(), obj.dtype.str
def array_custom_deserializer(serialized_obj):
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
add_class_to_whitelist(np.ndarray, pickle=False, custom_serializer=array_custom_serializer, custom_deserializer=array_custom_deserializer)
def serialize(obj):
"""This is the callback that will be used by numbuf.
If numbuf does not know how to serialize an object, it will call this method.
Args:
obj (object): A Python object.
Returns:
A dictionary that has the key "_pyttype_" to identify the class, and
contains all information needed to reconstruct the object.
"""
class_id = class_identifier(type(obj))
if class_id not in whitelisted_classes:
raise Exception("Ray does not know how to serialize the object {}. To fix this, call 'ray.register_class' on the class of the object.".format(obj))
if class_id in classes_to_pickle:
serialized_obj = {"data": pickling.dumps(obj)}
elif class_id in custom_serializers.keys():
serialized_obj = {"data": custom_serializers[class_id](obj)}
else:
# This code assumes that the type module.__dict__[type_name] knows how to deserialize itself
type_module, type_name = primitive_obj[0]
module = importlib.import_module(type_module)
obj = module.__dict__[type_name].deserialize(primitive_obj[1])
if not hasattr(obj, "__dict__"):
raise Exception("We do not know how to serialize the object '{}'".format(obj))
serialized_obj = obj.__dict__
if is_named_tuple(type(obj)):
# Handle the namedtuple case.
serialized_obj["_ray_getnewargs_"] = obj.__getnewargs__()
result = dict(serialized_obj, **{"_pytype_": class_id})
return result
def deserialize(serialized_obj):
"""This is the callback that will be used by numbuf.
If numbuf encounters a dictionary that contains the key "_pytype_" during
deserialization, it will ask this callback to deserialize the object.
Args:
serialized_obj (object): A dictionary that contains the key "_pytype_".
Returns:
A Python object.
"""
class_id = serialized_obj["_pytype_"]
cls = whitelisted_classes[class_id]
if class_id in classes_to_pickle:
obj = pickling.loads(serialized_obj["data"])
elif class_id in custom_deserializers.keys():
obj = custom_deserializers[class_id](serialized_obj["data"])
else:
# In this case, serialized_obj should just be the __dict__ field.
if "_ray_getnewargs_" in serialized_obj:
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
serialized_obj.pop("_ray_getnewargs_")
else:
obj = cls.__new__(cls)
serialized_obj.pop("_pytype_")
obj.__dict__.update(serialized_obj)
return obj
def is_arrow_serializable(value):
return isinstance(value, np.ndarray) and value.dtype.name in ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "float32", "float64"]
def serialize(worker_capsule, obj):
primitive_obj = to_primitive(obj)
obj_capsule, contained_objectids = raylib.serialize_object(worker_capsule, primitive_obj) # contained_objectids is a list of the objectids contained in obj
return obj_capsule, contained_objectids
def deserialize(worker_capsule, capsule):
primitive_obj = raylib.deserialize_object(worker_capsule, capsule)
return from_primitive(primitive_obj)
def serialize_task(worker_capsule, func_name, args):
primitive_args = [(arg if isinstance(arg, raylib.ObjectID) else to_primitive(arg)) for arg in args]
return raylib.serialize_task(worker_capsule, func_name, primitive_args)
def deserialize_task(worker_capsule, task):
func_name, primitive_args, return_objectids = task
args = [(arg if isinstance(arg, raylib.ObjectID) else from_primitive(arg)) for arg in primitive_args]
return func_name, args, return_objectids
# Register the callbacks with numbuf.
libnumbuf.register_callbacks(serialize, deserialize)
+118 -102
View File
@@ -22,13 +22,31 @@ import services
import libnumbuf
import libraylib as raylib
contained_objectids = []
def numbuf_serialize(value):
"""This serializes a value and tracks the object IDs inside the value.
We also define a custom ObjectID serializer which also closes over the global
variable contained_objectids, and whenever the custom serializer is called, it
adds the releevant ObjectID to the list contained_objectids. The list
contained_objectids should be reset between calls to numbuf_serialize.
Args:
value: A Python object that will be serialized.
Returns:
The serialized object.
"""
assert len(contained_objectids) == 0, "This should be unreachable."
return libnumbuf.serialize_list([value])
class RayTaskError(Exception):
"""An object used internally to represent a task that threw an exception.
If a task throws an exception during execution, a RayTaskError is stored in
the object store for each of the task's outputs. When an object is retrieved
from the object store, the Python method that retrieved it checks to see if
the object is a RayTaskError and if it is then an exceptionis thrown
the object is a RayTaskError and if it is then an exception is thrown
propagating the error message.
Currently, we either use the exception attribute or the traceback attribute
@@ -50,32 +68,6 @@ class RayTaskError(Exception):
self.exception = None
self.traceback_str = traceback_str
@staticmethod
def deserialize(primitives):
"""Create a RayTaskError from a primitive object."""
function_name, exception, traceback_str = primitives
if exception[0] == "RayGetError":
exception = RayGetError.deserialize(exception[1])
elif exception[0] == "RayGetArgumentError":
exception = RayGetArgumentError.deserialize(exception[1])
elif exception[0] == "None":
exception = None
else:
assert False, "This code should be unreachable."
return RayTaskError(function_name, exception, traceback_str)
def serialize(self):
"""Turn a RayTaskError into a primitive object."""
if isinstance(self.exception, RayGetError):
serialized_exception = ("RayGetError", self.exception.serialize())
elif isinstance(self.exception, RayGetArgumentError):
serialized_exception = ("RayGetArgumentError", self.exception.serialize())
elif self.exception is None:
serialized_exception = ("None",)
else:
assert False, "This code should be unreachable."
return (self.function_name, serialized_exception, self.traceback_str)
def __str__(self):
"""Format a RayTaskError as a string."""
if self.traceback_str is None:
@@ -99,16 +91,6 @@ class RayGetError(Exception):
self.objectid = objectid
self.task_error = task_error
@staticmethod
def deserialize(primitives):
"""Create a RayGetError from a primitive object."""
objectid, task_error = primitives
return RayGetError(objectid, RayTaskError.deserialize(task_error))
def serialize(self):
"""Turn a RayGetError into a primitive object."""
return (self.objectid, self.task_error.serialize())
def __str__(self):
"""Format a RayGetError as a string."""
return "Could not get objectid {}. It was created by remote function {}{}{} which failed with:\n\n{}".format(self.objectid, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
@@ -132,16 +114,6 @@ class RayGetArgumentError(Exception):
self.objectid = objectid
self.task_error = task_error
@staticmethod
def deserialize(primitives):
"""Create a RayGetArgumentError from a primitive object."""
function_name, argument_index, objectid, task_error = primitives
return RayGetArgumentError(function_name, argument_index, objectid, RayTaskError.deserialize(task_error))
def serialize(self):
"""Turn a RayGetArgumentError into a primitive object."""
return (self.function_name, self.argument_index, self.objectid, self.task_error.serialize())
def __str__(self):
"""Format a RayGetArgumentError as a string."""
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
@@ -366,30 +338,27 @@ class Worker(object):
objectid (raylib.ObjectID): The object ID of the value to be put.
value (serializable object): The value to put in the object store.
"""
try:
# We put the value into a list here because in arrow the concept of
# "serializing a single object" does not exits.
schema, size, serialized = libnumbuf.serialize_list([value])
# TODO(pcm): Right now, metadata is serialized twice, change that in the future
# in the following line, the "8" is for storing the metadata size,
# the len(schema) is for storing the metadata and the 4096 is for storing
# the metadata in the batch (see INITIAL_METADATA_SIZE in arrow)
size = size + 8 + len(schema) + 4096
buff, segmentid = raylib.allocate_buffer(self.handle, objectid, size)
# write the metadata length
np.frombuffer(buff, dtype="int64", count=1)[0] = len(schema)
# metadata buffer
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=len(schema))
# write the metadata
metadata[:] = schema
data = np.frombuffer(buff, dtype="byte")[8 + len(schema):]
metadata_offset = libnumbuf.write_to_buffer(serialized, memoryview(data))
raylib.finish_buffer(self.handle, objectid, segmentid, metadata_offset)
except:
# At the moment, custom object and objects that contain object IDs take this path
# TODO(pcm): Make sure that these are the only objects getting serialized to protobuf
object_capsule, contained_objectids = serialization.serialize(self.handle, value) # contained_objectids is a list of the objectids contained in object_capsule
raylib.put_object(self.handle, objectid, object_capsule, contained_objectids)
# We put the value into a list here because in arrow the concept of
# "serializing a single object" does not exits.
schema, size, serialized = numbuf_serialize(value)
global contained_objectids
raylib.add_contained_objectids(self.handle, objectid, contained_objectids)
contained_objectids = []
# TODO(pcm): Right now, metadata is serialized twice, change that in the future
# in the following line, the "8" is for storing the metadata size,
# the len(schema) is for storing the metadata and the 8192 is for storing
# the metadata in the batch (see INITIAL_METADATA_SIZE in arrow)
size = size + 8 + len(schema) + 4096
buff, segmentid = raylib.allocate_buffer(self.handle, objectid, size)
# write the metadata length
np.frombuffer(buff, dtype="int64", count=1)[0] = len(schema)
# metadata buffer
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=len(schema))
# write the metadata
metadata[:] = schema
data = np.frombuffer(buff, dtype="byte")[8 + len(schema):]
metadata_offset = libnumbuf.write_to_buffer(serialized, memoryview(data))
raylib.finish_buffer(self.handle, objectid, segmentid, metadata_offset)
def get_object(self, objectid):
"""Get the value in the local object store associated with objectid.
@@ -400,32 +369,25 @@ class Worker(object):
Args:
objectid (raylib.ObjectID): The object ID of the value to retrieve.
"""
if raylib.is_arrow(self.handle, objectid):
## this is the new codepath
buff, segmentid, metadata_offset = raylib.get_buffer(self.handle, objectid)
metadata_size = np.frombuffer(buff, dtype="int64", count=1)[0]
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=metadata_size)
data = np.frombuffer(buff, dtype="byte")[8 + metadata_size:]
serialized = libnumbuf.read_from_buffer(memoryview(data), bytearray(metadata), metadata_offset)
# If there is currently no ObjectFixture for this ObjectID, then create a
# new one. The object_fixtures object is a WeakValueDictionary, so entries
# will be discarded when there are no strong references to their values.
# We create object_fixture outside of the assignment because if we created
# it inside the assignement it would immediately go out of scope.
object_fixture = None
if objectid.id not in object_fixtures:
object_fixture = ObjectFixture(objectid, segmentid, self.handle)
object_fixtures[objectid.id] = object_fixture
deserialized = libnumbuf.deserialize_list(serialized, object_fixtures[objectid.id])
# Unwrap the object from the list (it was wrapped put_object)
assert len(deserialized) == 1
result = deserialized[0]
## this is the old codepath
# result, segmentid = raylib.get_arrow(self.handle, objectid)
else:
object_capsule, segmentid = raylib.get_object(self.handle, objectid)
result = serialization.deserialize(self.handle, object_capsule)
assert raylib.is_arrow(self.handle, objectid), "All objects should be serialized using Arrow."
buff, segmentid, metadata_offset = raylib.get_buffer(self.handle, objectid)
metadata_size = np.frombuffer(buff, dtype="int64", count=1)[0]
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=metadata_size)
data = np.frombuffer(buff, dtype="byte")[8 + metadata_size:]
serialized = libnumbuf.read_from_buffer(memoryview(data), bytearray(metadata), metadata_offset)
# If there is currently no ObjectFixture for this ObjectID, then create a
# new one. The object_fixtures object is a WeakValueDictionary, so entries
# will be discarded when there are no strong references to their values.
# We create object_fixture outside of the assignment because if we created
# it inside the assignement it would immediately go out of scope.
object_fixture = None
if objectid.id not in object_fixtures:
object_fixture = ObjectFixture(objectid, segmentid, self.handle)
object_fixtures[objectid.id] = object_fixture
deserialized = libnumbuf.deserialize_list(serialized, object_fixtures[objectid.id])
# Unwrap the object from the list (it was wrapped put_object)
assert len(deserialized) == 1
result = deserialized[0]
return result
def alias_objectids(self, alias_objectid, target_objectid):
@@ -445,7 +407,10 @@ class Worker(object):
be object IDs or they can be values. If they are values, they
must be serializable objecs.
"""
task_capsule = serialization.serialize_task(self.handle, func_name, args)
# Convert all of the argumens to object IDs. It is a little strange that we
# are calling put, which is external to this class.
args = [arg if isinstance(arg, raylib.ObjectID) else put(arg, worker=self) for arg in args]
task_capsule = raylib.serialize_task(self.handle, func_name, args)
objectids = raylib.submit_task(self.handle, task_capsule)
return objectids
@@ -461,10 +426,13 @@ class Worker(object):
not take any arguments. If it returns anything, its return values will
not be used.
"""
if self.mode not in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]:
raise Exception("run_function_on_all_workers can only be called on a driver.")
# First run the function on the driver.
function(self)
# Then run the function on all of the workers.
raylib.run_function_on_all_workers(self.handle, pickling.dumps(function))
if self.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
raylib.run_function_on_all_workers(self.handle, pickling.dumps(function))
global_worker = Worker()
"""Worker: The global Worker object for this worker process.
@@ -568,6 +536,28 @@ def task_info(worker=global_worker):
check_connected(worker)
return raylib.task_info(worker.handle)
def initialize_numbuf(worker=global_worker):
"""Initialize the serialization library.
This defines a custom serializer for object IDs and also tells numbuf to
serialize several exception classes that we define for error handling.
"""
# Define a custom serializer and deserializer for handling Object IDs.
def objectid_custom_serializer(obj):
class_identifier = serialization.class_identifier(type(obj))
contained_objectids.append(obj)
return raylib.serialize_objectid(worker.handle, obj)
def objectid_custom_deserializer(serialized_obj):
return raylib.deserialize_objectid(worker.handle, serialized_obj)
serialization.add_class_to_whitelist(raylib.ObjectID, pickle=False, custom_serializer=objectid_custom_serializer, custom_deserializer=objectid_custom_deserializer)
if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
# These should only be called on the driver because register_class will
# export the class to all of the workers.
register_class(RayTaskError)
register_class(RayGetError)
register_class(RayGetArgumentError)
def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_address=None, node_ip_address=None, driver_mode=raylib.SCRIPT_MODE):
"""Either connect to an existing Ray cluster or start one and connect to it.
@@ -735,14 +725,16 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl
# the same.
script_directory = os.path.abspath(os.path.dirname(sys.argv[0]))
current_directory = os.path.abspath(os.path.curdir)
worker.run_function_on_all_workers(lambda worker : sys.path.insert(1, script_directory))
worker.run_function_on_all_workers(lambda worker : sys.path.insert(1, current_directory))
worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, script_directory))
worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, current_directory))
# Export cached remote functions to the workers.
for function_name, function_to_export in worker.cached_remote_functions:
raylib.export_remote_function(worker.handle, function_name, function_to_export)
# Export cached reusable variables to the workers.
for name, reusable_variable in reusables._cached_reusables:
_export_reusable_variable(name, reusable_variable)
# Initialize the serialization library.
initialize_numbuf()
worker.cached_remote_functions = None
reusables._cached_reusables = None
@@ -757,6 +749,30 @@ def disconnect(worker=global_worker):
worker.cached_remote_functions = []
reusables._cached_reusables = []
def register_class(cls, pickle=False, worker=global_worker):
"""Enable workers to serialize or deserialize objects of a particular class.
This method runs the register_class function defined below on every worker,
which will enable libnumbuf to properly serialize and deserialize objects of
this class.
Args:
cls (type): The class that libnumbuf should serialize.
pickle (bool): If False then objects of this class will be serialized by
turning their __dict__ fields into a dictionary. If True, then objects
of this class will be serialized using pickle.
Raises:
Exception: An exception is raised if pickle=False and the class cannot be
efficiently serialized by Ray.
"""
# Raise an exception if cls cannot be serialized efficiently by Ray.
if not pickle:
serialization.check_serializable(cls)
def register_class_for_serialization(worker):
serialization.add_class_to_whitelist(cls, pickle=pickle)
worker.run_function_on_all_workers(register_class_for_serialization)
def get(objectid, worker=global_worker):
"""Get a remote object or a list of remote objects from the object store.
@@ -915,7 +931,7 @@ def main_loop(worker=global_worker):
After the task executes, the worker resets any reusable variables that were
accessed by the task.
"""
function_name, args, return_objectids = serialization.deserialize_task(worker.handle, task)
function_name, args, return_objectids = task
try:
arguments = get_arguments_for_execution(worker.functions[function_name], args, worker) # get args from objstore
outputs = worker.functions[function_name].executor(arguments) # execute the function