Remove register_class from API. (#550)

* Perform ray.register_class under the hood.

* Fix bug.

* Release worker lock when waiting for imports to arrive in get.

* Remove calls to register_class from examples and tests.

* Clear serialization state between tests.

* Fix bug and add test for multiple custom classes with same name.

* Fix failure test.

* Fix linting and cleanups to python code.

* Fixes to documentation.

* Implement recursion depth for recursively registering classes.

* Fix linting.

* Push warning to user if waiting for class for too long.

* Fix typos.

* Don't export FunctionToRun if pickling the function fails.

* Don't broadcast class definition when pickling class.
This commit is contained in:
Robert Nishihara
2017-05-16 18:38:52 -07:00
committed by Philipp Moritz
parent 3ebfd850e1
commit ec2534422b
11 changed files with 378 additions and 304 deletions
@@ -72,10 +72,6 @@ class DistArray(object):
return a[sliced]
# Register the DistArray class with Ray so that it knows how to serialize it.
ray.register_class(DistArray)
@ray.remote
def assemble(a):
return a.assemble()
+85 -64
View File
@@ -2,12 +2,26 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import ray.numbuf
import ray.pickling as pickling
class RaySerializationException(Exception):
def __init__(self, message, example_object):
Exception.__init__(self, message)
self.example_object = example_object
class RayDeserializationException(Exception):
def __init__(self, message, class_id):
Exception.__init__(self, message)
self.class_id = class_id
class RayNotDictionarySerializable(Exception):
pass
def check_serializable(cls):
"""Throws an exception if Ray cannot serialize this class efficiently.
@@ -22,42 +36,38 @@ def check_serializable(cls):
# This case works.
return
if not hasattr(cls, "__new__"):
raise Exception("The class {} does not have a '__new__' attribute, and is "
"probably an old-style class. We do not support this. "
"Please either make it a new-style class by inheriting "
"from 'object', or use "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
print("The class {} does not have a '__new__' attribute and is probably "
"an old-stye class. Please make it a new-style class by inheriting "
"from 'object'.")
raise RayNotDictionarySerializable("The class {} does not have a "
"'__new__' attribute and is probably "
"an old-style class. We do not support "
"this. Please make it a new-style "
"class by inheriting from 'object'."
.format(cls))
try:
obj = cls.__new__(cls)
except:
raise Exception("The class {} has overridden '__new__', so Ray may not be "
"able to serialize it efficiently. Try using "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
raise RayNotDictionarySerializable("The class {} has overridden '__new__'"
", so Ray may not be able to serialize "
"it efficiently.".format(cls))
if not hasattr(obj, "__dict__"):
raise Exception("Objects of the class {} do not have a `__dict__` "
"attribute, so Ray cannot serialize it efficiently. Try "
"using 'ray.register_class(cls, pickle=True)'. However, "
"note that pickle is inefficient.".format(cls))
raise RayNotDictionarySerializable("Objects of the class {} do not have a "
"'__dict__' attribute, so Ray cannot "
"serialize it efficiently.".format(cls))
if hasattr(obj, "__slots__"):
raise Exception("The class {} uses '__slots__', so Ray may not be able to "
"serialize it efficiently. Try using "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
raise RayNotDictionarySerializable("The class {} uses '__slots__', so Ray "
"may not be able to serialize it "
"efficiently.".format(cls))
# This field keeps track of a whitelisted set of classes that Ray will
# serialize.
whitelisted_classes = {}
type_to_class_id = dict()
whitelisted_classes = dict()
classes_to_pickle = set()
custom_serializers = {}
custom_deserializers = {}
def class_identifier(typ):
"""Return a string that identifies this type."""
return "{}.{}".format(typ.__module__, typ.__name__)
custom_serializers = dict()
custom_deserializers = dict()
def is_named_tuple(cls):
@@ -71,12 +81,13 @@ def is_named_tuple(cls):
return all(type(n) == str for n in f)
def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
def add_class_to_whitelist(cls, class_id, pickle=False, custom_serializer=None,
custom_deserializer=None):
"""Add cls to the list of classes that we can serialize.
Args:
cls (type): The class that we can serialize.
class_id: A string of bytes used to identify the class.
pickle (bool): True if the serialization should be done with pickle. False
if it should be done efficiently with Ray.
custom_serializer: This argument is optional, but can be provided to
@@ -84,7 +95,7 @@ def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
custom_deserializer: This argument is optional, but can be provided to
deserialize objects of the class in a particular way.
"""
class_id = class_identifier(cls)
type_to_class_id[cls] = class_id
whitelisted_classes[class_id] = cls
if pickle:
classes_to_pickle.add(class_id)
@@ -93,21 +104,6 @@ def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
custom_deserializers[class_id] = custom_deserializer
# Here we define a custom serializer and deserializer for handling numpy
# arrays that contain objects.
def array_custom_serializer(obj):
return obj.tolist(), obj.dtype.str
def array_custom_deserializer(serialized_obj):
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
add_class_to_whitelist(np.ndarray, pickle=False,
custom_serializer=array_custom_serializer,
custom_deserializer=array_custom_deserializer)
def serialize(obj):
"""This is the callback that will be used by numbuf.
@@ -120,14 +116,16 @@ def serialize(obj):
A dictionary that has the key "_pyttype_" to identify the class, and
contains all information needed to reconstruct the object.
"""
class_id = class_identifier(type(obj))
if class_id not in whitelisted_classes:
raise Exception("Ray does not know how to serialize objects of type {}. "
"To fix this, call 'ray.register_class' with this class."
.format(type(obj)))
if type(obj) not in type_to_class_id:
raise RaySerializationException("Ray does not know how to serialize "
"objects of type {}.".format(type(obj)),
obj)
class_id = type_to_class_id[type(obj)]
if class_id in classes_to_pickle:
serialized_obj = {"data": pickling.dumps(obj)}
elif class_id in custom_serializers.keys():
serialized_obj = {"data": pickling.dumps(obj),
"pickle": True}
elif class_id in custom_serializers:
serialized_obj = {"data": custom_serializers[class_id](obj)}
else:
# Handle the namedtuple case.
@@ -137,8 +135,8 @@ def serialize(obj):
elif hasattr(obj, "__dict__"):
serialized_obj = obj.__dict__
else:
raise Exception("We do not know how to serialize the object '{}'"
.format(obj))
raise RaySerializationException("We do not know how to serialize the "
"object '{}'".format(obj), obj)
result = dict(serialized_obj, **{"_pytype_": class_id})
return result
@@ -154,21 +152,36 @@ def deserialize(serialized_obj):
Returns:
A Python object.
Raises:
An exception is raised if we do not know how to deserialize the object.
"""
class_id = serialized_obj["_pytype_"]
cls = whitelisted_classes[class_id]
if class_id in classes_to_pickle:
if "pickle" in serialized_obj:
# The object was pickled, so unpickle it.
obj = pickling.loads(serialized_obj["data"])
elif class_id in custom_deserializers.keys():
obj = custom_deserializers[class_id](serialized_obj["data"])
else:
# In this case, serialized_obj should just be the __dict__ field.
if "_ray_getnewargs_" in serialized_obj:
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
assert class_id not in classes_to_pickle
if class_id not in whitelisted_classes:
# If this happens, that means that the call to _register_class, which
# should have added the class to the list of whitelisted classes, has not
# yet propagated to this worker. It should happen if we wait a little
# longer.
raise RayDeserializationException("The class {} is not one of the "
"whitelisted classes."
.format(class_id), class_id)
cls = whitelisted_classes[class_id]
if class_id in custom_deserializers:
obj = custom_deserializers[class_id](serialized_obj["data"])
else:
obj = cls.__new__(cls)
serialized_obj.pop("_pytype_")
obj.__dict__.update(serialized_obj)
# In this case, serialized_obj should just be the __dict__ field.
if "_ray_getnewargs_" in serialized_obj:
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
else:
obj = cls.__new__(cls)
serialized_obj.pop("_pytype_")
obj.__dict__.update(serialized_obj)
return obj
@@ -181,3 +194,11 @@ def set_callbacks():
callback.
"""
ray.numbuf.register_callbacks(serialize, deserialize)
def clear_state():
type_to_class_id.clear()
whitelisted_classes.clear()
classes_to_pickle.clear()
custom_serializers.clear()
custom_deserializers.clear()
-10
View File
@@ -107,13 +107,3 @@ def python_mode_g(x):
@ray.remote
def no_op():
pass
class TestClass(object):
def __init__(self):
self.a = 5
@ray.remote
def test_unknown_type():
return TestClass()
+160 -45
View File
@@ -479,27 +479,70 @@ class Worker(object):
self.mode = mode
colorama.init()
def put_object(self, objectid, value):
def store_and_register(self, object_id, value, depth=100):
"""Store an object and attempt to register its class if needed.
Args:
object_id: The ID of the object to store.
value: The value to put in the object store.
depth: The maximum number of classes to recursively register.
Raises:
Exception: An exception is raised if the attempt to store the object
fails. This can happen if there is already an object with the same ID
in the object store or if the object store is full.
"""
counter = 0
while True:
if counter == depth:
raise Exception("Ray exceeded the maximum number of classes that it "
"will recursively serialize when attempting to "
"serialize an object of type {}.".format(type(value)))
counter += 1
try:
ray.numbuf.store_list(object_id.id(), self.plasma_client.conn, [value])
break
except serialization.RaySerializationException as e:
try:
_register_class(type(e.example_object))
warning_message = ("WARNING: Serializing objects of type {} by "
"expanding them as dictionaries of their fields. "
"This behavior may be incorrect in some cases."
.format(type(e.example_object)))
print(warning_message)
except serialization.RayNotDictionarySerializable:
_register_class(type(e.example_object), pickle=True)
warning_message = ("WARNING: Falling back to serializing objects of "
"type {} by using pickle. This may be "
"inefficient.".format(type(e.example_object)))
print(warning_message)
def put_object(self, object_id, value):
"""Put value in the local object store with object id objectid.
This assumes that the value for objectid has not yet been placed in the
local object store.
Args:
objectid (object_id.ObjectID): The object ID of the value to be put.
object_id (object_id.ObjectID): The object ID of the value to be put.
value: The value to put in the object store.
Raises:
Exception: An exception is raised if the attempt to store the object
fails. This can happen if there is already an object with the same ID
in the object store or if the object store is full.
"""
# Make sure that the value is not an object ID.
if isinstance(value, ray.local_scheduler.ObjectID):
raise Exception("Calling `put` on an ObjectID is not allowed "
raise Exception("Calling 'put' on an ObjectID is not allowed "
"(similarly, returning an ObjectID from a remote "
"function is not allowed). If you really want to do "
"this, you can wrap the ObjectID in a list and call "
"`put` on it (or return it).")
"'put' on it (or return it).")
# Serialize and put the object in the object store.
try:
ray.numbuf.store_list(objectid.id(), self.plasma_client.conn, [value])
self.store_and_register(object_id, value)
except ray.numbuf.numbuf_plasma_object_exists_error as e:
# The object already exists in the object store, so there is no need to
# add it again. TODO(rkn): We need to compare the hashes and make sure
@@ -511,6 +554,38 @@ class Worker(object):
# Optionally do something with the contained_objectids here.
contained_objectids = []
def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10):
start_time = time.time()
# Only send the warning once.
warning_sent = False
while True:
try:
results = ray.numbuf.retrieve_list(
object_ids,
self.plasma_client.conn,
timeout)
return results
except serialization.RayDeserializationException as e:
# Wait a little bit for the import thread to import the class. If we
# currently have the worker lock, we need to release it so that the
# import thread can acquire it.
if self.mode == WORKER_MODE:
self.lock.release()
time.sleep(0.01)
if self.mode == WORKER_MODE:
self.lock.acquire()
if time.time() - start_time > error_timeout:
warning_message = ("This worker or driver is waiting to receive a "
"class definition so that it can deserialize an "
"object from the object store. This may be fine, "
"or it may be a bug.")
if not warning_sent:
self.push_error_to_driver(self.task_driver_id.id(),
"wait_for_class",
warning_message)
warning_sent = True
def get_object(self, object_ids):
"""Get the value or values in the object store associated with object_ids.
@@ -531,10 +606,8 @@ class Worker(object):
self.plasma_client.fetch([object_id.id() for object_id in object_ids])
# Get the objects. We initially try to get the objects immediately.
final_results = ray.numbuf.retrieve_list(
[object_id.id() for object_id in object_ids],
self.plasma_client.conn,
0)
final_results = self.retrieve_and_deserialize(
[object_id.id() for object_id in object_ids], 0)
# Construct a dictionary mapping object IDs that we haven't gotten yet to
# their original index in the object_ids argument.
unready_ids = dict((object_id, i) for (i, (object_id, val)) in
@@ -548,9 +621,8 @@ class Worker(object):
# Do another fetch for objects that aren't available locally yet, in case
# they were evicted since the last fetch.
self.plasma_client.fetch(list(unready_ids.keys()))
results = ray.numbuf.retrieve_list(list(unready_ids.keys()),
self.plasma_client.conn,
GET_TIMEOUT_MILLISECONDS)
results = self.retrieve_and_deserialize(list(unready_ids.keys()),
GET_TIMEOUT_MILLISECONDS)
# Remove any entries for objects we received during this iteration so we
# don't retrieve the same object twice.
for object_id, val in results:
@@ -634,14 +706,16 @@ class Worker(object):
not be used.
"""
check_main_thread()
if self.mode not in [None, SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]:
raise Exception("run_function_on_all_workers can only be called on a "
"driver.")
# If ray.init has not been called yet, then cache the function and export
# it when connect is called. Otherwise, run the function on all workers.
if self.mode is None:
self.cached_functions_to_run.append(function)
else:
# Attempt to pickle the function before we need it. This could fail, and
# it is more convenient if the failure happens before we actually run the
# function locally.
pickled_function = pickling.dumps(function)
function_to_run_id = random_string()
key = "FunctionsToRun:{}".format(function_to_run_id)
# First run the function on the driver. Pass in the number of workers on
@@ -652,7 +726,7 @@ class Worker(object):
# Run the function on all workers.
self.redis_client.hmset(key, {"driver_id": self.task_driver_id.id(),
"function_id": function_to_run_id,
"function": pickling.dumps(function)})
"function": pickled_function})
self.redis_client.rpush("Exports", key)
def push_error_to_driver(self, driver_id, error_type, message, data=None):
@@ -808,23 +882,37 @@ def initialize_numbuf(worker=global_worker):
def objectid_custom_deserializer(serialized_obj):
return ray.local_scheduler.ObjectID(serialized_obj)
serialization.add_class_to_whitelist(
ray.local_scheduler.ObjectID, pickle=False,
ray.local_scheduler.ObjectID, 20 * b"\x00", pickle=False,
custom_serializer=objectid_custom_serializer,
custom_deserializer=objectid_custom_deserializer)
# Define a custom serializer and deserializer for handling numpy arrays that
# contain objects.
def array_custom_serializer(obj):
return obj.tolist(), obj.dtype.str
def array_custom_deserializer(serialized_obj):
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
serialization.add_class_to_whitelist(
np.ndarray, 20 * b"\x01", pickle=False,
custom_serializer=array_custom_serializer,
custom_deserializer=array_custom_deserializer)
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
# These should only be called on the driver because register_class will
# These should only be called on the driver because _register_class will
# export the class to all of the workers.
register_class(RayTaskError)
register_class(RayGetError)
register_class(RayGetArgumentError)
_register_class(RayTaskError)
_register_class(RayGetError)
_register_class(RayGetArgumentError)
# Tell Ray to serialize lambdas with pickle.
register_class(type(lambda: 0), pickle=True)
_register_class(type(lambda: 0), pickle=True)
# Tell Ray to serialize sets with pickle.
register_class(type(set()), pickle=True)
_register_class(type(set()), pickle=True)
# Tell Ray to serialize types with pickle.
register_class(type(int), pickle=True)
_register_class(type(int), pickle=True)
def get_address_info_from_redis_helper(redis_address, node_ip_address):
@@ -1279,7 +1367,7 @@ def fetch_and_execute_function_to_run(key, worker=global_worker):
data={"name": name})
def import_thread(worker):
def import_thread(worker, mode):
worker.import_pubsub_client = worker.redis_client.pubsub()
# Exports that are published after the call to
# import_pubsub_client.psubscribe and before the call to
@@ -1292,6 +1380,16 @@ def import_thread(worker):
with worker.lock:
export_keys = worker.redis_client.lrange("Exports", 0, -1)
for key in export_keys:
num_imported += 1
# Handle the driver case first.
if mode != WORKER_MODE:
if key.startswith(b"FunctionsToRun"):
fetch_and_execute_function_to_run(key, worker=worker)
# Continue because FunctionsToRun are the only things that the driver
# should import.
continue
if key.startswith(b"RemoteFunction"):
fetch_and_register_remote_function(key, worker=worker)
elif key.startswith(b"EnvironmentVariables"):
@@ -1306,7 +1404,6 @@ def import_thread(worker):
worker.fetch_and_register_actor(key, worker)
else:
raise Exception("This code should be unreachable.")
num_imported += 1
try:
for msg in worker.import_pubsub_client.listen():
@@ -1317,7 +1414,18 @@ def import_thread(worker):
num_imports = worker.redis_client.llen("Exports")
assert num_imports >= num_imported
for i in range(num_imported, num_imports):
num_imported += 1
key = worker.redis_client.lindex("Exports", i)
# Handle the driver case first.
if mode != WORKER_MODE:
if key.startswith(b"FunctionsToRun"):
with log_span("ray:import_function_to_run", worker=worker):
fetch_and_execute_function_to_run(key, worker=worker)
# Continue because FunctionsToRun are the only things that the
# driver should import.
continue
if key.startswith(b"RemoteFunction"):
with log_span("ray:import_remote_function", worker=worker):
fetch_and_register_remote_function(key, worker=worker)
@@ -1335,7 +1443,6 @@ def import_thread(worker):
worker.fetch_and_register["Actor"](key, worker)
else:
raise Exception("This code should be unreachable.")
num_imported += 1
except redis.ConnectionError:
# When Redis terminates the listen call will throw a ConnectionError, which
# we catch here.
@@ -1486,12 +1593,14 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
class_id = worker.redis_client.hget(actor_key, "class_id")
worker.class_id = class_id
# If this is a worker, then start a thread to import exports from the driver.
if mode == WORKER_MODE:
t = threading.Thread(target=import_thread, args=(worker,))
# Making the thread a daemon causes it to exit when the main thread exits.
t.daemon = True
t.start()
# Start a thread to import exports from the driver or from other workers.
# Note that the driver also has an import thread, which is used only to
# import custom class definitions from calls to _register_class that happen
# under the hood on workers.
t = threading.Thread(target=import_thread, args=(worker, mode))
# Making the thread a daemon causes it to exit when the main thread exits.
t.daemon = True
t.start()
# If this is a driver running in SCRIPT_MODE, start a thread to print error
# messages asynchronously in the background. Ideally the scheduler would push
@@ -1554,9 +1663,15 @@ def disconnect(worker=global_worker):
worker.cached_functions_to_run = []
worker.cached_remote_functions = []
env._cached_environment_variables = []
serialization.clear_state()
def register_class(cls, pickle=False, worker=global_worker):
raise Exception("The function ray.register_class is deprecated. It should "
"be safe to remove any calls to this function.")
def _register_class(cls, pickle=False, worker=global_worker):
"""Enable workers to serialize or deserialize objects of a particular class.
This method runs the register_class function defined below on every worker,
@@ -1573,18 +1688,19 @@ def register_class(cls, pickle=False, worker=global_worker):
Exception: An exception is raised if pickle=False and the class cannot be
efficiently serialized by Ray.
"""
# If the worker is not a driver, then return. We do this so that Python
# modules can register classes and these modules can be imported on workers
# without any trouble.
if worker.mode == WORKER_MODE:
return
# Raise an exception if cls cannot be serialized efficiently by Ray.
if not pickle:
serialization.check_serializable(cls)
class_id = random_string()
def register_class_for_serialization(worker_info):
serialization.add_class_to_whitelist(cls, pickle=pickle)
worker.run_function_on_all_workers(register_class_for_serialization)
serialization.add_class_to_whitelist(cls, class_id, pickle=pickle)
if not pickle:
# Raise an exception if cls cannot be serialized efficiently by Ray.
serialization.check_serializable(cls)
worker.run_function_on_all_workers(register_class_for_serialization)
else:
# Since we are pickling objects of this class, we don't actually need to
# ship the class definition.
register_class_for_serialization({})
class RayLogSpan(object):
@@ -1778,7 +1894,6 @@ def wait_for_function(function_id, driver_id, timeout=10,
start_time = time.time()
# Only send the warning once.
warning_sent = False
num_warnings_sent = 0
while True:
with worker.lock:
if worker.actor_id == NIL_ACTOR_ID and (function_id.id() in
@@ -1787,7 +1902,7 @@ def wait_for_function(function_id, driver_id, timeout=10,
elif worker.actor_id != NIL_ACTOR_ID and (worker.actor_id in
worker.actors):
break
if time.time() - start_time > timeout * (num_warnings_sent + 1):
if time.time() - start_time > timeout:
warning_message = ("This worker was asked to execute a function that "
"it does not have registered. You may have to "
"restart Ray.")