mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
Rebase Ray on latest arrow (remove numbuf from Ray). (#910)
* remove some stuff * put get roundtrip working * fixes * more fixes * cleanup * fix tests * latest arrow * fixes * fix tests * fix linting * rebase * fixes * fix bug * bring back libgcc error * fix linting * use official arrow repo * fixes
This commit is contained in:
committed by
Robert Nishihara
parent
a2814567e1
commit
7030ef366f
@@ -10,6 +10,36 @@ pyarrow_path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
|
||||
"pyarrow_files")
|
||||
sys.path.insert(0, pyarrow_path)
|
||||
|
||||
# See https://github.com/ray-project/ray/issues/131.
|
||||
helpful_message = """
|
||||
|
||||
If you are using Anaconda, try fixing this problem by running:
|
||||
|
||||
conda install libgcc
|
||||
"""
|
||||
|
||||
try:
|
||||
import pyarrow # noqa: F401
|
||||
except ImportError as e:
|
||||
if ((hasattr(e, "msg") and isinstance(e.msg, str) and
|
||||
("libstdc++" in e.msg or "CXX" in e.msg))):
|
||||
# This code path should be taken with Python 3.
|
||||
e.msg += helpful_message
|
||||
elif (hasattr(e, "message") and isinstance(e.message, str) and
|
||||
("libstdc++" in e.message or "CXX" in e.message)):
|
||||
# This code path should be taken with Python 2.
|
||||
condition = (hasattr(e, "args") and isinstance(e.args, tuple) and
|
||||
len(e.args) == 1 and isinstance(e.args[0], str))
|
||||
if condition:
|
||||
e.args = (e.args[0] + helpful_message,)
|
||||
else:
|
||||
if not hasattr(e, "args"):
|
||||
e.args = ()
|
||||
elif not isinstance(e.args, tuple):
|
||||
e.args = (e.args,)
|
||||
e.args += (helpful_message,)
|
||||
raise
|
||||
|
||||
from ray.worker import (register_class, error_info, init, connect, disconnect,
|
||||
get, put, wait, remote, log_event, log_span,
|
||||
flush_log, get_gpu_ids) # noqa: E402
|
||||
|
||||
@@ -188,8 +188,8 @@ class TestLocalSchedulerClient(unittest.TestCase):
|
||||
time.sleep(0.1)
|
||||
self.assertTrue(t.is_alive())
|
||||
# Check that the first object dependency was evicted.
|
||||
object1 = self.plasma_client.get([pa.plasma.ObjectID(object_id1.id())],
|
||||
timeout_ms=0)
|
||||
object1 = self.plasma_client.get_buffers(
|
||||
[pa.plasma.ObjectID(object_id1.id())], timeout_ms=0)
|
||||
self.assertEqual(object1, [None])
|
||||
# Check that the thread is still waiting for a task.
|
||||
time.sleep(0.1)
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# See https://github.com/ray-project/ray/issues/131.
|
||||
helpful_message = """
|
||||
|
||||
If you are using Anaconda, try fixing this problem by running:
|
||||
|
||||
conda install libgcc
|
||||
"""
|
||||
|
||||
__all__ = ["deserialize_list", "numbuf_error",
|
||||
"numbuf_plasma_object_exists_error", "read_from_buffer",
|
||||
"register_callbacks", "retrieve_list", "serialize_list",
|
||||
"store_list", "write_to_buffer"]
|
||||
|
||||
try:
|
||||
from ray.core.src.numbuf.libnumbuf import (
|
||||
deserialize_list, numbuf_error, numbuf_plasma_object_exists_error,
|
||||
read_from_buffer, register_callbacks, retrieve_list, serialize_list,
|
||||
store_list, write_to_buffer)
|
||||
except ImportError as e:
|
||||
if ((hasattr(e, "msg") and isinstance(e.msg, str) and
|
||||
("libstdc++" in e.msg or "CXX" in e.msg))):
|
||||
# This code path should be taken with Python 3.
|
||||
e.msg += helpful_message
|
||||
elif (hasattr(e, "message") and isinstance(e.message, str) and
|
||||
("libstdc++" in e.message or "CXX" in e.message)):
|
||||
# This code path should be taken with Python 2.
|
||||
condition = (hasattr(e, "args") and isinstance(e.args, tuple) and
|
||||
len(e.args) == 1 and isinstance(e.args[0], str))
|
||||
if condition:
|
||||
e.args = (e.args[0] + helpful_message,)
|
||||
else:
|
||||
if not hasattr(e, "args"):
|
||||
e.args = ()
|
||||
elif not isinstance(e.args, tuple):
|
||||
e.args = (e.args,)
|
||||
e.args += (helpful_message,)
|
||||
raise
|
||||
@@ -32,8 +32,8 @@ def random_name():
|
||||
|
||||
def assert_get_object_equal(unit_test, client1, client2, object_id,
|
||||
memory_buffer=None, metadata=None):
|
||||
client1_buff = client1.get([object_id])[0]
|
||||
client2_buff = client2.get([object_id])[0]
|
||||
client1_buff = client1.get_buffers([object_id])[0]
|
||||
client2_buff = client2.get_buffers([object_id])[0]
|
||||
client1_metadata = client1.get_metadata([object_id])[0]
|
||||
client2_metadata = client2.get_metadata([object_id])[0]
|
||||
unit_test.assertEqual(len(client1_buff), len(client2_buff))
|
||||
@@ -371,7 +371,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
# trying until the object appears on the second Plasma store.
|
||||
for i in range(num_attempts):
|
||||
self.client1.transfer("127.0.0.1", self.port2, object_id1)
|
||||
buff = self.client2.get([object_id1], timeout_ms=100)[0]
|
||||
buff = self.client2.get_buffers(
|
||||
[object_id1], timeout_ms=100)[0]
|
||||
if buff is not None:
|
||||
break
|
||||
self.assertNotEqual(buff, None)
|
||||
@@ -397,7 +398,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
# trying until the object appears on the second Plasma store.
|
||||
for i in range(num_attempts):
|
||||
self.client2.transfer("127.0.0.1", self.port1, object_id2)
|
||||
buff = self.client1.get([object_id2], timeout_ms=100)[0]
|
||||
buff = self.client1.get_buffers(
|
||||
[object_id2], timeout_ms=100)[0]
|
||||
if buff is not None:
|
||||
break
|
||||
self.assertNotEqual(buff, None)
|
||||
|
||||
@@ -2,22 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import cloudpickle as pickle
|
||||
|
||||
import ray.numbuf
|
||||
|
||||
|
||||
class RaySerializationException(Exception):
|
||||
def __init__(self, message, example_object):
|
||||
Exception.__init__(self, message)
|
||||
self.example_object = example_object
|
||||
|
||||
|
||||
class RayDeserializationException(Exception):
|
||||
def __init__(self, message, class_id):
|
||||
Exception.__init__(self, message)
|
||||
self.class_id = class_id
|
||||
|
||||
|
||||
class RayNotDictionarySerializable(Exception):
|
||||
pass
|
||||
@@ -65,15 +49,6 @@ def check_serializable(cls):
|
||||
"it efficiently.".format(cls))
|
||||
|
||||
|
||||
# This field keeps track of a whitelisted set of classes that Ray will
|
||||
# serialize.
|
||||
type_to_class_id = dict()
|
||||
whitelisted_classes = dict()
|
||||
classes_to_pickle = set()
|
||||
custom_serializers = dict()
|
||||
custom_deserializers = dict()
|
||||
|
||||
|
||||
def is_named_tuple(cls):
|
||||
"""Return True if cls is a namedtuple and False otherwise."""
|
||||
b = cls.__bases__
|
||||
@@ -83,128 +58,3 @@ def is_named_tuple(cls):
|
||||
if not isinstance(f, tuple):
|
||||
return False
|
||||
return all(type(n) == str for n in f)
|
||||
|
||||
|
||||
def add_class_to_whitelist(cls, class_id, pickle=False, custom_serializer=None,
|
||||
custom_deserializer=None):
|
||||
"""Add cls to the list of classes that we can serialize.
|
||||
|
||||
Args:
|
||||
cls (type): The class that we can serialize.
|
||||
class_id: A string of bytes used to identify the class.
|
||||
pickle (bool): True if the serialization should be done with pickle.
|
||||
False if it should be done efficiently with Ray.
|
||||
custom_serializer: This argument is optional, but can be provided to
|
||||
serialize objects of the class in a particular way.
|
||||
custom_deserializer: This argument is optional, but can be provided to
|
||||
deserialize objects of the class in a particular way.
|
||||
"""
|
||||
type_to_class_id[cls] = class_id
|
||||
whitelisted_classes[class_id] = cls
|
||||
if pickle:
|
||||
classes_to_pickle.add(class_id)
|
||||
if custom_serializer is not None:
|
||||
custom_serializers[class_id] = custom_serializer
|
||||
custom_deserializers[class_id] = custom_deserializer
|
||||
|
||||
|
||||
def serialize(obj):
|
||||
"""This is the callback that will be used by numbuf.
|
||||
|
||||
If numbuf does not know how to serialize an object, it will call this
|
||||
method.
|
||||
|
||||
Args:
|
||||
obj (object): A Python object.
|
||||
|
||||
Returns:
|
||||
A dictionary that has the key "_pyttype_" to identify the class, and
|
||||
contains all information needed to reconstruct the object.
|
||||
"""
|
||||
if type(obj) not in type_to_class_id:
|
||||
raise RaySerializationException("Ray does not know how to serialize "
|
||||
"objects of type {}."
|
||||
.format(type(obj)),
|
||||
obj)
|
||||
class_id = type_to_class_id[type(obj)]
|
||||
|
||||
if class_id in classes_to_pickle:
|
||||
serialized_obj = {"data": pickle.dumps(obj),
|
||||
"pickle": True}
|
||||
elif class_id in custom_serializers:
|
||||
serialized_obj = {"data": custom_serializers[class_id](obj)}
|
||||
else:
|
||||
# Handle the namedtuple case.
|
||||
if is_named_tuple(type(obj)):
|
||||
serialized_obj = {}
|
||||
serialized_obj["_ray_getnewargs_"] = obj.__getnewargs__()
|
||||
elif hasattr(obj, "__dict__"):
|
||||
serialized_obj = obj.__dict__
|
||||
else:
|
||||
raise RaySerializationException("We do not know how to serialize "
|
||||
"the object '{}'".format(obj), obj)
|
||||
result = dict(serialized_obj, **{"_pytype_": class_id})
|
||||
return result
|
||||
|
||||
|
||||
def deserialize(serialized_obj):
|
||||
"""This is the callback that will be used by numbuf.
|
||||
|
||||
If numbuf encounters a dictionary that contains the key "_pytype_" during
|
||||
deserialization, it will ask this callback to deserialize the object.
|
||||
|
||||
Args:
|
||||
serialized_obj (object): A dictionary that contains the key "_pytype_".
|
||||
|
||||
Returns:
|
||||
A Python object.
|
||||
|
||||
Raises:
|
||||
An exception is raised if we do not know how to deserialize the object.
|
||||
"""
|
||||
class_id = serialized_obj["_pytype_"]
|
||||
|
||||
if "pickle" in serialized_obj:
|
||||
# The object was pickled, so unpickle it.
|
||||
obj = pickle.loads(serialized_obj["data"])
|
||||
else:
|
||||
assert class_id not in classes_to_pickle
|
||||
if class_id not in whitelisted_classes:
|
||||
# If this happens, that means that the call to _register_class,
|
||||
# which should have added the class to the list of whitelisted
|
||||
# classes, has not yet propagated to this worker. It should happen
|
||||
# if we wait a little longer.
|
||||
raise RayDeserializationException("The class {} is not one of the "
|
||||
"whitelisted classes."
|
||||
.format(class_id), class_id)
|
||||
cls = whitelisted_classes[class_id]
|
||||
if class_id in custom_deserializers:
|
||||
obj = custom_deserializers[class_id](serialized_obj["data"])
|
||||
else:
|
||||
# In this case, serialized_obj should just be the __dict__ field.
|
||||
if "_ray_getnewargs_" in serialized_obj:
|
||||
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
|
||||
else:
|
||||
obj = cls.__new__(cls)
|
||||
serialized_obj.pop("_pytype_")
|
||||
obj.__dict__.update(serialized_obj)
|
||||
return obj
|
||||
|
||||
|
||||
def set_callbacks():
|
||||
"""Register the custom callbacks with numbuf.
|
||||
|
||||
The serialize callback is used to serialize objects that numbuf does not
|
||||
know how to serialize (for example custom Python classes). The deserialize
|
||||
callback is used to serialize objects that were serialized by the serialize
|
||||
callback.
|
||||
"""
|
||||
ray.numbuf.register_callbacks(serialize, deserialize)
|
||||
|
||||
|
||||
def clear_state():
|
||||
type_to_class_id.clear()
|
||||
whitelisted_classes.clear()
|
||||
classes_to_pickle.clear()
|
||||
custom_serializers.clear()
|
||||
custom_deserializers.clear()
|
||||
|
||||
+37
-63
@@ -20,12 +20,12 @@ import time
|
||||
import traceback
|
||||
|
||||
# Ray modules
|
||||
import pyarrow
|
||||
import pyarrow.plasma as plasma
|
||||
import ray.experimental.state as state
|
||||
import ray.serialization as serialization
|
||||
import ray.services as services
|
||||
import ray.signature as signature
|
||||
import ray.numbuf
|
||||
import ray.local_scheduler
|
||||
import ray.plasma
|
||||
from ray.utils import FunctionProperties, random_string, binary_to_hex
|
||||
@@ -70,27 +70,6 @@ class FunctionID(object):
|
||||
return self.function_id
|
||||
|
||||
|
||||
contained_objectids = []
|
||||
|
||||
|
||||
def numbuf_serialize(value):
|
||||
"""This serializes a value and tracks the object IDs inside the value.
|
||||
|
||||
We also define a custom ObjectID serializer which also closes over the
|
||||
global variable contained_objectids, and whenever the custom serializer is
|
||||
called, it adds the releevant ObjectID to the list contained_objectids. The
|
||||
list contained_objectids should be reset between calls to numbuf_serialize.
|
||||
|
||||
Args:
|
||||
value: A Python object that will be serialized.
|
||||
|
||||
Returns:
|
||||
The serialized object.
|
||||
"""
|
||||
assert len(contained_objectids) == 0, "This should be unreachable."
|
||||
return ray.numbuf.serialize_list([value])
|
||||
|
||||
|
||||
class RayTaskError(Exception):
|
||||
"""An object used internally to represent a task that threw an exception.
|
||||
|
||||
@@ -300,11 +279,10 @@ class Worker(object):
|
||||
"type {}.".format(type(value)))
|
||||
counter += 1
|
||||
try:
|
||||
ray.numbuf.store_list(object_id.id(),
|
||||
self.plasma_client.to_capsule(),
|
||||
[value])
|
||||
self.plasma_client.put(value, pyarrow.plasma.ObjectID(
|
||||
object_id.id()), self.serialization_context)
|
||||
break
|
||||
except serialization.RaySerializationException as e:
|
||||
except pyarrow.SerializationCallbackError as e:
|
||||
try:
|
||||
_register_class(type(e.example_object))
|
||||
warning_message = ("WARNING: Serializing objects of type "
|
||||
@@ -349,7 +327,7 @@ class Worker(object):
|
||||
# Serialize and put the object in the object store.
|
||||
try:
|
||||
self.store_and_register(object_id, value)
|
||||
except ray.numbuf.numbuf_plasma_object_exists_error as e:
|
||||
except pyarrow.PlasmaObjectExists as e:
|
||||
# The object already exists in the object store, so there is no
|
||||
# need to add it again. TODO(rkn): We need to compare the hashes
|
||||
# and make sure that the objects are in fact the same. We also
|
||||
@@ -357,10 +335,6 @@ class Worker(object):
|
||||
# message.
|
||||
print("This object already exists in the object store.")
|
||||
|
||||
global contained_objectids
|
||||
# Optionally do something with the contained_objectids here.
|
||||
contained_objectids = []
|
||||
|
||||
def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10):
|
||||
start_time = time.time()
|
||||
# Only send the warning once.
|
||||
@@ -374,12 +348,12 @@ class Worker(object):
|
||||
results = []
|
||||
get_request_size = 10000
|
||||
for i in range(0, len(object_ids), get_request_size):
|
||||
results += ray.numbuf.retrieve_list(
|
||||
results += self.plasma_client.get(
|
||||
object_ids[i:(i + get_request_size)],
|
||||
self.plasma_client.to_capsule(),
|
||||
timeout)
|
||||
timeout,
|
||||
self.serialization_context)
|
||||
return results
|
||||
except serialization.RayDeserializationException as e:
|
||||
except pyarrow.DeserializationCallbackError as e:
|
||||
# Wait a little bit for the import thread to import the class.
|
||||
# If we currently have the worker lock, we need to release it
|
||||
# so that the import thread can acquire it.
|
||||
@@ -428,12 +402,12 @@ class Worker(object):
|
||||
plain_object_ids[i:(i + fetch_request_size)])
|
||||
|
||||
# Get the objects. We initially try to get the objects immediately.
|
||||
final_results = self.retrieve_and_deserialize(
|
||||
[object_id.id() for object_id in object_ids], 0)
|
||||
final_results = self.retrieve_and_deserialize(plain_object_ids, 0)
|
||||
# Construct a dictionary mapping object IDs that we haven't gotten yet
|
||||
# to their original index in the object_ids argument.
|
||||
unready_ids = dict((object_id, i) for (i, (object_id, val)) in
|
||||
enumerate(final_results) if val is None)
|
||||
unready_ids = dict((plain_object_ids[i].binary(), i) for (i, val) in
|
||||
enumerate(final_results)
|
||||
if val is plasma.ObjectNotAvailable)
|
||||
was_blocked = (len(unready_ids) > 0)
|
||||
# Try reconstructing any objects we haven't gotten yet. Try to get them
|
||||
# until at least GET_TIMEOUT_MILLISECONDS milliseconds passes, then
|
||||
@@ -451,14 +425,15 @@ class Worker(object):
|
||||
self.plasma_client.fetch(
|
||||
object_ids_to_fetch[i:(i + fetch_request_size)])
|
||||
results = self.retrieve_and_deserialize(
|
||||
list(unready_ids.keys()),
|
||||
object_ids_to_fetch,
|
||||
max([GET_TIMEOUT_MILLISECONDS, int(0.01 * len(unready_ids))]))
|
||||
# Remove any entries for objects we received during this iteration
|
||||
# so we don't retrieve the same object twice.
|
||||
for object_id, val in results:
|
||||
if val is not None:
|
||||
for i, val in enumerate(results):
|
||||
if val is not plasma.ObjectNotAvailable:
|
||||
object_id = object_ids_to_fetch[i].binary()
|
||||
index = unready_ids[object_id]
|
||||
final_results[index] = (object_id, val)
|
||||
final_results[index] = val
|
||||
unready_ids.pop(object_id)
|
||||
|
||||
# If there were objects that we weren't able to get locally, let the
|
||||
@@ -466,11 +441,8 @@ class Worker(object):
|
||||
if was_blocked:
|
||||
self.local_scheduler_client.notify_unblocked()
|
||||
|
||||
# Unwrap the object from the list (it was wrapped put_object).
|
||||
assert len(final_results) == len(object_ids)
|
||||
for i in range(len(final_results)):
|
||||
assert final_results[i][0] == object_ids[i].id()
|
||||
return [result[1][0] for result in final_results]
|
||||
return final_results
|
||||
|
||||
def submit_task(self, function_id, args, actor_id=None):
|
||||
"""Submit a remote task to the scheduler.
|
||||
@@ -556,7 +528,7 @@ class Worker(object):
|
||||
# counter starts at 0.
|
||||
counter = self.redis_client.hincrby(self.node_ip_address,
|
||||
key, 1) - 1
|
||||
function({"counter": counter})
|
||||
function({"counter": counter, "worker": self})
|
||||
# Run the function on all workers.
|
||||
self.redis_client.hmset(key,
|
||||
{"driver_id": self.task_driver_id.id(),
|
||||
@@ -991,23 +963,22 @@ def error_info(worker=global_worker):
|
||||
return errors
|
||||
|
||||
|
||||
def initialize_numbuf(worker=global_worker):
|
||||
def _initialize_serialization(worker=global_worker):
|
||||
"""Initialize the serialization library.
|
||||
|
||||
This defines a custom serializer for object IDs and also tells numbuf to
|
||||
This defines a custom serializer for object IDs and also tells ray to
|
||||
serialize several exception classes that we define for error handling.
|
||||
"""
|
||||
ray.serialization.set_callbacks()
|
||||
worker.serialization_context = pyarrow.SerializationContext()
|
||||
|
||||
# Define a custom serializer and deserializer for handling Object IDs.
|
||||
def objectid_custom_serializer(obj):
|
||||
contained_objectids.append(obj)
|
||||
return obj.id()
|
||||
|
||||
def objectid_custom_deserializer(serialized_obj):
|
||||
return ray.local_scheduler.ObjectID(serialized_obj)
|
||||
|
||||
serialization.add_class_to_whitelist(
|
||||
worker.serialization_context.register_type(
|
||||
ray.local_scheduler.ObjectID, 20 * b"\x00", pickle=False,
|
||||
custom_serializer=objectid_custom_serializer,
|
||||
custom_deserializer=objectid_custom_deserializer)
|
||||
@@ -1020,7 +991,7 @@ def initialize_numbuf(worker=global_worker):
|
||||
def array_custom_deserializer(serialized_obj):
|
||||
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
|
||||
|
||||
serialization.add_class_to_whitelist(
|
||||
worker.serialization_context.register_type(
|
||||
np.ndarray, 20 * b"\x01", pickle=False,
|
||||
custom_serializer=array_custom_serializer,
|
||||
custom_deserializer=array_custom_deserializer)
|
||||
@@ -1503,7 +1474,7 @@ def fetch_and_execute_function_to_run(key, worker=global_worker):
|
||||
# Deserialize the function.
|
||||
function = pickle.loads(serialized_function)
|
||||
# Run the function.
|
||||
function({"counter": counter})
|
||||
function({"counter": counter, "worker": worker})
|
||||
except:
|
||||
# If an exception was thrown when the function was run, we record the
|
||||
# traceback and notify the scheduler of the failure.
|
||||
@@ -1770,6 +1741,10 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
|
||||
class_id = worker.redis_client.hget(actor_key, "class_id")
|
||||
worker.class_id = class_id
|
||||
|
||||
# Initialize the serialization library. This registers some classes, and so
|
||||
# it must be run before we export all of the cached remote functions.
|
||||
_initialize_serialization()
|
||||
|
||||
# Start a thread to import exports from the driver or from other workers.
|
||||
# Note that the driver also has an import thread, which is used only to
|
||||
# import custom class definitions from calls to _register_class that happen
|
||||
@@ -1791,9 +1766,7 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
|
||||
# exits.
|
||||
t.daemon = True
|
||||
t.start()
|
||||
# Initialize the serialization library. This registers some classes, and so
|
||||
# it must be run before we export all of the cached remote functions.
|
||||
initialize_numbuf()
|
||||
|
||||
if mode in [SCRIPT_MODE, SILENT_MODE]:
|
||||
# Add the directory containing the script that is running to the Python
|
||||
# paths of the workers. Also add the current directory. Note that this
|
||||
@@ -1835,7 +1808,7 @@ def disconnect(worker=global_worker):
|
||||
worker.connected = False
|
||||
worker.cached_functions_to_run = []
|
||||
worker.cached_remote_functions = []
|
||||
serialization.clear_state()
|
||||
worker.serialization_context = pyarrow.SerializationContext()
|
||||
|
||||
|
||||
def register_class(cls, pickle=False, worker=global_worker):
|
||||
@@ -1847,11 +1820,11 @@ def _register_class(cls, pickle=False, worker=global_worker):
|
||||
"""Enable serialization and deserialization for a particular class.
|
||||
|
||||
This method runs the register_class function defined below on every worker,
|
||||
which will enable numbuf to properly serialize and deserialize objects of
|
||||
which will enable ray to properly serialize and deserialize objects of
|
||||
this class.
|
||||
|
||||
Args:
|
||||
cls (type): The class that numbuf should serialize.
|
||||
cls (type): The class that ray should serialize.
|
||||
pickle (bool): If False then objects of this class will be serialized
|
||||
by turning their __dict__ fields into a dictionary. If True, then
|
||||
objects of this class will be serialized using pickle.
|
||||
@@ -1863,7 +1836,8 @@ def _register_class(cls, pickle=False, worker=global_worker):
|
||||
class_id = random_string()
|
||||
|
||||
def register_class_for_serialization(worker_info):
|
||||
serialization.add_class_to_whitelist(cls, class_id, pickle=pickle)
|
||||
worker_info["worker"].serialization_context.register_type(
|
||||
cls, class_id, pickle=pickle)
|
||||
|
||||
if not pickle:
|
||||
# Raise an exception if cls cannot be serialized efficiently by Ray.
|
||||
@@ -1872,7 +1846,7 @@ def _register_class(cls, pickle=False, worker=global_worker):
|
||||
else:
|
||||
# Since we are pickling objects of this class, we don't actually need
|
||||
# to ship the class definition.
|
||||
register_class_for_serialization({})
|
||||
register_class_for_serialization({"worker": worker})
|
||||
|
||||
|
||||
class RayLogSpan(object):
|
||||
|
||||
@@ -22,7 +22,6 @@ ray_files = [
|
||||
"ray/core/src/plasma/plasma_manager",
|
||||
"ray/core/src/local_scheduler/local_scheduler",
|
||||
"ray/core/src/local_scheduler/liblocal_scheduler_library.so",
|
||||
"ray/core/src/numbuf/libnumbuf.so",
|
||||
"ray/core/src/global_scheduler/global_scheduler",
|
||||
"ray/WebUI.ipynb"
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user