mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:48:31 +08:00
Remove register_class from API. (#550)
* Perform ray.register_class under the hood. * Fix bug. * Release worker lock when waiting for imports to arrive in get. * Remove calls to register_class from examples and tests. * Clear serialization state between tests. * Fix bug and add test for multiple custom classes with same name. * Fix failure test. * Fix linting and cleanups to python code. * Fixes to documentation. * Implement recursion depth for recursively registering classes. * Fix linting. * Push warning to user if waiting for class for too long. * Fix typos. * Don't export FunctionToRun if pickling the function fails. * Don't broadcast class definition when pickling class.
This commit is contained in:
committed by
Philipp Moritz
parent
3ebfd850e1
commit
ec2534422b
@@ -220,8 +220,6 @@ We can put this all together as follows.
|
||||
|
||||
# Load the MNIST dataset and tell Ray how to serialize the custom classes.
|
||||
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
|
||||
ray.register_class(type(mnist))
|
||||
ray.register_class(type(mnist.train))
|
||||
|
||||
# Create the actor.
|
||||
nn = NeuralNetOnGPU.remote(mnist)
|
||||
|
||||
@@ -9,16 +9,15 @@ store.
|
||||
|
||||
1. The return values of a remote function.
|
||||
2. The value ``x`` in a call to ``ray.put(x)``.
|
||||
3. Large objects or objects other than simple primitive types that are passed
|
||||
as arguments into remote functions.
|
||||
3. Arguments to remote functions (except for simple arguments like ints or
|
||||
floats).
|
||||
|
||||
A Python object may have an arbitrary number of pointers with arbitrarily deep
|
||||
nesting. To place an object in the object store or send it between processes,
|
||||
it must first be converted to a contiguous string of bytes. This process is
|
||||
known as serialization. The process of converting the string of bytes back into a
|
||||
Python object is known as deserialization. Serialization and deserialization
|
||||
are often bottlenecks in distributed computing if the time needed to compute
|
||||
on the data is relatively low.
|
||||
are often bottlenecks in distributed computing.
|
||||
|
||||
Pickle is one example of a library for serialization and deserialization in
|
||||
Python.
|
||||
@@ -40,9 +39,8 @@ overheads, even when all processes are read-only and could easily share memory.
|
||||
In Ray, we optimize for numpy arrays by using the `Apache Arrow`_ data format.
|
||||
When we deserialize a list of numpy arrays from the object store, we still
|
||||
create a Python list of numpy array objects. However, rather than copy each
|
||||
numpy array over again, each numpy array object holds a pointer to the relevant
|
||||
array held in shared memory. There are some advantages to this form of
|
||||
serialization.
|
||||
numpy array, each numpy array object holds a pointer to the relevant array held
|
||||
in shared memory. There are some advantages to this form of serialization.
|
||||
|
||||
- Deserialization can be very fast.
|
||||
- Memory is shared between processes so worker processes can all read the same
|
||||
@@ -54,84 +52,17 @@ What Objects Does Ray Handle
|
||||
----------------------------
|
||||
|
||||
Ray does not currently support serialization of arbitrary Python objects. The
|
||||
set of Python objects that Ray can serialize includes the following.
|
||||
set of Python objects that Ray can serialize using Arrow includes the following.
|
||||
|
||||
1. Primitive types: ints, floats, longs, bools, strings, unicode, and numpy
|
||||
arrays.
|
||||
2. Any list, dictionary, or tuple whose elements can be serialized by Ray.
|
||||
3. Objects whose classes can be registered with ``ray.register_class``. This
|
||||
point is described below.
|
||||
|
||||
Registering Custom Classes
|
||||
--------------------------
|
||||
|
||||
We currently support serializing a limited subset of custom classes. For
|
||||
example, suppose you define a new class ``Foo`` as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Foo(object):
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
Simply calling ``ray.put(Foo(1, 2))`` will fail with a message like
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
Ray does not know how to serialize the object <__main__.Foo object at 0x1077d7c50>.
|
||||
|
||||
This can be addressed by calling ``ray.register_class(Foo)``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
|
||||
# Define a custom class.
|
||||
class Foo(object):
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Calling ray.register_class(Foo) ships the class definition to all of the
|
||||
# workers so that workers know how to construct new Foo objects.
|
||||
ray.register_class(Foo)
|
||||
|
||||
# Create a Foo object, place it in the object store, and retrieve it.
|
||||
f = Foo(1, 2)
|
||||
f_id = ray.put(f)
|
||||
ray.get(f_id) # prints <__main__.Foo at 0x1078128d0>
|
||||
|
||||
Under the hood, ``ray.put`` places ``f.__dict__``, the dictionary of attributes
|
||||
of ``f``, into the object store instead of ``f`` itself. In this case, this is
|
||||
the dictionary, ``{"a": 1, "b": 2}``. Then during deserialization, ``ray.get``
|
||||
constructs a new ``Foo`` object from the dictionary of fields.
|
||||
|
||||
This naive substitution won't work in all cases. For example, this scheme does
|
||||
not support Python objects of type ``function`` (e.g., ``f = lambda x: x +
|
||||
1``). In these cases, the call to ``ray.register_class`` will give an error
|
||||
message, and you should fall back to pickle.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# This call tells Ray to fall back to using pickle when it encounters objects
|
||||
# of type function (we actually already do this under the hood).
|
||||
f = lambda x: x + 1
|
||||
ray.register_class(type(f), pickle=True)
|
||||
|
||||
f_new = ray.get(ray.put(f))
|
||||
f_new(0) # prints 1
|
||||
|
||||
However, it's best to avoid using pickle for the efficiency reasons described
|
||||
above. If you find yourself needing to pickle certain objects, consider trying
|
||||
to use more efficient data structures like arrays.
|
||||
|
||||
**Note:** Another setting where the naive replacement of an object with its
|
||||
``__dict__`` attribute fails is recursion, e.g., an object contains itself or
|
||||
multiple objects contain each other. To see more examples of this, see the
|
||||
section `Notes and Limitations`_.
|
||||
For a more general object, Ray will first attempt to serialize the object by
|
||||
unpacking the object as a dictionary of its fields. This behavior is not
|
||||
correct in all cases. If Ray cannot serialize the object as a dictionary of its
|
||||
fields, Ray will fall back to using pickle. However, using pickle will likely
|
||||
be inefficient.
|
||||
|
||||
Notes and limitations
|
||||
---------------------
|
||||
@@ -167,9 +98,6 @@ Notes and limitations
|
||||
|
||||
This object exceeds the maximum recursion depth. It may contain itself recursively.
|
||||
|
||||
- If you need to pass a custom class into a remote function, you should call
|
||||
``ray.register_class`` on the class **before defining the remote function**.
|
||||
|
||||
- Whenever possible, use numpy arrays for maximum performance.
|
||||
|
||||
Last Resort Workaround
|
||||
|
||||
@@ -154,10 +154,6 @@ if __name__ == "__main__":
|
||||
ray.init(redis_address=args.redis_address,
|
||||
num_workers=(0 if args.redis_address is None else None))
|
||||
|
||||
# Tell Ray to serialize Config and Result objects.
|
||||
ray.register_class(Config)
|
||||
ray.register_class(Result)
|
||||
|
||||
config = Config(l2coeff=0.005,
|
||||
noise_stdev=0.02,
|
||||
episodes_per_batch=10000,
|
||||
|
||||
@@ -33,10 +33,6 @@ if __name__ == "__main__":
|
||||
|
||||
ray.init(redis_address=args.redis_address)
|
||||
|
||||
ray.register_class(AtariRamPreprocessor)
|
||||
ray.register_class(AtariPixelPreprocessor)
|
||||
ray.register_class(NoPreprocessor)
|
||||
|
||||
mdp_name = args.environment
|
||||
if args.environment == "Pong-v0":
|
||||
preprocessor = AtariPixelPreprocessor()
|
||||
|
||||
@@ -72,10 +72,6 @@ class DistArray(object):
|
||||
return a[sliced]
|
||||
|
||||
|
||||
# Register the DistArray class with Ray so that it knows how to serialize it.
|
||||
ray.register_class(DistArray)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def assemble(a):
|
||||
return a.assemble()
|
||||
|
||||
+85
-64
@@ -2,12 +2,26 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray.numbuf
|
||||
import ray.pickling as pickling
|
||||
|
||||
|
||||
class RaySerializationException(Exception):
|
||||
def __init__(self, message, example_object):
|
||||
Exception.__init__(self, message)
|
||||
self.example_object = example_object
|
||||
|
||||
|
||||
class RayDeserializationException(Exception):
|
||||
def __init__(self, message, class_id):
|
||||
Exception.__init__(self, message)
|
||||
self.class_id = class_id
|
||||
|
||||
|
||||
class RayNotDictionarySerializable(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def check_serializable(cls):
|
||||
"""Throws an exception if Ray cannot serialize this class efficiently.
|
||||
|
||||
@@ -22,42 +36,38 @@ def check_serializable(cls):
|
||||
# This case works.
|
||||
return
|
||||
if not hasattr(cls, "__new__"):
|
||||
raise Exception("The class {} does not have a '__new__' attribute, and is "
|
||||
"probably an old-style class. We do not support this. "
|
||||
"Please either make it a new-style class by inheriting "
|
||||
"from 'object', or use "
|
||||
"'ray.register_class(cls, pickle=True)'. However, note "
|
||||
"that pickle is inefficient.".format(cls))
|
||||
print("The class {} does not have a '__new__' attribute and is probably "
|
||||
"an old-stye class. Please make it a new-style class by inheriting "
|
||||
"from 'object'.")
|
||||
raise RayNotDictionarySerializable("The class {} does not have a "
|
||||
"'__new__' attribute and is probably "
|
||||
"an old-style class. We do not support "
|
||||
"this. Please make it a new-style "
|
||||
"class by inheriting from 'object'."
|
||||
.format(cls))
|
||||
try:
|
||||
obj = cls.__new__(cls)
|
||||
except:
|
||||
raise Exception("The class {} has overridden '__new__', so Ray may not be "
|
||||
"able to serialize it efficiently. Try using "
|
||||
"'ray.register_class(cls, pickle=True)'. However, note "
|
||||
"that pickle is inefficient.".format(cls))
|
||||
raise RayNotDictionarySerializable("The class {} has overridden '__new__'"
|
||||
", so Ray may not be able to serialize "
|
||||
"it efficiently.".format(cls))
|
||||
if not hasattr(obj, "__dict__"):
|
||||
raise Exception("Objects of the class {} do not have a `__dict__` "
|
||||
"attribute, so Ray cannot serialize it efficiently. Try "
|
||||
"using 'ray.register_class(cls, pickle=True)'. However, "
|
||||
"note that pickle is inefficient.".format(cls))
|
||||
raise RayNotDictionarySerializable("Objects of the class {} do not have a "
|
||||
"'__dict__' attribute, so Ray cannot "
|
||||
"serialize it efficiently.".format(cls))
|
||||
if hasattr(obj, "__slots__"):
|
||||
raise Exception("The class {} uses '__slots__', so Ray may not be able to "
|
||||
"serialize it efficiently. Try using "
|
||||
"'ray.register_class(cls, pickle=True)'. However, note "
|
||||
"that pickle is inefficient.".format(cls))
|
||||
raise RayNotDictionarySerializable("The class {} uses '__slots__', so Ray "
|
||||
"may not be able to serialize it "
|
||||
"efficiently.".format(cls))
|
||||
|
||||
|
||||
# This field keeps track of a whitelisted set of classes that Ray will
|
||||
# serialize.
|
||||
whitelisted_classes = {}
|
||||
type_to_class_id = dict()
|
||||
whitelisted_classes = dict()
|
||||
classes_to_pickle = set()
|
||||
custom_serializers = {}
|
||||
custom_deserializers = {}
|
||||
|
||||
|
||||
def class_identifier(typ):
|
||||
"""Return a string that identifies this type."""
|
||||
return "{}.{}".format(typ.__module__, typ.__name__)
|
||||
custom_serializers = dict()
|
||||
custom_deserializers = dict()
|
||||
|
||||
|
||||
def is_named_tuple(cls):
|
||||
@@ -71,12 +81,13 @@ def is_named_tuple(cls):
|
||||
return all(type(n) == str for n in f)
|
||||
|
||||
|
||||
def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
|
||||
def add_class_to_whitelist(cls, class_id, pickle=False, custom_serializer=None,
|
||||
custom_deserializer=None):
|
||||
"""Add cls to the list of classes that we can serialize.
|
||||
|
||||
Args:
|
||||
cls (type): The class that we can serialize.
|
||||
class_id: A string of bytes used to identify the class.
|
||||
pickle (bool): True if the serialization should be done with pickle. False
|
||||
if it should be done efficiently with Ray.
|
||||
custom_serializer: This argument is optional, but can be provided to
|
||||
@@ -84,7 +95,7 @@ def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
|
||||
custom_deserializer: This argument is optional, but can be provided to
|
||||
deserialize objects of the class in a particular way.
|
||||
"""
|
||||
class_id = class_identifier(cls)
|
||||
type_to_class_id[cls] = class_id
|
||||
whitelisted_classes[class_id] = cls
|
||||
if pickle:
|
||||
classes_to_pickle.add(class_id)
|
||||
@@ -93,21 +104,6 @@ def add_class_to_whitelist(cls, pickle=False, custom_serializer=None,
|
||||
custom_deserializers[class_id] = custom_deserializer
|
||||
|
||||
|
||||
# Here we define a custom serializer and deserializer for handling numpy
|
||||
# arrays that contain objects.
|
||||
def array_custom_serializer(obj):
|
||||
return obj.tolist(), obj.dtype.str
|
||||
|
||||
|
||||
def array_custom_deserializer(serialized_obj):
|
||||
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
|
||||
|
||||
|
||||
add_class_to_whitelist(np.ndarray, pickle=False,
|
||||
custom_serializer=array_custom_serializer,
|
||||
custom_deserializer=array_custom_deserializer)
|
||||
|
||||
|
||||
def serialize(obj):
|
||||
"""This is the callback that will be used by numbuf.
|
||||
|
||||
@@ -120,14 +116,16 @@ def serialize(obj):
|
||||
A dictionary that has the key "_pyttype_" to identify the class, and
|
||||
contains all information needed to reconstruct the object.
|
||||
"""
|
||||
class_id = class_identifier(type(obj))
|
||||
if class_id not in whitelisted_classes:
|
||||
raise Exception("Ray does not know how to serialize objects of type {}. "
|
||||
"To fix this, call 'ray.register_class' with this class."
|
||||
.format(type(obj)))
|
||||
if type(obj) not in type_to_class_id:
|
||||
raise RaySerializationException("Ray does not know how to serialize "
|
||||
"objects of type {}.".format(type(obj)),
|
||||
obj)
|
||||
class_id = type_to_class_id[type(obj)]
|
||||
|
||||
if class_id in classes_to_pickle:
|
||||
serialized_obj = {"data": pickling.dumps(obj)}
|
||||
elif class_id in custom_serializers.keys():
|
||||
serialized_obj = {"data": pickling.dumps(obj),
|
||||
"pickle": True}
|
||||
elif class_id in custom_serializers:
|
||||
serialized_obj = {"data": custom_serializers[class_id](obj)}
|
||||
else:
|
||||
# Handle the namedtuple case.
|
||||
@@ -137,8 +135,8 @@ def serialize(obj):
|
||||
elif hasattr(obj, "__dict__"):
|
||||
serialized_obj = obj.__dict__
|
||||
else:
|
||||
raise Exception("We do not know how to serialize the object '{}'"
|
||||
.format(obj))
|
||||
raise RaySerializationException("We do not know how to serialize the "
|
||||
"object '{}'".format(obj), obj)
|
||||
result = dict(serialized_obj, **{"_pytype_": class_id})
|
||||
return result
|
||||
|
||||
@@ -154,21 +152,36 @@ def deserialize(serialized_obj):
|
||||
|
||||
Returns:
|
||||
A Python object.
|
||||
|
||||
Raises:
|
||||
An exception is raised if we do not know how to deserialize the object.
|
||||
"""
|
||||
class_id = serialized_obj["_pytype_"]
|
||||
cls = whitelisted_classes[class_id]
|
||||
if class_id in classes_to_pickle:
|
||||
|
||||
if "pickle" in serialized_obj:
|
||||
# The object was pickled, so unpickle it.
|
||||
obj = pickling.loads(serialized_obj["data"])
|
||||
elif class_id in custom_deserializers.keys():
|
||||
obj = custom_deserializers[class_id](serialized_obj["data"])
|
||||
else:
|
||||
# In this case, serialized_obj should just be the __dict__ field.
|
||||
if "_ray_getnewargs_" in serialized_obj:
|
||||
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
|
||||
assert class_id not in classes_to_pickle
|
||||
if class_id not in whitelisted_classes:
|
||||
# If this happens, that means that the call to _register_class, which
|
||||
# should have added the class to the list of whitelisted classes, has not
|
||||
# yet propagated to this worker. It should happen if we wait a little
|
||||
# longer.
|
||||
raise RayDeserializationException("The class {} is not one of the "
|
||||
"whitelisted classes."
|
||||
.format(class_id), class_id)
|
||||
cls = whitelisted_classes[class_id]
|
||||
if class_id in custom_deserializers:
|
||||
obj = custom_deserializers[class_id](serialized_obj["data"])
|
||||
else:
|
||||
obj = cls.__new__(cls)
|
||||
serialized_obj.pop("_pytype_")
|
||||
obj.__dict__.update(serialized_obj)
|
||||
# In this case, serialized_obj should just be the __dict__ field.
|
||||
if "_ray_getnewargs_" in serialized_obj:
|
||||
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
|
||||
else:
|
||||
obj = cls.__new__(cls)
|
||||
serialized_obj.pop("_pytype_")
|
||||
obj.__dict__.update(serialized_obj)
|
||||
return obj
|
||||
|
||||
|
||||
@@ -181,3 +194,11 @@ def set_callbacks():
|
||||
callback.
|
||||
"""
|
||||
ray.numbuf.register_callbacks(serialize, deserialize)
|
||||
|
||||
|
||||
def clear_state():
|
||||
type_to_class_id.clear()
|
||||
whitelisted_classes.clear()
|
||||
classes_to_pickle.clear()
|
||||
custom_serializers.clear()
|
||||
custom_deserializers.clear()
|
||||
|
||||
@@ -107,13 +107,3 @@ def python_mode_g(x):
|
||||
@ray.remote
|
||||
def no_op():
|
||||
pass
|
||||
|
||||
|
||||
class TestClass(object):
|
||||
def __init__(self):
|
||||
self.a = 5
|
||||
|
||||
|
||||
@ray.remote
|
||||
def test_unknown_type():
|
||||
return TestClass()
|
||||
|
||||
+160
-45
@@ -479,27 +479,70 @@ class Worker(object):
|
||||
self.mode = mode
|
||||
colorama.init()
|
||||
|
||||
def put_object(self, objectid, value):
|
||||
def store_and_register(self, object_id, value, depth=100):
|
||||
"""Store an object and attempt to register its class if needed.
|
||||
|
||||
Args:
|
||||
object_id: The ID of the object to store.
|
||||
value: The value to put in the object store.
|
||||
depth: The maximum number of classes to recursively register.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if the attempt to store the object
|
||||
fails. This can happen if there is already an object with the same ID
|
||||
in the object store or if the object store is full.
|
||||
"""
|
||||
counter = 0
|
||||
while True:
|
||||
if counter == depth:
|
||||
raise Exception("Ray exceeded the maximum number of classes that it "
|
||||
"will recursively serialize when attempting to "
|
||||
"serialize an object of type {}.".format(type(value)))
|
||||
counter += 1
|
||||
try:
|
||||
ray.numbuf.store_list(object_id.id(), self.plasma_client.conn, [value])
|
||||
break
|
||||
except serialization.RaySerializationException as e:
|
||||
try:
|
||||
_register_class(type(e.example_object))
|
||||
warning_message = ("WARNING: Serializing objects of type {} by "
|
||||
"expanding them as dictionaries of their fields. "
|
||||
"This behavior may be incorrect in some cases."
|
||||
.format(type(e.example_object)))
|
||||
print(warning_message)
|
||||
except serialization.RayNotDictionarySerializable:
|
||||
_register_class(type(e.example_object), pickle=True)
|
||||
warning_message = ("WARNING: Falling back to serializing objects of "
|
||||
"type {} by using pickle. This may be "
|
||||
"inefficient.".format(type(e.example_object)))
|
||||
print(warning_message)
|
||||
|
||||
def put_object(self, object_id, value):
|
||||
"""Put value in the local object store with object id objectid.
|
||||
|
||||
This assumes that the value for objectid has not yet been placed in the
|
||||
local object store.
|
||||
|
||||
Args:
|
||||
objectid (object_id.ObjectID): The object ID of the value to be put.
|
||||
object_id (object_id.ObjectID): The object ID of the value to be put.
|
||||
value: The value to put in the object store.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if the attempt to store the object
|
||||
fails. This can happen if there is already an object with the same ID
|
||||
in the object store or if the object store is full.
|
||||
"""
|
||||
# Make sure that the value is not an object ID.
|
||||
if isinstance(value, ray.local_scheduler.ObjectID):
|
||||
raise Exception("Calling `put` on an ObjectID is not allowed "
|
||||
raise Exception("Calling 'put' on an ObjectID is not allowed "
|
||||
"(similarly, returning an ObjectID from a remote "
|
||||
"function is not allowed). If you really want to do "
|
||||
"this, you can wrap the ObjectID in a list and call "
|
||||
"`put` on it (or return it).")
|
||||
"'put' on it (or return it).")
|
||||
|
||||
# Serialize and put the object in the object store.
|
||||
try:
|
||||
ray.numbuf.store_list(objectid.id(), self.plasma_client.conn, [value])
|
||||
self.store_and_register(object_id, value)
|
||||
except ray.numbuf.numbuf_plasma_object_exists_error as e:
|
||||
# The object already exists in the object store, so there is no need to
|
||||
# add it again. TODO(rkn): We need to compare the hashes and make sure
|
||||
@@ -511,6 +554,38 @@ class Worker(object):
|
||||
# Optionally do something with the contained_objectids here.
|
||||
contained_objectids = []
|
||||
|
||||
def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10):
|
||||
start_time = time.time()
|
||||
# Only send the warning once.
|
||||
warning_sent = False
|
||||
while True:
|
||||
try:
|
||||
results = ray.numbuf.retrieve_list(
|
||||
object_ids,
|
||||
self.plasma_client.conn,
|
||||
timeout)
|
||||
return results
|
||||
except serialization.RayDeserializationException as e:
|
||||
# Wait a little bit for the import thread to import the class. If we
|
||||
# currently have the worker lock, we need to release it so that the
|
||||
# import thread can acquire it.
|
||||
if self.mode == WORKER_MODE:
|
||||
self.lock.release()
|
||||
time.sleep(0.01)
|
||||
if self.mode == WORKER_MODE:
|
||||
self.lock.acquire()
|
||||
|
||||
if time.time() - start_time > error_timeout:
|
||||
warning_message = ("This worker or driver is waiting to receive a "
|
||||
"class definition so that it can deserialize an "
|
||||
"object from the object store. This may be fine, "
|
||||
"or it may be a bug.")
|
||||
if not warning_sent:
|
||||
self.push_error_to_driver(self.task_driver_id.id(),
|
||||
"wait_for_class",
|
||||
warning_message)
|
||||
warning_sent = True
|
||||
|
||||
def get_object(self, object_ids):
|
||||
"""Get the value or values in the object store associated with object_ids.
|
||||
|
||||
@@ -531,10 +606,8 @@ class Worker(object):
|
||||
self.plasma_client.fetch([object_id.id() for object_id in object_ids])
|
||||
|
||||
# Get the objects. We initially try to get the objects immediately.
|
||||
final_results = ray.numbuf.retrieve_list(
|
||||
[object_id.id() for object_id in object_ids],
|
||||
self.plasma_client.conn,
|
||||
0)
|
||||
final_results = self.retrieve_and_deserialize(
|
||||
[object_id.id() for object_id in object_ids], 0)
|
||||
# Construct a dictionary mapping object IDs that we haven't gotten yet to
|
||||
# their original index in the object_ids argument.
|
||||
unready_ids = dict((object_id, i) for (i, (object_id, val)) in
|
||||
@@ -548,9 +621,8 @@ class Worker(object):
|
||||
# Do another fetch for objects that aren't available locally yet, in case
|
||||
# they were evicted since the last fetch.
|
||||
self.plasma_client.fetch(list(unready_ids.keys()))
|
||||
results = ray.numbuf.retrieve_list(list(unready_ids.keys()),
|
||||
self.plasma_client.conn,
|
||||
GET_TIMEOUT_MILLISECONDS)
|
||||
results = self.retrieve_and_deserialize(list(unready_ids.keys()),
|
||||
GET_TIMEOUT_MILLISECONDS)
|
||||
# Remove any entries for objects we received during this iteration so we
|
||||
# don't retrieve the same object twice.
|
||||
for object_id, val in results:
|
||||
@@ -634,14 +706,16 @@ class Worker(object):
|
||||
not be used.
|
||||
"""
|
||||
check_main_thread()
|
||||
if self.mode not in [None, SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]:
|
||||
raise Exception("run_function_on_all_workers can only be called on a "
|
||||
"driver.")
|
||||
# If ray.init has not been called yet, then cache the function and export
|
||||
# it when connect is called. Otherwise, run the function on all workers.
|
||||
if self.mode is None:
|
||||
self.cached_functions_to_run.append(function)
|
||||
else:
|
||||
# Attempt to pickle the function before we need it. This could fail, and
|
||||
# it is more convenient if the failure happens before we actually run the
|
||||
# function locally.
|
||||
pickled_function = pickling.dumps(function)
|
||||
|
||||
function_to_run_id = random_string()
|
||||
key = "FunctionsToRun:{}".format(function_to_run_id)
|
||||
# First run the function on the driver. Pass in the number of workers on
|
||||
@@ -652,7 +726,7 @@ class Worker(object):
|
||||
# Run the function on all workers.
|
||||
self.redis_client.hmset(key, {"driver_id": self.task_driver_id.id(),
|
||||
"function_id": function_to_run_id,
|
||||
"function": pickling.dumps(function)})
|
||||
"function": pickled_function})
|
||||
self.redis_client.rpush("Exports", key)
|
||||
|
||||
def push_error_to_driver(self, driver_id, error_type, message, data=None):
|
||||
@@ -808,23 +882,37 @@ def initialize_numbuf(worker=global_worker):
|
||||
|
||||
def objectid_custom_deserializer(serialized_obj):
|
||||
return ray.local_scheduler.ObjectID(serialized_obj)
|
||||
|
||||
serialization.add_class_to_whitelist(
|
||||
ray.local_scheduler.ObjectID, pickle=False,
|
||||
ray.local_scheduler.ObjectID, 20 * b"\x00", pickle=False,
|
||||
custom_serializer=objectid_custom_serializer,
|
||||
custom_deserializer=objectid_custom_deserializer)
|
||||
|
||||
# Define a custom serializer and deserializer for handling numpy arrays that
|
||||
# contain objects.
|
||||
def array_custom_serializer(obj):
|
||||
return obj.tolist(), obj.dtype.str
|
||||
|
||||
def array_custom_deserializer(serialized_obj):
|
||||
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
|
||||
|
||||
serialization.add_class_to_whitelist(
|
||||
np.ndarray, 20 * b"\x01", pickle=False,
|
||||
custom_serializer=array_custom_serializer,
|
||||
custom_deserializer=array_custom_deserializer)
|
||||
|
||||
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
||||
# These should only be called on the driver because register_class will
|
||||
# These should only be called on the driver because _register_class will
|
||||
# export the class to all of the workers.
|
||||
register_class(RayTaskError)
|
||||
register_class(RayGetError)
|
||||
register_class(RayGetArgumentError)
|
||||
_register_class(RayTaskError)
|
||||
_register_class(RayGetError)
|
||||
_register_class(RayGetArgumentError)
|
||||
# Tell Ray to serialize lambdas with pickle.
|
||||
register_class(type(lambda: 0), pickle=True)
|
||||
_register_class(type(lambda: 0), pickle=True)
|
||||
# Tell Ray to serialize sets with pickle.
|
||||
register_class(type(set()), pickle=True)
|
||||
_register_class(type(set()), pickle=True)
|
||||
# Tell Ray to serialize types with pickle.
|
||||
register_class(type(int), pickle=True)
|
||||
_register_class(type(int), pickle=True)
|
||||
|
||||
|
||||
def get_address_info_from_redis_helper(redis_address, node_ip_address):
|
||||
@@ -1279,7 +1367,7 @@ def fetch_and_execute_function_to_run(key, worker=global_worker):
|
||||
data={"name": name})
|
||||
|
||||
|
||||
def import_thread(worker):
|
||||
def import_thread(worker, mode):
|
||||
worker.import_pubsub_client = worker.redis_client.pubsub()
|
||||
# Exports that are published after the call to
|
||||
# import_pubsub_client.psubscribe and before the call to
|
||||
@@ -1292,6 +1380,16 @@ def import_thread(worker):
|
||||
with worker.lock:
|
||||
export_keys = worker.redis_client.lrange("Exports", 0, -1)
|
||||
for key in export_keys:
|
||||
num_imported += 1
|
||||
|
||||
# Handle the driver case first.
|
||||
if mode != WORKER_MODE:
|
||||
if key.startswith(b"FunctionsToRun"):
|
||||
fetch_and_execute_function_to_run(key, worker=worker)
|
||||
# Continue because FunctionsToRun are the only things that the driver
|
||||
# should import.
|
||||
continue
|
||||
|
||||
if key.startswith(b"RemoteFunction"):
|
||||
fetch_and_register_remote_function(key, worker=worker)
|
||||
elif key.startswith(b"EnvironmentVariables"):
|
||||
@@ -1306,7 +1404,6 @@ def import_thread(worker):
|
||||
worker.fetch_and_register_actor(key, worker)
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
num_imported += 1
|
||||
|
||||
try:
|
||||
for msg in worker.import_pubsub_client.listen():
|
||||
@@ -1317,7 +1414,18 @@ def import_thread(worker):
|
||||
num_imports = worker.redis_client.llen("Exports")
|
||||
assert num_imports >= num_imported
|
||||
for i in range(num_imported, num_imports):
|
||||
num_imported += 1
|
||||
key = worker.redis_client.lindex("Exports", i)
|
||||
|
||||
# Handle the driver case first.
|
||||
if mode != WORKER_MODE:
|
||||
if key.startswith(b"FunctionsToRun"):
|
||||
with log_span("ray:import_function_to_run", worker=worker):
|
||||
fetch_and_execute_function_to_run(key, worker=worker)
|
||||
# Continue because FunctionsToRun are the only things that the
|
||||
# driver should import.
|
||||
continue
|
||||
|
||||
if key.startswith(b"RemoteFunction"):
|
||||
with log_span("ray:import_remote_function", worker=worker):
|
||||
fetch_and_register_remote_function(key, worker=worker)
|
||||
@@ -1335,7 +1443,6 @@ def import_thread(worker):
|
||||
worker.fetch_and_register["Actor"](key, worker)
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
num_imported += 1
|
||||
except redis.ConnectionError:
|
||||
# When Redis terminates the listen call will throw a ConnectionError, which
|
||||
# we catch here.
|
||||
@@ -1486,12 +1593,14 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
|
||||
class_id = worker.redis_client.hget(actor_key, "class_id")
|
||||
worker.class_id = class_id
|
||||
|
||||
# If this is a worker, then start a thread to import exports from the driver.
|
||||
if mode == WORKER_MODE:
|
||||
t = threading.Thread(target=import_thread, args=(worker,))
|
||||
# Making the thread a daemon causes it to exit when the main thread exits.
|
||||
t.daemon = True
|
||||
t.start()
|
||||
# Start a thread to import exports from the driver or from other workers.
|
||||
# Note that the driver also has an import thread, which is used only to
|
||||
# import custom class definitions from calls to _register_class that happen
|
||||
# under the hood on workers.
|
||||
t = threading.Thread(target=import_thread, args=(worker, mode))
|
||||
# Making the thread a daemon causes it to exit when the main thread exits.
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
# If this is a driver running in SCRIPT_MODE, start a thread to print error
|
||||
# messages asynchronously in the background. Ideally the scheduler would push
|
||||
@@ -1554,9 +1663,15 @@ def disconnect(worker=global_worker):
|
||||
worker.cached_functions_to_run = []
|
||||
worker.cached_remote_functions = []
|
||||
env._cached_environment_variables = []
|
||||
serialization.clear_state()
|
||||
|
||||
|
||||
def register_class(cls, pickle=False, worker=global_worker):
|
||||
raise Exception("The function ray.register_class is deprecated. It should "
|
||||
"be safe to remove any calls to this function.")
|
||||
|
||||
|
||||
def _register_class(cls, pickle=False, worker=global_worker):
|
||||
"""Enable workers to serialize or deserialize objects of a particular class.
|
||||
|
||||
This method runs the register_class function defined below on every worker,
|
||||
@@ -1573,18 +1688,19 @@ def register_class(cls, pickle=False, worker=global_worker):
|
||||
Exception: An exception is raised if pickle=False and the class cannot be
|
||||
efficiently serialized by Ray.
|
||||
"""
|
||||
# If the worker is not a driver, then return. We do this so that Python
|
||||
# modules can register classes and these modules can be imported on workers
|
||||
# without any trouble.
|
||||
if worker.mode == WORKER_MODE:
|
||||
return
|
||||
# Raise an exception if cls cannot be serialized efficiently by Ray.
|
||||
if not pickle:
|
||||
serialization.check_serializable(cls)
|
||||
class_id = random_string()
|
||||
|
||||
def register_class_for_serialization(worker_info):
|
||||
serialization.add_class_to_whitelist(cls, pickle=pickle)
|
||||
worker.run_function_on_all_workers(register_class_for_serialization)
|
||||
serialization.add_class_to_whitelist(cls, class_id, pickle=pickle)
|
||||
|
||||
if not pickle:
|
||||
# Raise an exception if cls cannot be serialized efficiently by Ray.
|
||||
serialization.check_serializable(cls)
|
||||
worker.run_function_on_all_workers(register_class_for_serialization)
|
||||
else:
|
||||
# Since we are pickling objects of this class, we don't actually need to
|
||||
# ship the class definition.
|
||||
register_class_for_serialization({})
|
||||
|
||||
|
||||
class RayLogSpan(object):
|
||||
@@ -1778,7 +1894,6 @@ def wait_for_function(function_id, driver_id, timeout=10,
|
||||
start_time = time.time()
|
||||
# Only send the warning once.
|
||||
warning_sent = False
|
||||
num_warnings_sent = 0
|
||||
while True:
|
||||
with worker.lock:
|
||||
if worker.actor_id == NIL_ACTOR_ID and (function_id.id() in
|
||||
@@ -1787,7 +1902,7 @@ def wait_for_function(function_id, driver_id, timeout=10,
|
||||
elif worker.actor_id != NIL_ACTOR_ID and (worker.actor_id in
|
||||
worker.actors):
|
||||
break
|
||||
if time.time() - start_time > timeout * (num_warnings_sent + 1):
|
||||
if time.time() - start_time > timeout:
|
||||
warning_message = ("This worker was asked to execute a function that "
|
||||
"it does not have registered. You may have to "
|
||||
"restart Ray.")
|
||||
|
||||
@@ -141,7 +141,6 @@ class ActorAPI(unittest.TestCase):
|
||||
class Foo(object):
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
ray.register_class(Foo)
|
||||
|
||||
@ray.remote
|
||||
class Actor(object):
|
||||
|
||||
@@ -28,42 +28,6 @@ def wait_for_errors(error_type, num_errors, timeout=10):
|
||||
print("Timing out of wait.")
|
||||
|
||||
|
||||
class FailureTest(unittest.TestCase):
|
||||
def testUnknownSerialization(self):
|
||||
reload(test_functions)
|
||||
ray.init(num_workers=1, driver_mode=ray.SILENT_MODE)
|
||||
|
||||
test_functions.test_unknown_type.remote()
|
||||
wait_for_errors(b"task", 1)
|
||||
self.assertEqual(len(relevant_errors(b"task")), 1)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
||||
class TaskSerializationTest(unittest.TestCase):
|
||||
def testReturnAndPassUnknownType(self):
|
||||
ray.init(num_workers=1, driver_mode=ray.SILENT_MODE)
|
||||
|
||||
class Foo(object):
|
||||
pass
|
||||
|
||||
# Check that returning an unknown type from a remote function raises an
|
||||
# exception.
|
||||
@ray.remote
|
||||
def f():
|
||||
return Foo()
|
||||
self.assertRaises(Exception, lambda: ray.get(f.remote()))
|
||||
|
||||
# Check that passing an unknown type into a remote function raises an
|
||||
# exception.
|
||||
@ray.remote
|
||||
def g(x):
|
||||
return 1
|
||||
self.assertRaises(Exception, lambda: g.remote(Foo()))
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
||||
class TaskStatusTest(unittest.TestCase):
|
||||
def testFailedTask(self):
|
||||
reload(test_functions)
|
||||
|
||||
+122
-51
@@ -2,15 +2,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import ray
|
||||
from collections import defaultdict, namedtuple
|
||||
import numpy as np
|
||||
import time
|
||||
import os
|
||||
import ray
|
||||
import re
|
||||
import shutil
|
||||
import string
|
||||
import sys
|
||||
from collections import defaultdict, namedtuple
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import ray.test.test_functions as test_functions
|
||||
|
||||
@@ -169,8 +170,6 @@ class SerializationTest(unittest.TestCase):
|
||||
class ClassA(object):
|
||||
pass
|
||||
|
||||
ray.register_class(ClassA)
|
||||
|
||||
# Make a list that contains itself.
|
||||
l = []
|
||||
l.append(l)
|
||||
@@ -201,14 +200,6 @@ class SerializationTest(unittest.TestCase):
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
ray.register_class(Exception)
|
||||
ray.register_class(CustomError)
|
||||
ray.register_class(Point)
|
||||
ray.register_class(Foo)
|
||||
ray.register_class(Bar)
|
||||
ray.register_class(Baz)
|
||||
ray.register_class(NamedTupleExample)
|
||||
|
||||
# Check that we can pass arguments by value to remote functions and that
|
||||
# they are uncorrupted.
|
||||
for obj in RAY_TEST_OBJECTS:
|
||||
@@ -303,20 +294,129 @@ class WorkerTest(unittest.TestCase):
|
||||
class APITest(unittest.TestCase):
|
||||
|
||||
def testRegisterClass(self):
|
||||
ray.init(num_workers=0)
|
||||
ray.init(num_workers=2)
|
||||
|
||||
# Check that putting an object of a class that has not been registered
|
||||
# throws an exception.
|
||||
class TempClass(object):
|
||||
pass
|
||||
self.assertRaises(Exception, lambda: ray.put(TempClass()))
|
||||
# Check that registering a class that Ray cannot serialize efficiently
|
||||
# raises an exception.
|
||||
self.assertRaises(Exception, lambda: ray.register_class(defaultdict))
|
||||
# Check that registering the same class with pickle works.
|
||||
ray.register_class(defaultdict, pickle=True)
|
||||
ray.get(ray.put(TempClass()))
|
||||
|
||||
# Note that the below actually returns a dictionary and not a defaultdict.
|
||||
# This is a bug (https://github.com/ray-project/ray/issues/512).
|
||||
ray.get(ray.put(defaultdict(lambda: 0)))
|
||||
|
||||
# Test passing custom classes into remote functions from the driver.
|
||||
@ray.remote
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
foo = ray.get(f.remote(Foo(7)))
|
||||
self.assertEqual(foo, Foo(7))
|
||||
|
||||
regex = re.compile(r"\d+\.\d*")
|
||||
new_regex = ray.get(f.remote(regex))
|
||||
self.assertEqual(regex, new_regex)
|
||||
|
||||
# Test returning custom classes created on workers.
|
||||
@ray.remote
|
||||
def g():
|
||||
return SubQux(), Qux()
|
||||
|
||||
subqux, qux = ray.get(g.remote())
|
||||
self.assertEqual(subqux.objs[2].foo.value, 0)
|
||||
|
||||
# Test exporting custom class definitions from one worker to another when
|
||||
# the worker is blocked in a get.
|
||||
class NewTempClass(object):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
@ray.remote
|
||||
def h1(x):
|
||||
return NewTempClass(x)
|
||||
|
||||
@ray.remote
|
||||
def h2(x):
|
||||
return ray.get(h1.remote(x))
|
||||
|
||||
self.assertEqual(ray.get(h2.remote(10)).value, 10)
|
||||
|
||||
# Test registering multiple classes with the same name.
|
||||
@ray.remote(num_return_vals=3)
|
||||
def j():
|
||||
class Class0(object):
|
||||
def method0(self):
|
||||
pass
|
||||
|
||||
c0 = Class0()
|
||||
|
||||
class Class0(object):
|
||||
def method1(self):
|
||||
pass
|
||||
|
||||
c1 = Class0()
|
||||
|
||||
class Class0(object):
|
||||
def method2(self):
|
||||
pass
|
||||
|
||||
c2 = Class0()
|
||||
|
||||
return c0, c1, c2
|
||||
|
||||
results = []
|
||||
for _ in range(5):
|
||||
results += j.remote()
|
||||
for i in range(len(results) // 3):
|
||||
c0, c1, c2 = ray.get(results[(3 * i):(3 * (i + 1))])
|
||||
|
||||
c0.method0()
|
||||
c1.method1()
|
||||
c2.method2()
|
||||
|
||||
self.assertFalse(hasattr(c0, "method1"))
|
||||
self.assertFalse(hasattr(c0, "method2"))
|
||||
self.assertFalse(hasattr(c1, "method0"))
|
||||
self.assertFalse(hasattr(c1, "method2"))
|
||||
self.assertFalse(hasattr(c2, "method0"))
|
||||
self.assertFalse(hasattr(c2, "method1"))
|
||||
|
||||
@ray.remote
|
||||
def k():
|
||||
class Class0(object):
|
||||
def method0(self):
|
||||
pass
|
||||
|
||||
c0 = Class0()
|
||||
|
||||
class Class0(object):
|
||||
def method1(self):
|
||||
pass
|
||||
|
||||
c1 = Class0()
|
||||
|
||||
class Class0(object):
|
||||
def method2(self):
|
||||
pass
|
||||
|
||||
c2 = Class0()
|
||||
|
||||
return c0, c1, c2
|
||||
|
||||
results = ray.get([k.remote() for _ in range(5)])
|
||||
for c0, c1, c2 in results:
|
||||
c0.method0()
|
||||
c1.method1()
|
||||
c2.method2()
|
||||
|
||||
self.assertFalse(hasattr(c0, "method1"))
|
||||
self.assertFalse(hasattr(c0, "method2"))
|
||||
self.assertFalse(hasattr(c1, "method0"))
|
||||
self.assertFalse(hasattr(c1, "method2"))
|
||||
self.assertFalse(hasattr(c2, "method0"))
|
||||
self.assertFalse(hasattr(c2, "method1"))
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testKeywordArgs(self):
|
||||
@@ -666,35 +766,6 @@ class APITest(unittest.TestCase):
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testPassingInfoToAllWorkers(self):
|
||||
ray.init(num_workers=10, num_cpus=10)
|
||||
|
||||
def f(worker_info):
|
||||
sys.path.append(worker_info)
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
@ray.remote
|
||||
def get_path():
|
||||
time.sleep(1)
|
||||
return sys.path
|
||||
# Retrieve the values that we stored in the worker paths.
|
||||
paths = ray.get([get_path.remote() for _ in range(10)])
|
||||
# Add the driver's path to the list.
|
||||
paths.append(sys.path)
|
||||
worker_infos = [path[-1] for path in paths]
|
||||
for worker_info in worker_infos:
|
||||
self.assertEqual(list(worker_info.keys()), ["counter"])
|
||||
counters = [worker_info["counter"] for worker_info in worker_infos]
|
||||
# We use range(11) because the driver also runs the function.
|
||||
self.assertEqual(set(counters), set(range(11)))
|
||||
|
||||
# Clean up the worker paths.
|
||||
def f(worker_info):
|
||||
sys.path.pop(-1)
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testLoggingAPI(self):
|
||||
ray.init(num_workers=1, driver_mode=ray.SILENT_MODE)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user