mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 04:27:27 +08:00
Remove register_class from API. (#550)
* Perform ray.register_class under the hood. * Fix bug. * Release worker lock when waiting for imports to arrive in get. * Remove calls to register_class from examples and tests. * Clear serialization state between tests. * Fix bug and add test for multiple custom classes with same name. * Fix failure test. * Fix linting and cleanups to python code. * Fixes to documentation. * Implement recursion depth for recursively registering classes. * Fix linting. * Push warning to user if waiting for class for too long. * Fix typos. * Don't export FunctionToRun if pickling the function fails. * Don't broadcast class definition when pickling class.
This commit is contained in:
committed by
Philipp Moritz
parent
3ebfd850e1
commit
ec2534422b
+85
-64
@@ -2,12 +2,26 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray.numbuf
|
||||
import ray.pickling as pickling
|
||||
|
||||
|
||||
class RaySerializationException(Exception):
|
||||
def __init__(self, message, example_object):
|
||||
Exception.__init__(self, message)
|
||||
self.example_object = example_object
|
||||
|
||||
|
||||
class RayDeserializationException(Exception):
|
||||
def __init__(self, message, class_id):
|
||||
Exception.__init__(self, message)
|
||||
self.class_id = class_id
|
||||
|
||||
|
||||
class RayNotDictionarySerializable(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def check_serializable(cls):
|
||||
"""Throws an exception if Ray cannot serialize this class efficiently.
|
||||
|
||||
@@ -22,42 +36,38 @@ def check_serializable(cls):
|
||||
# This case works.
|
||||
return
|
||||
if not hasattr(cls, "__new__"):
|
||||
raise Exception("The class {} does not have a '__new__' attribute, and is "
|
||||
"probably an old-style class. We do not support this. "
|
||||
"Please either make it a new-style class by inheriting "
|
||||
"from 'object', or use "
|
||||
"'ray.register_class(cls, pickle=True)'. However, note "
|
||||
"that pickle is inefficient.".format(cls))
|
||||
print("The class {} does not have a '__new__' attribute and is probably "
|
||||
"an old-stye class. Please make it a new-style class by inheriting "
|
||||
"from 'object'.")
|
||||
raise RayNotDictionarySerializable("The class {} does not have a "
|
||||
"'__new__' attribute and is probably "
|
||||
"an old-style class. We do not support "
|
||||
"this. Please make it a new-style "
|
||||
"class by inheriting from 'object'."
|
||||
.format(cls))
|
||||
try:
|
||||
obj = cls.__new__(cls)
|
||||
except:
|
||||
raise Exception("The class {} has overridden '__new__', so Ray may not be "
|
||||
"able to serialize it efficiently. Try using "
|
||||
"'ray.register_class(cls, pickle=True)'. However, note "
|
||||
"that pickle is inefficient.".format(cls))
|
||||
raise RayNotDictionarySerializable("The class {} has overridden '__new__'"
|
||||
", so Ray may not be able to serialize "
|
||||
"it efficiently.".format(cls))
|
||||
if not hasattr(obj, "__dict__"):
|
||||
raise Exception("Objects of the class {} do not have a `__dict__` "
|
||||
"attribute, so Ray cannot serialize it efficiently. Try "
|
||||
"using 'ray.register_class(cls, pickle=True)'. However, "
|
||||
"note that pickle is inefficient.".format(cls))
|
||||
raise RayNotDictionarySerializable("Objects of the class {} do not have a "
|
||||
"'__dict__' attribute, so Ray cannot "
|
||||
"serialize it efficiently.".format(cls))
|
||||
if hasattr(obj, "__slots__"):
|
||||
raise Exception("The class {} uses '__slots__', so Ray may not be able to "
|
||||
"serialize it efficiently. Try using "
|
||||
"'ray.register_class(cls, pickle=True)'. However, note "
|
||||
"that pickle is inefficient.".format(cls))
|
||||
raise RayNotDictionarySerializable("The class {} uses '__slots__', so Ray "
|
||||
"may not be able to serialize it "
|
||||
"efficiently.".format(cls))
|
||||
|
||||
|
||||
# This field keeps track of a whitelisted set of classes that Ray will
|
||||
# serialize.
|
||||
whitelisted_classes = {}
|
||||
type_to_class_id = dict()
|
||||
whitelisted_classes = dict()
|
||||
classes_to_pickle = set()
|
||||
custom_serializers = {}
|
||||
custom_deserializers = {}
|
||||
|
||||
|
||||
def class_identifier(typ):
|
||||
"""Return a string that identifies this type."""
|
||||
return "{}.{}".format(typ.__module__, typ.__name__)
|
||||
custom_serializers = dict()
|
||||
custom_deserializers = dict()
|
||||
|
||||
|
||||
def is_named_tuple(cls):
|
||||
@@ -71,12 +81,13 @@ def is_named_tuple(cls):
|
||||
return all(type(n) == str for n in f)
|
||||
|
||||
|
||||
def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
|
||||
def add_class_to_whitelist(cls, class_id, pickle=False, custom_serializer=None,
|
||||
custom_deserializer=None):
|
||||
"""Add cls to the list of classes that we can serialize.
|
||||
|
||||
Args:
|
||||
cls (type): The class that we can serialize.
|
||||
class_id: A string of bytes used to identify the class.
|
||||
pickle (bool): True if the serialization should be done with pickle. False
|
||||
if it should be done efficiently with Ray.
|
||||
custom_serializer: This argument is optional, but can be provided to
|
||||
@@ -84,7 +95,7 @@ def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
|
||||
custom_deserializer: This argument is optional, but can be provided to
|
||||
deserialize objects of the class in a particular way.
|
||||
"""
|
||||
class_id = class_identifier(cls)
|
||||
type_to_class_id[cls] = class_id
|
||||
whitelisted_classes[class_id] = cls
|
||||
if pickle:
|
||||
classes_to_pickle.add(class_id)
|
||||
@@ -93,21 +104,6 @@ def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
|
||||
custom_deserializers[class_id] = custom_deserializer
|
||||
|
||||
|
||||
# Here we define a custom serializer and deserializer for handling numpy
|
||||
# arrays that contain objects.
|
||||
def array_custom_serializer(obj):
|
||||
return obj.tolist(), obj.dtype.str
|
||||
|
||||
|
||||
def array_custom_deserializer(serialized_obj):
|
||||
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
|
||||
|
||||
|
||||
add_class_to_whitelist(np.ndarray, pickle=False,
|
||||
custom_serializer=array_custom_serializer,
|
||||
custom_deserializer=array_custom_deserializer)
|
||||
|
||||
|
||||
def serialize(obj):
|
||||
"""This is the callback that will be used by numbuf.
|
||||
|
||||
@@ -120,14 +116,16 @@ def serialize(obj):
|
||||
A dictionary that has the key "_pyttype_" to identify the class, and
|
||||
contains all information needed to reconstruct the object.
|
||||
"""
|
||||
class_id = class_identifier(type(obj))
|
||||
if class_id not in whitelisted_classes:
|
||||
raise Exception("Ray does not know how to serialize objects of type {}. "
|
||||
"To fix this, call 'ray.register_class' with this class."
|
||||
.format(type(obj)))
|
||||
if type(obj) not in type_to_class_id:
|
||||
raise RaySerializationException("Ray does not know how to serialize "
|
||||
"objects of type {}.".format(type(obj)),
|
||||
obj)
|
||||
class_id = type_to_class_id[type(obj)]
|
||||
|
||||
if class_id in classes_to_pickle:
|
||||
serialized_obj = {"data": pickling.dumps(obj)}
|
||||
elif class_id in custom_serializers.keys():
|
||||
serialized_obj = {"data": pickling.dumps(obj),
|
||||
"pickle": True}
|
||||
elif class_id in custom_serializers:
|
||||
serialized_obj = {"data": custom_serializers[class_id](obj)}
|
||||
else:
|
||||
# Handle the namedtuple case.
|
||||
@@ -137,8 +135,8 @@ def serialize(obj):
|
||||
elif hasattr(obj, "__dict__"):
|
||||
serialized_obj = obj.__dict__
|
||||
else:
|
||||
raise Exception("We do not know how to serialize the object '{}'"
|
||||
.format(obj))
|
||||
raise RaySerializationException("We do not know how to serialize the "
|
||||
"object '{}'".format(obj), obj)
|
||||
result = dict(serialized_obj, **{"_pytype_": class_id})
|
||||
return result
|
||||
|
||||
@@ -154,21 +152,36 @@ def deserialize(serialized_obj):
|
||||
|
||||
Returns:
|
||||
A Python object.
|
||||
|
||||
Raises:
|
||||
An exception is raised if we do not know how to deserialize the object.
|
||||
"""
|
||||
class_id = serialized_obj["_pytype_"]
|
||||
cls = whitelisted_classes[class_id]
|
||||
if class_id in classes_to_pickle:
|
||||
|
||||
if "pickle" in serialized_obj:
|
||||
# The object was pickled, so unpickle it.
|
||||
obj = pickling.loads(serialized_obj["data"])
|
||||
elif class_id in custom_deserializers.keys():
|
||||
obj = custom_deserializers[class_id](serialized_obj["data"])
|
||||
else:
|
||||
# In this case, serialized_obj should just be the __dict__ field.
|
||||
if "_ray_getnewargs_" in serialized_obj:
|
||||
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
|
||||
assert class_id not in classes_to_pickle
|
||||
if class_id not in whitelisted_classes:
|
||||
# If this happens, that means that the call to _register_class, which
|
||||
# should have added the class to the list of whitelisted classes, has not
|
||||
# yet propagated to this worker. It should happen if we wait a little
|
||||
# longer.
|
||||
raise RayDeserializationException("The class {} is not one of the "
|
||||
"whitelisted classes."
|
||||
.format(class_id), class_id)
|
||||
cls = whitelisted_classes[class_id]
|
||||
if class_id in custom_deserializers:
|
||||
obj = custom_deserializers[class_id](serialized_obj["data"])
|
||||
else:
|
||||
obj = cls.__new__(cls)
|
||||
serialized_obj.pop("_pytype_")
|
||||
obj.__dict__.update(serialized_obj)
|
||||
# In this case, serialized_obj should just be the __dict__ field.
|
||||
if "_ray_getnewargs_" in serialized_obj:
|
||||
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
|
||||
else:
|
||||
obj = cls.__new__(cls)
|
||||
serialized_obj.pop("_pytype_")
|
||||
obj.__dict__.update(serialized_obj)
|
||||
return obj
|
||||
|
||||
|
||||
@@ -181,3 +194,11 @@ def set_callbacks():
|
||||
callback.
|
||||
"""
|
||||
ray.numbuf.register_callbacks(serialize, deserialize)
|
||||
|
||||
|
||||
def clear_state():
|
||||
type_to_class_id.clear()
|
||||
whitelisted_classes.clear()
|
||||
classes_to_pickle.clear()
|
||||
custom_serializers.clear()
|
||||
custom_deserializers.clear()
|
||||
|
||||
Reference in New Issue
Block a user