Remove register_class from API. (#550)

* Perform ray.register_class under the hood.

* Fix bug.

* Release worker lock when waiting for imports to arrive in get.

* Remove calls to register_class from examples and tests.

* Clear serialization state between tests.

* Fix bug and add test for multiple custom classes with same name.

* Fix failure test.

* Fix linting and cleanups to python code.

* Fixes to documentation.

* Implement recursion depth for recursively registering classes.

* Fix linting.

* Push warning to user if waiting for class for too long.

* Fix typos.

* Don't export FunctionToRun if pickling the function fails.

* Don't broadcast class definition when pickling class.
This commit is contained in:
Robert Nishihara
2017-05-16 18:38:52 -07:00
committed by Philipp Moritz
parent 3ebfd850e1
commit ec2534422b
11 changed files with 378 additions and 304 deletions
-2
View File
@@ -220,8 +220,6 @@ We can put this all together as follows.
# Load the MNIST dataset and tell Ray how to serialize the custom classes.
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
ray.register_class(type(mnist))
ray.register_class(type(mnist.train))
# Create the actor.
nn = NeuralNetOnGPU.remote(mnist)
+11 -83
View File
@@ -9,16 +9,15 @@ store.
1. The return values of a remote function.
2. The value ``x`` in a call to ``ray.put(x)``.
3. Large objects or objects other than simple primitive types that are passed
as arguments into remote functions.
3. Arguments to remote functions (except for simple arguments like ints or
floats).
A Python object may have an arbitrary number of pointers with arbitrarily deep
nesting. To place an object in the object store or send it between processes,
it must first be converted to a contiguous string of bytes. This process is
known as serialization. The process of converting the string of bytes back into a
Python object is known as deserialization. Serialization and deserialization
are often bottlenecks in distributed computing if the time needed to compute
on the data is relatively low.
are often bottlenecks in distributed computing.
Pickle is one example of a library for serialization and deserialization in
Python.
@@ -40,9 +39,8 @@ overheads, even when all processes are read-only and could easily share memory.
In Ray, we optimize for numpy arrays by using the `Apache Arrow`_ data format.
When we deserialize a list of numpy arrays from the object store, we still
create a Python list of numpy array objects. However, rather than copy each
numpy array over again, each numpy array object holds a pointer to the relevant
array held in shared memory. There are some advantages to this form of
serialization.
numpy array, each numpy array object holds a pointer to the relevant array held
in shared memory. There are some advantages to this form of serialization.
- Deserialization can be very fast.
- Memory is shared between processes so worker processes can all read the same
@@ -54,84 +52,17 @@ What Objects Does Ray Handle
----------------------------
Ray does not currently support serialization of arbitrary Python objects. The
set of Python objects that Ray can serialize includes the following.
set of Python objects that Ray can serialize using Arrow includes the following.
1. Primitive types: ints, floats, longs, bools, strings, unicode, and numpy
arrays.
2. Any list, dictionary, or tuple whose elements can be serialized by Ray.
3. Objects whose classes can be registered with ``ray.register_class``. This
point is described below.
Registering Custom Classes
--------------------------
We currently support serializing a limited subset of custom classes. For
example, suppose you define a new class ``Foo`` as follows.
.. code-block:: python
class Foo(object):
def __init__(self, a, b):
self.a = a
self.b = b
Simply calling ``ray.put(Foo(1, 2))`` will fail with a message like
.. code-block:: python
Ray does not know how to serialize the object <__main__.Foo object at 0x1077d7c50>.
This can be addressed by calling ``ray.register_class(Foo)``.
.. code-block:: python
import ray
ray.init()
# Define a custom class.
class Foo(object):
def __init__(self, a, b):
self.a = a
self.b = b
# Calling ray.register_class(Foo) ships the class definition to all of the
# workers so that workers know how to construct new Foo objects.
ray.register_class(Foo)
# Create a Foo object, place it in the object store, and retrieve it.
f = Foo(1, 2)
f_id = ray.put(f)
ray.get(f_id) # prints <__main__.Foo at 0x1078128d0>
Under the hood, ``ray.put`` places ``f.__dict__``, the dictionary of attributes
of ``f``, into the object store instead of ``f`` itself. In this case, this is
the dictionary, ``{"a": 1, "b": 2}``. Then during deserialization, ``ray.get``
constructs a new ``Foo`` object from the dictionary of fields.
This naive substitution won't work in all cases. For example, this scheme does
not support Python objects of type ``function`` (e.g., ``f = lambda x: x +
1``). In these cases, the call to ``ray.register_class`` will give an error
message, and you should fall back to pickle.
.. code-block:: python
# This call tells Ray to fall back to using pickle when it encounters objects
# of type function (we actually already do this under the hood).
f = lambda x: x + 1
ray.register_class(type(f), pickle=True)
f_new = ray.get(ray.put(f))
f_new(0) # prints 1
However, it's best to avoid using pickle for the efficiency reasons described
above. If you find yourself needing to pickle certain objects, consider trying
to use more efficient data structures like arrays.
**Note:** Another setting where the naive replacement of an object with its
``__dict__`` attribute fails is recursion, e.g., an object contains itself or
multiple objects contain each other. To see more examples of this, see the
section `Notes and Limitations`_.
For a more general object, Ray will first attempt to serialize the object by
unpacking the object as a dictionary of its fields. This behavior is not
correct in all cases. If Ray cannot serialize the object as a dictionary of its
fields, Ray will fall back to using pickle. However, using pickle will likely
be inefficient.
Notes and limitations
---------------------
@@ -167,9 +98,6 @@ Notes and limitations
This object exceeds the maximum recursion depth. It may contain itself recursively.
- If you need to pass a custom class into a remote function, you should call
``ray.register_class`` on the class **before defining the remote function**.
- Whenever possible, use numpy arrays for maximum performance.
Last Resort Workaround
@@ -154,10 +154,6 @@ if __name__ == "__main__":
ray.init(redis_address=args.redis_address,
num_workers=(0 if args.redis_address is None else None))
# Tell Ray to serialize Config and Result objects.
ray.register_class(Config)
ray.register_class(Result)
config = Config(l2coeff=0.005,
noise_stdev=0.02,
episodes_per_batch=10000,
@@ -33,10 +33,6 @@ if __name__ == "__main__":
ray.init(redis_address=args.redis_address)
ray.register_class(AtariRamPreprocessor)
ray.register_class(AtariPixelPreprocessor)
ray.register_class(NoPreprocessor)
mdp_name = args.environment
if args.environment == "Pong-v0":
preprocessor = AtariPixelPreprocessor()
@@ -72,10 +72,6 @@ class DistArray(object):
return a[sliced]
# Register the DistArray class with Ray so that it knows how to serialize it.
ray.register_class(DistArray)
@ray.remote
def assemble(a):
return a.assemble()
+85 -64
View File
@@ -2,12 +2,26 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import ray.numbuf
import ray.pickling as pickling
class RaySerializationException(Exception):
def __init__(self, message, example_object):
Exception.__init__(self, message)
self.example_object = example_object
class RayDeserializationException(Exception):
def __init__(self, message, class_id):
Exception.__init__(self, message)
self.class_id = class_id
class RayNotDictionarySerializable(Exception):
pass
def check_serializable(cls):
"""Throws an exception if Ray cannot serialize this class efficiently.
@@ -22,42 +36,38 @@ def check_serializable(cls):
# This case works.
return
if not hasattr(cls, "__new__"):
raise Exception("The class {} does not have a '__new__' attribute, and is "
"probably an old-style class. We do not support this. "
"Please either make it a new-style class by inheriting "
"from 'object', or use "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
print("The class {} does not have a '__new__' attribute and is probably "
"an old-stye class. Please make it a new-style class by inheriting "
"from 'object'.")
raise RayNotDictionarySerializable("The class {} does not have a "
"'__new__' attribute and is probably "
"an old-style class. We do not support "
"this. Please make it a new-style "
"class by inheriting from 'object'."
.format(cls))
try:
obj = cls.__new__(cls)
except:
raise Exception("The class {} has overridden '__new__', so Ray may not be "
"able to serialize it efficiently. Try using "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
raise RayNotDictionarySerializable("The class {} has overridden '__new__'"
", so Ray may not be able to serialize "
"it efficiently.".format(cls))
if not hasattr(obj, "__dict__"):
raise Exception("Objects of the class {} do not have a `__dict__` "
"attribute, so Ray cannot serialize it efficiently. Try "
"using 'ray.register_class(cls, pickle=True)'. However, "
"note that pickle is inefficient.".format(cls))
raise RayNotDictionarySerializable("Objects of the class {} do not have a "
"'__dict__' attribute, so Ray cannot "
"serialize it efficiently.".format(cls))
if hasattr(obj, "__slots__"):
raise Exception("The class {} uses '__slots__', so Ray may not be able to "
"serialize it efficiently. Try using "
"'ray.register_class(cls, pickle=True)'. However, note "
"that pickle is inefficient.".format(cls))
raise RayNotDictionarySerializable("The class {} uses '__slots__', so Ray "
"may not be able to serialize it "
"efficiently.".format(cls))
# This field keeps track of a whitelisted set of classes that Ray will
# serialize.
whitelisted_classes = {}
type_to_class_id = dict()
whitelisted_classes = dict()
classes_to_pickle = set()
custom_serializers = {}
custom_deserializers = {}
def class_identifier(typ):
"""Return a string that identifies this type."""
return "{}.{}".format(typ.__module__, typ.__name__)
custom_serializers = dict()
custom_deserializers = dict()
def is_named_tuple(cls):
@@ -71,12 +81,13 @@ def is_named_tuple(cls):
return all(type(n) == str for n in f)
def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
def add_class_to_whitelist(cls, class_id, pickle=False, custom_serializer=None,
custom_deserializer=None):
"""Add cls to the list of classes that we can serialize.
Args:
cls (type): The class that we can serialize.
class_id: A string of bytes used to identify the class.
pickle (bool): True if the serialization should be done with pickle. False
if it should be done efficiently with Ray.
custom_serializer: This argument is optional, but can be provided to
@@ -84,7 +95,7 @@ def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
custom_deserializer: This argument is optional, but can be provided to
deserialize objects of the class in a particular way.
"""
class_id = class_identifier(cls)
type_to_class_id[cls] = class_id
whitelisted_classes[class_id] = cls
if pickle:
classes_to_pickle.add(class_id)
@@ -93,21 +104,6 @@ def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
custom_deserializers[class_id] = custom_deserializer
# Here we define a custom serializer and deserializer for handling numpy
# arrays that contain objects.
def array_custom_serializer(obj):
return obj.tolist(), obj.dtype.str
def array_custom_deserializer(serialized_obj):
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
add_class_to_whitelist(np.ndarray, pickle=False,
custom_serializer=array_custom_serializer,
custom_deserializer=array_custom_deserializer)
def serialize(obj):
"""This is the callback that will be used by numbuf.
@@ -120,14 +116,16 @@ def serialize(obj):
A dictionary that has the key "_pyttype_" to identify the class, and
contains all information needed to reconstruct the object.
"""
class_id = class_identifier(type(obj))
if class_id not in whitelisted_classes:
raise Exception("Ray does not know how to serialize objects of type {}. "
"To fix this, call 'ray.register_class' with this class."
.format(type(obj)))
if type(obj) not in type_to_class_id:
raise RaySerializationException("Ray does not know how to serialize "
"objects of type {}.".format(type(obj)),
obj)
class_id = type_to_class_id[type(obj)]
if class_id in classes_to_pickle:
serialized_obj = {"data": pickling.dumps(obj)}
elif class_id in custom_serializers.keys():
serialized_obj = {"data": pickling.dumps(obj),
"pickle": True}
elif class_id in custom_serializers:
serialized_obj = {"data": custom_serializers[class_id](obj)}
else:
# Handle the namedtuple case.
@@ -137,8 +135,8 @@ def serialize(obj):
elif hasattr(obj, "__dict__"):
serialized_obj = obj.__dict__
else:
raise Exception("We do not know how to serialize the object '{}'"
.format(obj))
raise RaySerializationException("We do not know how to serialize the "
"object '{}'".format(obj), obj)
result = dict(serialized_obj, **{"_pytype_": class_id})
return result
@@ -154,21 +152,36 @@ def deserialize(serialized_obj):
Returns:
A Python object.
Raises:
An exception is raised if we do not know how to deserialize the object.
"""
class_id = serialized_obj["_pytype_"]
cls = whitelisted_classes[class_id]
if class_id in classes_to_pickle:
if "pickle" in serialized_obj:
# The object was pickled, so unpickle it.
obj = pickling.loads(serialized_obj["data"])
elif class_id in custom_deserializers.keys():
obj = custom_deserializers[class_id](serialized_obj["data"])
else:
# In this case, serialized_obj should just be the __dict__ field.
if "_ray_getnewargs_" in serialized_obj:
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
assert class_id not in classes_to_pickle
if class_id not in whitelisted_classes:
# If this happens, that means that the call to _register_class, which
# should have added the class to the list of whitelisted classes, has not
# yet propagated to this worker. It should happen if we wait a little
# longer.
raise RayDeserializationException("The class {} is not one of the "
"whitelisted classes."
.format(class_id), class_id)
cls = whitelisted_classes[class_id]
if class_id in custom_deserializers:
obj = custom_deserializers[class_id](serialized_obj["data"])
else:
obj = cls.__new__(cls)
serialized_obj.pop("_pytype_")
obj.__dict__.update(serialized_obj)
# In this case, serialized_obj should just be the __dict__ field.
if "_ray_getnewargs_" in serialized_obj:
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
else:
obj = cls.__new__(cls)
serialized_obj.pop("_pytype_")
obj.__dict__.update(serialized_obj)
return obj
@@ -181,3 +194,11 @@ def set_callbacks():
callback.
"""
ray.numbuf.register_callbacks(serialize, deserialize)
def clear_state():
type_to_class_id.clear()
whitelisted_classes.clear()
classes_to_pickle.clear()
custom_serializers.clear()
custom_deserializers.clear()
-10
View File
@@ -107,13 +107,3 @@ def python_mode_g(x):
@ray.remote
def no_op():
pass
class TestClass(object):
def __init__(self):
self.a = 5
@ray.remote
def test_unknown_type():
return TestClass()
+160 -45
View File
@@ -479,27 +479,70 @@ class Worker(object):
self.mode = mode
colorama.init()
def put_object(self, objectid, value):
def store_and_register(self, object_id, value, depth=100):
"""Store an object and attempt to register its class if needed.
Args:
object_id: The ID of the object to store.
value: The value to put in the object store.
depth: The maximum number of classes to recursively register.
Raises:
Exception: An exception is raised if the attempt to store the object
fails. This can happen if there is already an object with the same ID
in the object store or if the object store is full.
"""
counter = 0
while True:
if counter == depth:
raise Exception("Ray exceeded the maximum number of classes that it "
"will recursively serialize when attempting to "
"serialize an object of type {}.".format(type(value)))
counter += 1
try:
ray.numbuf.store_list(object_id.id(), self.plasma_client.conn, [value])
break
except serialization.RaySerializationException as e:
try:
_register_class(type(e.example_object))
warning_message = ("WARNING: Serializing objects of type {} by "
"expanding them as dictionaries of their fields. "
"This behavior may be incorrect in some cases."
.format(type(e.example_object)))
print(warning_message)
except serialization.RayNotDictionarySerializable:
_register_class(type(e.example_object), pickle=True)
warning_message = ("WARNING: Falling back to serializing objects of "
"type {} by using pickle. This may be "
"inefficient.".format(type(e.example_object)))
print(warning_message)
def put_object(self, object_id, value):
"""Put value in the local object store with object id objectid.
This assumes that the value for objectid has not yet been placed in the
local object store.
Args:
objectid (object_id.ObjectID): The object ID of the value to be put.
object_id (object_id.ObjectID): The object ID of the value to be put.
value: The value to put in the object store.
Raises:
Exception: An exception is raised if the attempt to store the object
fails. This can happen if there is already an object with the same ID
in the object store or if the object store is full.
"""
# Make sure that the value is not an object ID.
if isinstance(value, ray.local_scheduler.ObjectID):
raise Exception("Calling `put` on an ObjectID is not allowed "
raise Exception("Calling 'put' on an ObjectID is not allowed "
"(similarly, returning an ObjectID from a remote "
"function is not allowed). If you really want to do "
"this, you can wrap the ObjectID in a list and call "
"`put` on it (or return it).")
"'put' on it (or return it).")
# Serialize and put the object in the object store.
try:
ray.numbuf.store_list(objectid.id(), self.plasma_client.conn, [value])
self.store_and_register(object_id, value)
except ray.numbuf.numbuf_plasma_object_exists_error as e:
# The object already exists in the object store, so there is no need to
# add it again. TODO(rkn): We need to compare the hashes and make sure
@@ -511,6 +554,38 @@ class Worker(object):
# Optionally do something with the contained_objectids here.
contained_objectids = []
def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10):
start_time = time.time()
# Only send the warning once.
warning_sent = False
while True:
try:
results = ray.numbuf.retrieve_list(
object_ids,
self.plasma_client.conn,
timeout)
return results
except serialization.RayDeserializationException as e:
# Wait a little bit for the import thread to import the class. If we
# currently have the worker lock, we need to release it so that the
# import thread can acquire it.
if self.mode == WORKER_MODE:
self.lock.release()
time.sleep(0.01)
if self.mode == WORKER_MODE:
self.lock.acquire()
if time.time() - start_time > error_timeout:
warning_message = ("This worker or driver is waiting to receive a "
"class definition so that it can deserialize an "
"object from the object store. This may be fine, "
"or it may be a bug.")
if not warning_sent:
self.push_error_to_driver(self.task_driver_id.id(),
"wait_for_class",
warning_message)
warning_sent = True
def get_object(self, object_ids):
"""Get the value or values in the object store associated with object_ids.
@@ -531,10 +606,8 @@ class Worker(object):
self.plasma_client.fetch([object_id.id() for object_id in object_ids])
# Get the objects. We initially try to get the objects immediately.
final_results = ray.numbuf.retrieve_list(
[object_id.id() for object_id in object_ids],
self.plasma_client.conn,
0)
final_results = self.retrieve_and_deserialize(
[object_id.id() for object_id in object_ids], 0)
# Construct a dictionary mapping object IDs that we haven't gotten yet to
# their original index in the object_ids argument.
unready_ids = dict((object_id, i) for (i, (object_id, val)) in
@@ -548,9 +621,8 @@ class Worker(object):
# Do another fetch for objects that aren't available locally yet, in case
# they were evicted since the last fetch.
self.plasma_client.fetch(list(unready_ids.keys()))
results = ray.numbuf.retrieve_list(list(unready_ids.keys()),
self.plasma_client.conn,
GET_TIMEOUT_MILLISECONDS)
results = self.retrieve_and_deserialize(list(unready_ids.keys()),
GET_TIMEOUT_MILLISECONDS)
# Remove any entries for objects we received during this iteration so we
# don't retrieve the same object twice.
for object_id, val in results:
@@ -634,14 +706,16 @@ class Worker(object):
not be used.
"""
check_main_thread()
if self.mode not in [None, SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]:
raise Exception("run_function_on_all_workers can only be called on a "
"driver.")
# If ray.init has not been called yet, then cache the function and export
# it when connect is called. Otherwise, run the function on all workers.
if self.mode is None:
self.cached_functions_to_run.append(function)
else:
# Attempt to pickle the function before we need it. This could fail, and
# it is more convenient if the failure happens before we actually run the
# function locally.
pickled_function = pickling.dumps(function)
function_to_run_id = random_string()
key = "FunctionsToRun:{}".format(function_to_run_id)
# First run the function on the driver. Pass in the number of workers on
@@ -652,7 +726,7 @@ class Worker(object):
# Run the function on all workers.
self.redis_client.hmset(key, {"driver_id": self.task_driver_id.id(),
"function_id": function_to_run_id,
"function": pickling.dumps(function)})
"function": pickled_function})
self.redis_client.rpush("Exports", key)
def push_error_to_driver(self, driver_id, error_type, message, data=None):
@@ -808,23 +882,37 @@ def initialize_numbuf(worker=global_worker):
def objectid_custom_deserializer(serialized_obj):
return ray.local_scheduler.ObjectID(serialized_obj)
serialization.add_class_to_whitelist(
ray.local_scheduler.ObjectID, pickle=False,
ray.local_scheduler.ObjectID, 20 * b"\x00", pickle=False,
custom_serializer=objectid_custom_serializer,
custom_deserializer=objectid_custom_deserializer)
# Define a custom serializer and deserializer for handling numpy arrays that
# contain objects.
def array_custom_serializer(obj):
return obj.tolist(), obj.dtype.str
def array_custom_deserializer(serialized_obj):
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
serialization.add_class_to_whitelist(
np.ndarray, 20 * b"\x01", pickle=False,
custom_serializer=array_custom_serializer,
custom_deserializer=array_custom_deserializer)
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
# These should only be called on the driver because register_class will
# These should only be called on the driver because _register_class will
# export the class to all of the workers.
register_class(RayTaskError)
register_class(RayGetError)
register_class(RayGetArgumentError)
_register_class(RayTaskError)
_register_class(RayGetError)
_register_class(RayGetArgumentError)
# Tell Ray to serialize lambdas with pickle.
register_class(type(lambda: 0), pickle=True)
_register_class(type(lambda: 0), pickle=True)
# Tell Ray to serialize sets with pickle.
register_class(type(set()), pickle=True)
_register_class(type(set()), pickle=True)
# Tell Ray to serialize types with pickle.
register_class(type(int), pickle=True)
_register_class(type(int), pickle=True)
def get_address_info_from_redis_helper(redis_address, node_ip_address):
@@ -1279,7 +1367,7 @@ def fetch_and_execute_function_to_run(key, worker=global_worker):
data={"name": name})
def import_thread(worker):
def import_thread(worker, mode):
worker.import_pubsub_client = worker.redis_client.pubsub()
# Exports that are published after the call to
# import_pubsub_client.psubscribe and before the call to
@@ -1292,6 +1380,16 @@ def import_thread(worker):
with worker.lock:
export_keys = worker.redis_client.lrange("Exports", 0, -1)
for key in export_keys:
num_imported += 1
# Handle the driver case first.
if mode != WORKER_MODE:
if key.startswith(b"FunctionsToRun"):
fetch_and_execute_function_to_run(key, worker=worker)
# Continue because FunctionsToRun are the only things that the driver
# should import.
continue
if key.startswith(b"RemoteFunction"):
fetch_and_register_remote_function(key, worker=worker)
elif key.startswith(b"EnvironmentVariables"):
@@ -1306,7 +1404,6 @@ def import_thread(worker):
worker.fetch_and_register_actor(key, worker)
else:
raise Exception("This code should be unreachable.")
num_imported += 1
try:
for msg in worker.import_pubsub_client.listen():
@@ -1317,7 +1414,18 @@ def import_thread(worker):
num_imports = worker.redis_client.llen("Exports")
assert num_imports >= num_imported
for i in range(num_imported, num_imports):
num_imported += 1
key = worker.redis_client.lindex("Exports", i)
# Handle the driver case first.
if mode != WORKER_MODE:
if key.startswith(b"FunctionsToRun"):
with log_span("ray:import_function_to_run", worker=worker):
fetch_and_execute_function_to_run(key, worker=worker)
# Continue because FunctionsToRun are the only things that the
# driver should import.
continue
if key.startswith(b"RemoteFunction"):
with log_span("ray:import_remote_function", worker=worker):
fetch_and_register_remote_function(key, worker=worker)
@@ -1335,7 +1443,6 @@ def import_thread(worker):
worker.fetch_and_register["Actor"](key, worker)
else:
raise Exception("This code should be unreachable.")
num_imported += 1
except redis.ConnectionError:
# When Redis terminates the listen call will throw a ConnectionError, which
# we catch here.
@@ -1486,12 +1593,14 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
class_id = worker.redis_client.hget(actor_key, "class_id")
worker.class_id = class_id
# If this is a worker, then start a thread to import exports from the driver.
if mode == WORKER_MODE:
t = threading.Thread(target=import_thread, args=(worker,))
# Making the thread a daemon causes it to exit when the main thread exits.
t.daemon = True
t.start()
# Start a thread to import exports from the driver or from other workers.
# Note that the driver also has an import thread, which is used only to
# import custom class definitions from calls to _register_class that happen
# under the hood on workers.
t = threading.Thread(target=import_thread, args=(worker, mode))
# Making the thread a daemon causes it to exit when the main thread exits.
t.daemon = True
t.start()
# If this is a driver running in SCRIPT_MODE, start a thread to print error
# messages asynchronously in the background. Ideally the scheduler would push
@@ -1554,9 +1663,15 @@ def disconnect(worker=global_worker):
worker.cached_functions_to_run = []
worker.cached_remote_functions = []
env._cached_environment_variables = []
serialization.clear_state()
def register_class(cls, pickle=False, worker=global_worker):
raise Exception("The function ray.register_class is deprecated. It should "
"be safe to remove any calls to this function.")
def _register_class(cls, pickle=False, worker=global_worker):
"""Enable workers to serialize or deserialize objects of a particular class.
This method runs the register_class function defined below on every worker,
@@ -1573,18 +1688,19 @@ def register_class(cls, pickle=False, worker=global_worker):
Exception: An exception is raised if pickle=False and the class cannot be
efficiently serialized by Ray.
"""
# If the worker is not a driver, then return. We do this so that Python
# modules can register classes and these modules can be imported on workers
# without any trouble.
if worker.mode == WORKER_MODE:
return
# Raise an exception if cls cannot be serialized efficiently by Ray.
if not pickle:
serialization.check_serializable(cls)
class_id = random_string()
def register_class_for_serialization(worker_info):
serialization.add_class_to_whitelist(cls, pickle=pickle)
worker.run_function_on_all_workers(register_class_for_serialization)
serialization.add_class_to_whitelist(cls, class_id, pickle=pickle)
if not pickle:
# Raise an exception if cls cannot be serialized efficiently by Ray.
serialization.check_serializable(cls)
worker.run_function_on_all_workers(register_class_for_serialization)
else:
# Since we are pickling objects of this class, we don't actually need to
# ship the class definition.
register_class_for_serialization({})
class RayLogSpan(object):
@@ -1778,7 +1894,6 @@ def wait_for_function(function_id, driver_id, timeout=10,
start_time = time.time()
# Only send the warning once.
warning_sent = False
num_warnings_sent = 0
while True:
with worker.lock:
if worker.actor_id == NIL_ACTOR_ID and (function_id.id() in
@@ -1787,7 +1902,7 @@ def wait_for_function(function_id, driver_id, timeout=10,
elif worker.actor_id != NIL_ACTOR_ID and (worker.actor_id in
worker.actors):
break
if time.time() - start_time > timeout * (num_warnings_sent + 1):
if time.time() - start_time > timeout:
warning_message = ("This worker was asked to execute a function that "
"it does not have registered. You may have to "
"restart Ray.")
-1
View File
@@ -141,7 +141,6 @@ class ActorAPI(unittest.TestCase):
class Foo(object):
def __init__(self, x):
self.x = x
ray.register_class(Foo)
@ray.remote
class Actor(object):
-36
View File
@@ -28,42 +28,6 @@ def wait_for_errors(error_type, num_errors, timeout=10):
print("Timing out of wait.")
class FailureTest(unittest.TestCase):
def testUnknownSerialization(self):
reload(test_functions)
ray.init(num_workers=1, driver_mode=ray.SILENT_MODE)
test_functions.test_unknown_type.remote()
wait_for_errors(b"task", 1)
self.assertEqual(len(relevant_errors(b"task")), 1)
ray.worker.cleanup()
class TaskSerializationTest(unittest.TestCase):
def testReturnAndPassUnknownType(self):
ray.init(num_workers=1, driver_mode=ray.SILENT_MODE)
class Foo(object):
pass
# Check that returning an unknown type from a remote function raises an
# exception.
@ray.remote
def f():
return Foo()
self.assertRaises(Exception, lambda: ray.get(f.remote()))
# Check that passing an unknown type into a remote function raises an
# exception.
@ray.remote
def g(x):
return 1
self.assertRaises(Exception, lambda: g.remote(Foo()))
ray.worker.cleanup()
class TaskStatusTest(unittest.TestCase):
def testFailedTask(self):
reload(test_functions)
+122 -51
View File
@@ -2,15 +2,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import ray
from collections import defaultdict, namedtuple
import numpy as np
import time
import os
import ray
import re
import shutil
import string
import sys
from collections import defaultdict, namedtuple
import time
import unittest
import ray.test.test_functions as test_functions
@@ -169,8 +170,6 @@ class SerializationTest(unittest.TestCase):
class ClassA(object):
pass
ray.register_class(ClassA)
# Make a list that contains itself.
l = []
l.append(l)
@@ -201,14 +200,6 @@ class SerializationTest(unittest.TestCase):
def f(x):
return x
ray.register_class(Exception)
ray.register_class(CustomError)
ray.register_class(Point)
ray.register_class(Foo)
ray.register_class(Bar)
ray.register_class(Baz)
ray.register_class(NamedTupleExample)
# Check that we can pass arguments by value to remote functions and that
# they are uncorrupted.
for obj in RAY_TEST_OBJECTS:
@@ -303,20 +294,129 @@ class WorkerTest(unittest.TestCase):
class APITest(unittest.TestCase):
def testRegisterClass(self):
ray.init(num_workers=0)
ray.init(num_workers=2)
# Check that putting an object of a class that has not been registered
# throws an exception.
class TempClass(object):
pass
self.assertRaises(Exception, lambda: ray.put(TempClass()))
# Check that registering a class that Ray cannot serialize efficiently
# raises an exception.
self.assertRaises(Exception, lambda: ray.register_class(defaultdict))
# Check that registering the same class with pickle works.
ray.register_class(defaultdict, pickle=True)
ray.get(ray.put(TempClass()))
# Note that the below actually returns a dictionary and not a defaultdict.
# This is a bug (https://github.com/ray-project/ray/issues/512).
ray.get(ray.put(defaultdict(lambda: 0)))
# Test passing custom classes into remote functions from the driver.
@ray.remote
def f(x):
return x
foo = ray.get(f.remote(Foo(7)))
self.assertEqual(foo, Foo(7))
regex = re.compile(r"\d+\.\d*")
new_regex = ray.get(f.remote(regex))
self.assertEqual(regex, new_regex)
# Test returning custom classes created on workers.
@ray.remote
def g():
return SubQux(), Qux()
subqux, qux = ray.get(g.remote())
self.assertEqual(subqux.objs[2].foo.value, 0)
# Test exporting custom class definitions from one worker to another when
# the worker is blocked in a get.
class NewTempClass(object):
def __init__(self, value):
self.value = value
@ray.remote
def h1(x):
return NewTempClass(x)
@ray.remote
def h2(x):
return ray.get(h1.remote(x))
self.assertEqual(ray.get(h2.remote(10)).value, 10)
# Test registering multiple classes with the same name.
@ray.remote(num_return_vals=3)
def j():
class Class0(object):
def method0(self):
pass
c0 = Class0()
class Class0(object):
def method1(self):
pass
c1 = Class0()
class Class0(object):
def method2(self):
pass
c2 = Class0()
return c0, c1, c2
results = []
for _ in range(5):
results += j.remote()
for i in range(len(results) // 3):
c0, c1, c2 = ray.get(results[(3 * i):(3 * (i + 1))])
c0.method0()
c1.method1()
c2.method2()
self.assertFalse(hasattr(c0, "method1"))
self.assertFalse(hasattr(c0, "method2"))
self.assertFalse(hasattr(c1, "method0"))
self.assertFalse(hasattr(c1, "method2"))
self.assertFalse(hasattr(c2, "method0"))
self.assertFalse(hasattr(c2, "method1"))
@ray.remote
def k():
class Class0(object):
def method0(self):
pass
c0 = Class0()
class Class0(object):
def method1(self):
pass
c1 = Class0()
class Class0(object):
def method2(self):
pass
c2 = Class0()
return c0, c1, c2
results = ray.get([k.remote() for _ in range(5)])
for c0, c1, c2 in results:
c0.method0()
c1.method1()
c2.method2()
self.assertFalse(hasattr(c0, "method1"))
self.assertFalse(hasattr(c0, "method2"))
self.assertFalse(hasattr(c1, "method0"))
self.assertFalse(hasattr(c1, "method2"))
self.assertFalse(hasattr(c2, "method0"))
self.assertFalse(hasattr(c2, "method1"))
ray.worker.cleanup()
def testKeywordArgs(self):
@@ -666,35 +766,6 @@ class APITest(unittest.TestCase):
ray.worker.cleanup()
def testPassingInfoToAllWorkers(self):
ray.init(num_workers=10, num_cpus=10)
def f(worker_info):
sys.path.append(worker_info)
ray.worker.global_worker.run_function_on_all_workers(f)
@ray.remote
def get_path():
time.sleep(1)
return sys.path
# Retrieve the values that we stored in the worker paths.
paths = ray.get([get_path.remote() for _ in range(10)])
# Add the driver's path to the list.
paths.append(sys.path)
worker_infos = [path[-1] for path in paths]
for worker_info in worker_infos:
self.assertEqual(list(worker_info.keys()), ["counter"])
counters = [worker_info["counter"] for worker_info in worker_infos]
# We use range(11) because the driver also runs the function.
self.assertEqual(set(counters), set(range(11)))
# Clean up the worker paths.
def f(worker_info):
sys.path.pop(-1)
ray.worker.global_worker.run_function_on_all_workers(f)
ray.worker.cleanup()
def testLoggingAPI(self):
ray.init(num_workers=1, driver_mode=ray.SILENT_MODE)