Remove register_class from API. (#550)

* Perform ray.register_class under the hood.

* Fix bug.

* Release worker lock when waiting for imports to arrive in get.

* Remove calls to register_class from examples and tests.

* Clear serialization state between tests.

* Fix bug and add test for multiple custom classes with same name.

* Fix failure test.

* Fix linting and cleanups to python code.

* Fixes to documentation.

* Implement recursion depth for recursively registering classes.

* Fix linting.

* Push warning to user if waiting for class for too long.

* Fix typos.

* Don't export FunctionToRun if pickling the function fails.

* Don't broadcast class definition when pickling class.
This commit is contained in:
Robert Nishihara
2017-05-16 18:38:52 -07:00
committed by Philipp Moritz
parent 3ebfd850e1
commit ec2534422b
11 changed files with 378 additions and 304 deletions
+85 -64
View File
@@ -2,12 +2,26 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import ray.numbuf
import ray.pickling as pickling
class RaySerializationException(Exception):
def __init__(self, message, example_object):
Exception.__init__(self, message)
self.example_object = example_object
class RayDeserializationException(Exception):
def __init__(self, message, class_id):
Exception.__init__(self, message)
self.class_id = class_id
class RayNotDictionarySerializable(Exception):
pass
def check_serializable(cls):
"""Throws an exception if Ray cannot serialize this class efficiently.
@@ -22,42 +36,38 @@ def check_serializable(cls):
# This case works.
return
if not hasattr(cls, "__new__"):
raise Exception("The class {} does not have a '__new__' attribute, and is "
"probably an old-style class. We do not support this. "
"Please either make it a new-style class by inheriting "
"from 'object', or use "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
print("The class {} does not have a '__new__' attribute and is probably "
"an old-stye class. Please make it a new-style class by inheriting "
"from 'object'.")
raise RayNotDictionarySerializable("The class {} does not have a "
"'__new__' attribute and is probably "
"an old-style class. We do not support "
"this. Please make it a new-style "
"class by inheriting from 'object'."
.format(cls))
try:
obj = cls.__new__(cls)
except:
raise Exception("The class {} has overridden '__new__', so Ray may not be "
"able to serialize it efficiently. Try using "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
raise RayNotDictionarySerializable("The class {} has overridden '__new__'"
", so Ray may not be able to serialize "
"it efficiently.".format(cls))
if not hasattr(obj, "__dict__"):
raise Exception("Objects of the class {} do not have a `__dict__` "
"attribute, so Ray cannot serialize it efficiently. Try "
"using 'ray.register_class(cls, pickle=True)'. However, "
"note that pickle is inefficient.".format(cls))
raise RayNotDictionarySerializable("Objects of the class {} do not have a "
"'__dict__' attribute, so Ray cannot "
"serialize it efficiently.".format(cls))
if hasattr(obj, "__slots__"):
raise Exception("The class {} uses '__slots__', so Ray may not be able to "
"serialize it efficiently. Try using "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
raise RayNotDictionarySerializable("The class {} uses '__slots__', so Ray "
"may not be able to serialize it "
"efficiently.".format(cls))
# This field keeps track of a whitelisted set of classes that Ray will
# serialize.
whitelisted_classes = {}
type_to_class_id = dict()
whitelisted_classes = dict()
classes_to_pickle = set()
custom_serializers = {}
custom_deserializers = {}
def class_identifier(typ):
"""Return a string that identifies this type."""
return "{}.{}".format(typ.__module__, typ.__name__)
custom_serializers = dict()
custom_deserializers = dict()
def is_named_tuple(cls):
@@ -71,12 +81,13 @@ def is_named_tuple(cls):
return all(type(n) == str for n in f)
def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
def add_class_to_whitelist(cls, class_id, pickle=False, custom_serializer=None,
custom_deserializer=None):
"""Add cls to the list of classes that we can serialize.
Args:
cls (type): The class that we can serialize.
class_id: A string of bytes used to identify the class.
pickle (bool): True if the serialization should be done with pickle. False
if it should be done efficiently with Ray.
custom_serializer: This argument is optional, but can be provided to
@@ -84,7 +95,7 @@ def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
custom_deserializer: This argument is optional, but can be provided to
deserialize objects of the class in a particular way.
"""
class_id = class_identifier(cls)
type_to_class_id[cls] = class_id
whitelisted_classes[class_id] = cls
if pickle:
classes_to_pickle.add(class_id)
@@ -93,21 +104,6 @@ def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
custom_deserializers[class_id] = custom_deserializer
# Here we define a custom serializer and deserializer for handling numpy
# arrays that contain objects.
def array_custom_serializer(obj):
return obj.tolist(), obj.dtype.str
def array_custom_deserializer(serialized_obj):
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
add_class_to_whitelist(np.ndarray, pickle=False,
custom_serializer=array_custom_serializer,
custom_deserializer=array_custom_deserializer)
def serialize(obj):
"""This is the callback that will be used by numbuf.
@@ -120,14 +116,16 @@ def serialize(obj):
A dictionary that has the key "_pyttype_" to identify the class, and
contains all information needed to reconstruct the object.
"""
class_id = class_identifier(type(obj))
if class_id not in whitelisted_classes:
raise Exception("Ray does not know how to serialize objects of type {}. "
"To fix this, call 'ray.register_class' with this class."
.format(type(obj)))
if type(obj) not in type_to_class_id:
raise RaySerializationException("Ray does not know how to serialize "
"objects of type {}.".format(type(obj)),
obj)
class_id = type_to_class_id[type(obj)]
if class_id in classes_to_pickle:
serialized_obj = {"data": pickling.dumps(obj)}
elif class_id in custom_serializers.keys():
serialized_obj = {"data": pickling.dumps(obj),
"pickle": True}
elif class_id in custom_serializers:
serialized_obj = {"data": custom_serializers[class_id](obj)}
else:
# Handle the namedtuple case.
@@ -137,8 +135,8 @@ def serialize(obj):
elif hasattr(obj, "__dict__"):
serialized_obj = obj.__dict__
else:
raise Exception("We do not know how to serialize the object '{}'"
.format(obj))
raise RaySerializationException("We do not know how to serialize the "
"object '{}'".format(obj), obj)
result = dict(serialized_obj, **{"_pytype_": class_id})
return result
@@ -154,21 +152,36 @@ def deserialize(serialized_obj):
Returns:
A Python object.
Raises:
An exception is raised if we do not know how to deserialize the object.
"""
class_id = serialized_obj["_pytype_"]
cls = whitelisted_classes[class_id]
if class_id in classes_to_pickle:
if "pickle" in serialized_obj:
# The object was pickled, so unpickle it.
obj = pickling.loads(serialized_obj["data"])
elif class_id in custom_deserializers.keys():
obj = custom_deserializers[class_id](serialized_obj["data"])
else:
# In this case, serialized_obj should just be the __dict__ field.
if "_ray_getnewargs_" in serialized_obj:
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
assert class_id not in classes_to_pickle
if class_id not in whitelisted_classes:
# If this happens, that means that the call to _register_class, which
# should have added the class to the list of whitelisted classes, has not
# yet propagated to this worker. It should happen if we wait a little
# longer.
raise RayDeserializationException("The class {} is not one of the "
"whitelisted classes."
.format(class_id), class_id)
cls = whitelisted_classes[class_id]
if class_id in custom_deserializers:
obj = custom_deserializers[class_id](serialized_obj["data"])
else:
obj = cls.__new__(cls)
serialized_obj.pop("_pytype_")
obj.__dict__.update(serialized_obj)
# In this case, serialized_obj should just be the __dict__ field.
if "_ray_getnewargs_" in serialized_obj:
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
else:
obj = cls.__new__(cls)
serialized_obj.pop("_pytype_")
obj.__dict__.update(serialized_obj)
return obj
@@ -181,3 +194,11 @@ def set_callbacks():
callback.
"""
ray.numbuf.register_callbacks(serialize, deserialize)
def clear_state():
type_to_class_id.clear()
whitelisted_classes.clear()
classes_to_pickle.clear()
custom_serializers.clear()
custom_deserializers.clear()