mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 13:32:36 +08:00
Allow users to serialize custom classes. (#393)
* Allow serialization of custom classes. * Add documentation and test cases, also fix pickle case. * Don't allow old-style classes.
This commit is contained in:
committed by
Philipp Moritz
parent
d5cb3ac090
commit
11a8914684
@@ -11,7 +11,7 @@ if hasattr(ctypes, "windll"):
|
||||
|
||||
import config
|
||||
import serialization
|
||||
from worker import scheduler_info, visualize_computation_graph, task_info, init, connect, disconnect, get, put, select, remote, kill_workers, restart_workers_local
|
||||
from worker import scheduler_info, register_class, visualize_computation_graph, task_info, init, connect, disconnect, get, put, select, remote, kill_workers, restart_workers_local
|
||||
from worker import Reusable, reusables
|
||||
from libraylib import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE
|
||||
from libraylib import ObjectID
|
||||
|
||||
@@ -16,14 +16,6 @@ class DistArray(object):
|
||||
if self.num_blocks != list(self.objectids.shape):
|
||||
raise Exception("The fields `num_blocks` and `objectids` are inconsistent, `num_blocks` is {} and `objectids` has shape {}".format(self.num_blocks, list(self.objectids.shape)))
|
||||
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
(shape, objectids) = primitives
|
||||
return DistArray(shape, objectids)
|
||||
|
||||
def serialize(self):
|
||||
return (self.shape, self.objectids)
|
||||
|
||||
@staticmethod
|
||||
def compute_block_lower(index, shape):
|
||||
if len(index) != len(shape):
|
||||
|
||||
+131
-35
@@ -1,42 +1,138 @@
|
||||
import importlib
|
||||
import numpy as np
|
||||
|
||||
import pickling
|
||||
import libraylib as raylib
|
||||
import libnumbuf
|
||||
|
||||
def to_primitive(obj):
|
||||
if hasattr(obj, "serialize"):
|
||||
primitive_obj = ((type(obj).__module__, type(obj).__name__), obj.serialize())
|
||||
else:
|
||||
primitive_obj = ("primitive", obj)
|
||||
return primitive_obj
|
||||
def check_serializable(cls):
|
||||
"""Throws an exception if Ray cannot serialize this class efficiently.
|
||||
|
||||
def from_primitive(primitive_obj):
|
||||
if primitive_obj[0] == "primitive":
|
||||
obj = primitive_obj[1]
|
||||
Args:
|
||||
cls (type): The class to be serialized.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if Ray cannot serialize this class
|
||||
efficiently.
|
||||
"""
|
||||
if is_named_tuple(cls):
|
||||
# This case works.
|
||||
return
|
||||
if not hasattr(cls, "__new__"):
|
||||
raise Exception("The class {} does not have a '__new__' attribute, and is probably an old-style class. We do not support this. Please either make it a new-style class by inheriting from 'object', or use 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
|
||||
try:
|
||||
obj = cls.__new__(cls)
|
||||
except:
|
||||
raise Exception("The class {} has overridden '__new__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
|
||||
if not hasattr(obj, "__dict__"):
|
||||
raise Exception("Objects of the class {} do not have a `__dict__` attribute, so Ray cannot serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
|
||||
if hasattr(obj, "__slots__"):
|
||||
raise Exception("The class {} uses '__slots__', so Ray may not be able to serialize it efficiently. Try using 'ray.register_class(cls, pickle=True)'. However, note that pickle is inefficient.".format(cls))
|
||||
|
||||
# This field keeps track of a whitelisted set of classes that Ray will
|
||||
# serialize.
|
||||
whitelisted_classes = {}
|
||||
classes_to_pickle = set()
|
||||
custom_serializers = {}
|
||||
custom_deserializers = {}
|
||||
|
||||
def class_identifier(typ):
|
||||
"""Return a string that identifies this type."""
|
||||
return "{}.{}".format(typ.__module__, typ.__name__)
|
||||
|
||||
def is_named_tuple(cls):
|
||||
"""Return True if cls is a namedtuple and False otherwise."""
|
||||
b = cls.__bases__
|
||||
if len(b) != 1 or b[0] != tuple:
|
||||
return False
|
||||
f = getattr(cls, "_fields", None)
|
||||
if not isinstance(f, tuple):
|
||||
return False
|
||||
return all(type(n) == str for n in f)
|
||||
|
||||
def add_class_to_whitelist(cls, pickle=False, custom_serializer=None, custom_deserializer=None):
|
||||
"""Add cls to the list of classes that we can serialize.
|
||||
|
||||
Args:
|
||||
cls (type): The class that we can serialize.
|
||||
pickle (bool): True if the serialization should be done with pickle. False
|
||||
if it should be done efficiently with Ray.
|
||||
custom_serializer: This argument is optional, but can be provided to
|
||||
serialize objects of the class in a particular way.
|
||||
custom_deserializer: This argument is optional, but can be provided to
|
||||
deserialize objects of the class in a particular way.
|
||||
"""
|
||||
class_id = class_identifier(cls)
|
||||
whitelisted_classes[class_id] = cls
|
||||
if pickle:
|
||||
classes_to_pickle.add(class_id)
|
||||
if custom_serializer is not None:
|
||||
custom_serializers[class_id] = custom_serializer
|
||||
custom_deserializers[class_id] = custom_deserializer
|
||||
|
||||
# Here we define a custom serializer and deserializer for handling numpy
|
||||
# arrays that contain objects.
|
||||
def array_custom_serializer(obj):
|
||||
return obj.tolist(), obj.dtype.str
|
||||
def array_custom_deserializer(serialized_obj):
|
||||
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
|
||||
add_class_to_whitelist(np.ndarray, pickle=False, custom_serializer=array_custom_serializer, custom_deserializer=array_custom_deserializer)
|
||||
|
||||
def serialize(obj):
|
||||
"""This is the callback that will be used by numbuf.
|
||||
|
||||
If numbuf does not know how to serialize an object, it will call this method.
|
||||
|
||||
Args:
|
||||
obj (object): A Python object.
|
||||
|
||||
Returns:
|
||||
A dictionary that has the key "_pyttype_" to identify the class, and
|
||||
contains all information needed to reconstruct the object.
|
||||
"""
|
||||
class_id = class_identifier(type(obj))
|
||||
if class_id not in whitelisted_classes:
|
||||
raise Exception("Ray does not know how to serialize the object {}. To fix this, call 'ray.register_class' on the class of the object.".format(obj))
|
||||
if class_id in classes_to_pickle:
|
||||
serialized_obj = {"data": pickling.dumps(obj)}
|
||||
elif class_id in custom_serializers.keys():
|
||||
serialized_obj = {"data": custom_serializers[class_id](obj)}
|
||||
else:
|
||||
# This code assumes that the type module.__dict__[type_name] knows how to deserialize itself
|
||||
type_module, type_name = primitive_obj[0]
|
||||
module = importlib.import_module(type_module)
|
||||
obj = module.__dict__[type_name].deserialize(primitive_obj[1])
|
||||
if not hasattr(obj, "__dict__"):
|
||||
raise Exception("We do not know how to serialize the object '{}'".format(obj))
|
||||
serialized_obj = obj.__dict__
|
||||
if is_named_tuple(type(obj)):
|
||||
# Handle the namedtuple case.
|
||||
serialized_obj["_ray_getnewargs_"] = obj.__getnewargs__()
|
||||
result = dict(serialized_obj, **{"_pytype_": class_id})
|
||||
return result
|
||||
|
||||
def deserialize(serialized_obj):
|
||||
"""This is the callback that will be used by numbuf.
|
||||
|
||||
If numbuf encounters a dictionary that contains the key "_pytype_" during
|
||||
deserialization, it will ask this callback to deserialize the object.
|
||||
|
||||
Args:
|
||||
serialized_obj (object): A dictionary that contains the key "_pytype_".
|
||||
|
||||
Returns:
|
||||
A Python object.
|
||||
"""
|
||||
class_id = serialized_obj["_pytype_"]
|
||||
cls = whitelisted_classes[class_id]
|
||||
if class_id in classes_to_pickle:
|
||||
obj = pickling.loads(serialized_obj["data"])
|
||||
elif class_id in custom_deserializers.keys():
|
||||
obj = custom_deserializers[class_id](serialized_obj["data"])
|
||||
else:
|
||||
# In this case, serialized_obj should just be the __dict__ field.
|
||||
if "_ray_getnewargs_" in serialized_obj:
|
||||
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
|
||||
serialized_obj.pop("_ray_getnewargs_")
|
||||
else:
|
||||
obj = cls.__new__(cls)
|
||||
serialized_obj.pop("_pytype_")
|
||||
obj.__dict__.update(serialized_obj)
|
||||
return obj
|
||||
|
||||
def is_arrow_serializable(value):
|
||||
return isinstance(value, np.ndarray) and value.dtype.name in ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "float32", "float64"]
|
||||
|
||||
def serialize(worker_capsule, obj):
|
||||
primitive_obj = to_primitive(obj)
|
||||
obj_capsule, contained_objectids = raylib.serialize_object(worker_capsule, primitive_obj) # contained_objectids is a list of the objectids contained in obj
|
||||
return obj_capsule, contained_objectids
|
||||
|
||||
def deserialize(worker_capsule, capsule):
|
||||
primitive_obj = raylib.deserialize_object(worker_capsule, capsule)
|
||||
return from_primitive(primitive_obj)
|
||||
|
||||
def serialize_task(worker_capsule, func_name, args):
|
||||
primitive_args = [(arg if isinstance(arg, raylib.ObjectID) else to_primitive(arg)) for arg in args]
|
||||
return raylib.serialize_task(worker_capsule, func_name, primitive_args)
|
||||
|
||||
def deserialize_task(worker_capsule, task):
|
||||
func_name, primitive_args, return_objectids = task
|
||||
args = [(arg if isinstance(arg, raylib.ObjectID) else from_primitive(arg)) for arg in primitive_args]
|
||||
return func_name, args, return_objectids
|
||||
# Register the callbacks with numbuf.
|
||||
libnumbuf.register_callbacks(serialize, deserialize)
|
||||
|
||||
+118
-102
@@ -22,13 +22,31 @@ import services
|
||||
import libnumbuf
|
||||
import libraylib as raylib
|
||||
|
||||
contained_objectids = []
|
||||
def numbuf_serialize(value):
|
||||
"""This serializes a value and tracks the object IDs inside the value.
|
||||
|
||||
We also define a custom ObjectID serializer which also closes over the global
|
||||
variable contained_objectids, and whenever the custom serializer is called, it
|
||||
adds the releevant ObjectID to the list contained_objectids. The list
|
||||
contained_objectids should be reset between calls to numbuf_serialize.
|
||||
|
||||
Args:
|
||||
value: A Python object that will be serialized.
|
||||
|
||||
Returns:
|
||||
The serialized object.
|
||||
"""
|
||||
assert len(contained_objectids) == 0, "This should be unreachable."
|
||||
return libnumbuf.serialize_list([value])
|
||||
|
||||
class RayTaskError(Exception):
|
||||
"""An object used internally to represent a task that threw an exception.
|
||||
|
||||
If a task throws an exception during execution, a RayTaskError is stored in
|
||||
the object store for each of the task's outputs. When an object is retrieved
|
||||
from the object store, the Python method that retrieved it checks to see if
|
||||
the object is a RayTaskError and if it is then an exceptionis thrown
|
||||
the object is a RayTaskError and if it is then an exception is thrown
|
||||
propagating the error message.
|
||||
|
||||
Currently, we either use the exception attribute or the traceback attribute
|
||||
@@ -50,32 +68,6 @@ class RayTaskError(Exception):
|
||||
self.exception = None
|
||||
self.traceback_str = traceback_str
|
||||
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
"""Create a RayTaskError from a primitive object."""
|
||||
function_name, exception, traceback_str = primitives
|
||||
if exception[0] == "RayGetError":
|
||||
exception = RayGetError.deserialize(exception[1])
|
||||
elif exception[0] == "RayGetArgumentError":
|
||||
exception = RayGetArgumentError.deserialize(exception[1])
|
||||
elif exception[0] == "None":
|
||||
exception = None
|
||||
else:
|
||||
assert False, "This code should be unreachable."
|
||||
return RayTaskError(function_name, exception, traceback_str)
|
||||
|
||||
def serialize(self):
|
||||
"""Turn a RayTaskError into a primitive object."""
|
||||
if isinstance(self.exception, RayGetError):
|
||||
serialized_exception = ("RayGetError", self.exception.serialize())
|
||||
elif isinstance(self.exception, RayGetArgumentError):
|
||||
serialized_exception = ("RayGetArgumentError", self.exception.serialize())
|
||||
elif self.exception is None:
|
||||
serialized_exception = ("None",)
|
||||
else:
|
||||
assert False, "This code should be unreachable."
|
||||
return (self.function_name, serialized_exception, self.traceback_str)
|
||||
|
||||
def __str__(self):
|
||||
"""Format a RayTaskError as a string."""
|
||||
if self.traceback_str is None:
|
||||
@@ -99,16 +91,6 @@ class RayGetError(Exception):
|
||||
self.objectid = objectid
|
||||
self.task_error = task_error
|
||||
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
"""Create a RayGetError from a primitive object."""
|
||||
objectid, task_error = primitives
|
||||
return RayGetError(objectid, RayTaskError.deserialize(task_error))
|
||||
|
||||
def serialize(self):
|
||||
"""Turn a RayGetError into a primitive object."""
|
||||
return (self.objectid, self.task_error.serialize())
|
||||
|
||||
def __str__(self):
|
||||
"""Format a RayGetError as a string."""
|
||||
return "Could not get objectid {}. It was created by remote function {}{}{} which failed with:\n\n{}".format(self.objectid, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
||||
@@ -132,16 +114,6 @@ class RayGetArgumentError(Exception):
|
||||
self.objectid = objectid
|
||||
self.task_error = task_error
|
||||
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
"""Create a RayGetArgumentError from a primitive object."""
|
||||
function_name, argument_index, objectid, task_error = primitives
|
||||
return RayGetArgumentError(function_name, argument_index, objectid, RayTaskError.deserialize(task_error))
|
||||
|
||||
def serialize(self):
|
||||
"""Turn a RayGetArgumentError into a primitive object."""
|
||||
return (self.function_name, self.argument_index, self.objectid, self.task_error.serialize())
|
||||
|
||||
def __str__(self):
|
||||
"""Format a RayGetArgumentError as a string."""
|
||||
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
||||
@@ -366,30 +338,27 @@ class Worker(object):
|
||||
objectid (raylib.ObjectID): The object ID of the value to be put.
|
||||
value (serializable object): The value to put in the object store.
|
||||
"""
|
||||
try:
|
||||
# We put the value into a list here because in arrow the concept of
|
||||
# "serializing a single object" does not exits.
|
||||
schema, size, serialized = libnumbuf.serialize_list([value])
|
||||
# TODO(pcm): Right now, metadata is serialized twice, change that in the future
|
||||
# in the following line, the "8" is for storing the metadata size,
|
||||
# the len(schema) is for storing the metadata and the 4096 is for storing
|
||||
# the metadata in the batch (see INITIAL_METADATA_SIZE in arrow)
|
||||
size = size + 8 + len(schema) + 4096
|
||||
buff, segmentid = raylib.allocate_buffer(self.handle, objectid, size)
|
||||
# write the metadata length
|
||||
np.frombuffer(buff, dtype="int64", count=1)[0] = len(schema)
|
||||
# metadata buffer
|
||||
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=len(schema))
|
||||
# write the metadata
|
||||
metadata[:] = schema
|
||||
data = np.frombuffer(buff, dtype="byte")[8 + len(schema):]
|
||||
metadata_offset = libnumbuf.write_to_buffer(serialized, memoryview(data))
|
||||
raylib.finish_buffer(self.handle, objectid, segmentid, metadata_offset)
|
||||
except:
|
||||
# At the moment, custom object and objects that contain object IDs take this path
|
||||
# TODO(pcm): Make sure that these are the only objects getting serialized to protobuf
|
||||
object_capsule, contained_objectids = serialization.serialize(self.handle, value) # contained_objectids is a list of the objectids contained in object_capsule
|
||||
raylib.put_object(self.handle, objectid, object_capsule, contained_objectids)
|
||||
# We put the value into a list here because in arrow the concept of
|
||||
# "serializing a single object" does not exits.
|
||||
schema, size, serialized = numbuf_serialize(value)
|
||||
global contained_objectids
|
||||
raylib.add_contained_objectids(self.handle, objectid, contained_objectids)
|
||||
contained_objectids = []
|
||||
# TODO(pcm): Right now, metadata is serialized twice, change that in the future
|
||||
# in the following line, the "8" is for storing the metadata size,
|
||||
# the len(schema) is for storing the metadata and the 8192 is for storing
|
||||
# the metadata in the batch (see INITIAL_METADATA_SIZE in arrow)
|
||||
size = size + 8 + len(schema) + 4096
|
||||
buff, segmentid = raylib.allocate_buffer(self.handle, objectid, size)
|
||||
# write the metadata length
|
||||
np.frombuffer(buff, dtype="int64", count=1)[0] = len(schema)
|
||||
# metadata buffer
|
||||
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=len(schema))
|
||||
# write the metadata
|
||||
metadata[:] = schema
|
||||
data = np.frombuffer(buff, dtype="byte")[8 + len(schema):]
|
||||
metadata_offset = libnumbuf.write_to_buffer(serialized, memoryview(data))
|
||||
raylib.finish_buffer(self.handle, objectid, segmentid, metadata_offset)
|
||||
|
||||
def get_object(self, objectid):
|
||||
"""Get the value in the local object store associated with objectid.
|
||||
@@ -400,32 +369,25 @@ class Worker(object):
|
||||
Args:
|
||||
objectid (raylib.ObjectID): The object ID of the value to retrieve.
|
||||
"""
|
||||
if raylib.is_arrow(self.handle, objectid):
|
||||
## this is the new codepath
|
||||
buff, segmentid, metadata_offset = raylib.get_buffer(self.handle, objectid)
|
||||
metadata_size = np.frombuffer(buff, dtype="int64", count=1)[0]
|
||||
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=metadata_size)
|
||||
data = np.frombuffer(buff, dtype="byte")[8 + metadata_size:]
|
||||
serialized = libnumbuf.read_from_buffer(memoryview(data), bytearray(metadata), metadata_offset)
|
||||
# If there is currently no ObjectFixture for this ObjectID, then create a
|
||||
# new one. The object_fixtures object is a WeakValueDictionary, so entries
|
||||
# will be discarded when there are no strong references to their values.
|
||||
# We create object_fixture outside of the assignment because if we created
|
||||
# it inside the assignement it would immediately go out of scope.
|
||||
object_fixture = None
|
||||
if objectid.id not in object_fixtures:
|
||||
object_fixture = ObjectFixture(objectid, segmentid, self.handle)
|
||||
object_fixtures[objectid.id] = object_fixture
|
||||
deserialized = libnumbuf.deserialize_list(serialized, object_fixtures[objectid.id])
|
||||
# Unwrap the object from the list (it was wrapped put_object)
|
||||
assert len(deserialized) == 1
|
||||
result = deserialized[0]
|
||||
## this is the old codepath
|
||||
# result, segmentid = raylib.get_arrow(self.handle, objectid)
|
||||
else:
|
||||
object_capsule, segmentid = raylib.get_object(self.handle, objectid)
|
||||
result = serialization.deserialize(self.handle, object_capsule)
|
||||
|
||||
assert raylib.is_arrow(self.handle, objectid), "All objects should be serialized using Arrow."
|
||||
buff, segmentid, metadata_offset = raylib.get_buffer(self.handle, objectid)
|
||||
metadata_size = np.frombuffer(buff, dtype="int64", count=1)[0]
|
||||
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=metadata_size)
|
||||
data = np.frombuffer(buff, dtype="byte")[8 + metadata_size:]
|
||||
serialized = libnumbuf.read_from_buffer(memoryview(data), bytearray(metadata), metadata_offset)
|
||||
# If there is currently no ObjectFixture for this ObjectID, then create a
|
||||
# new one. The object_fixtures object is a WeakValueDictionary, so entries
|
||||
# will be discarded when there are no strong references to their values.
|
||||
# We create object_fixture outside of the assignment because if we created
|
||||
# it inside the assignement it would immediately go out of scope.
|
||||
object_fixture = None
|
||||
if objectid.id not in object_fixtures:
|
||||
object_fixture = ObjectFixture(objectid, segmentid, self.handle)
|
||||
object_fixtures[objectid.id] = object_fixture
|
||||
deserialized = libnumbuf.deserialize_list(serialized, object_fixtures[objectid.id])
|
||||
# Unwrap the object from the list (it was wrapped put_object)
|
||||
assert len(deserialized) == 1
|
||||
result = deserialized[0]
|
||||
return result
|
||||
|
||||
def alias_objectids(self, alias_objectid, target_objectid):
|
||||
@@ -445,7 +407,10 @@ class Worker(object):
|
||||
be object IDs or they can be values. If they are values, they
|
||||
must be serializable objecs.
|
||||
"""
|
||||
task_capsule = serialization.serialize_task(self.handle, func_name, args)
|
||||
# Convert all of the argumens to object IDs. It is a little strange that we
|
||||
# are calling put, which is external to this class.
|
||||
args = [arg if isinstance(arg, raylib.ObjectID) else put(arg, worker=self) for arg in args]
|
||||
task_capsule = raylib.serialize_task(self.handle, func_name, args)
|
||||
objectids = raylib.submit_task(self.handle, task_capsule)
|
||||
return objectids
|
||||
|
||||
@@ -461,10 +426,13 @@ class Worker(object):
|
||||
not take any arguments. If it returns anything, its return values will
|
||||
not be used.
|
||||
"""
|
||||
if self.mode not in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]:
|
||||
raise Exception("run_function_on_all_workers can only be called on a driver.")
|
||||
# First run the function on the driver.
|
||||
function(self)
|
||||
# Then run the function on all of the workers.
|
||||
raylib.run_function_on_all_workers(self.handle, pickling.dumps(function))
|
||||
if self.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
raylib.run_function_on_all_workers(self.handle, pickling.dumps(function))
|
||||
|
||||
global_worker = Worker()
|
||||
"""Worker: The global Worker object for this worker process.
|
||||
@@ -568,6 +536,28 @@ def task_info(worker=global_worker):
|
||||
check_connected(worker)
|
||||
return raylib.task_info(worker.handle)
|
||||
|
||||
def initialize_numbuf(worker=global_worker):
|
||||
"""Initialize the serialization library.
|
||||
|
||||
This defines a custom serializer for object IDs and also tells numbuf to
|
||||
serialize several exception classes that we define for error handling.
|
||||
"""
|
||||
# Define a custom serializer and deserializer for handling Object IDs.
|
||||
def objectid_custom_serializer(obj):
|
||||
class_identifier = serialization.class_identifier(type(obj))
|
||||
contained_objectids.append(obj)
|
||||
return raylib.serialize_objectid(worker.handle, obj)
|
||||
def objectid_custom_deserializer(serialized_obj):
|
||||
return raylib.deserialize_objectid(worker.handle, serialized_obj)
|
||||
serialization.add_class_to_whitelist(raylib.ObjectID, pickle=False, custom_serializer=objectid_custom_serializer, custom_deserializer=objectid_custom_deserializer)
|
||||
|
||||
if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
# These should only be called on the driver because register_class will
|
||||
# export the class to all of the workers.
|
||||
register_class(RayTaskError)
|
||||
register_class(RayGetError)
|
||||
register_class(RayGetArgumentError)
|
||||
|
||||
def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_address=None, node_ip_address=None, driver_mode=raylib.SCRIPT_MODE):
|
||||
"""Either connect to an existing Ray cluster or start one and connect to it.
|
||||
|
||||
@@ -735,14 +725,16 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl
|
||||
# the same.
|
||||
script_directory = os.path.abspath(os.path.dirname(sys.argv[0]))
|
||||
current_directory = os.path.abspath(os.path.curdir)
|
||||
worker.run_function_on_all_workers(lambda worker : sys.path.insert(1, script_directory))
|
||||
worker.run_function_on_all_workers(lambda worker : sys.path.insert(1, current_directory))
|
||||
worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, script_directory))
|
||||
worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, current_directory))
|
||||
# Export cached remote functions to the workers.
|
||||
for function_name, function_to_export in worker.cached_remote_functions:
|
||||
raylib.export_remote_function(worker.handle, function_name, function_to_export)
|
||||
# Export cached reusable variables to the workers.
|
||||
for name, reusable_variable in reusables._cached_reusables:
|
||||
_export_reusable_variable(name, reusable_variable)
|
||||
# Initialize the serialization library.
|
||||
initialize_numbuf()
|
||||
worker.cached_remote_functions = None
|
||||
reusables._cached_reusables = None
|
||||
|
||||
@@ -757,6 +749,30 @@ def disconnect(worker=global_worker):
|
||||
worker.cached_remote_functions = []
|
||||
reusables._cached_reusables = []
|
||||
|
||||
def register_class(cls, pickle=False, worker=global_worker):
|
||||
"""Enable workers to serialize or deserialize objects of a particular class.
|
||||
|
||||
This method runs the register_class function defined below on every worker,
|
||||
which will enable libnumbuf to properly serialize and deserialize objects of
|
||||
this class.
|
||||
|
||||
Args:
|
||||
cls (type): The class that libnumbuf should serialize.
|
||||
pickle (bool): If False then objects of this class will be serialized by
|
||||
turning their __dict__ fields into a dictionary. If True, then objects
|
||||
of this class will be serialized using pickle.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if pickle=False and the class cannot be
|
||||
efficiently serialized by Ray.
|
||||
"""
|
||||
# Raise an exception if cls cannot be serialized efficiently by Ray.
|
||||
if not pickle:
|
||||
serialization.check_serializable(cls)
|
||||
def register_class_for_serialization(worker):
|
||||
serialization.add_class_to_whitelist(cls, pickle=pickle)
|
||||
worker.run_function_on_all_workers(register_class_for_serialization)
|
||||
|
||||
def get(objectid, worker=global_worker):
|
||||
"""Get a remote object or a list of remote objects from the object store.
|
||||
|
||||
@@ -915,7 +931,7 @@ def main_loop(worker=global_worker):
|
||||
After the task executes, the worker resets any reusable variables that were
|
||||
accessed by the task.
|
||||
"""
|
||||
function_name, args, return_objectids = serialization.deserialize_task(worker.handle, task)
|
||||
function_name, args, return_objectids = task
|
||||
try:
|
||||
arguments = get_arguments_for_execution(worker.functions[function_name], args, worker) # get args from objstore
|
||||
outputs = worker.functions[function_name].executor(arguments) # execute the function
|
||||
|
||||
Reference in New Issue
Block a user