mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 00:29:38 +08:00
Collect object IDs during serialization (#6946)
This commit is contained in:
@@ -60,8 +60,8 @@ cdef class CoreWorker:
|
||||
|
||||
cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata,
|
||||
size_t data_size, ObjectID object_id,
|
||||
c_vector[CObjectID] contained_ids,
|
||||
CObjectID *c_object_id, shared_ptr[CBuffer] *data)
|
||||
# TODO: handle noreturn better
|
||||
cdef store_task_outputs(
|
||||
self, worker, outputs, const c_vector[CObjectID] return_ids,
|
||||
c_vector[shared_ptr[CRayObject]] *returns)
|
||||
|
||||
+20
-9
@@ -287,7 +287,13 @@ cdef void prepare_args(
|
||||
else:
|
||||
serialized_arg = worker.get_serialization_context().serialize(arg)
|
||||
size = serialized_arg.total_bytes
|
||||
if <int64_t>size <= put_threshold:
|
||||
|
||||
# TODO(edoakes): any objects containing ObjectIDs are spilled to
|
||||
# plasma here. This is inefficient for small objects, but inlined
|
||||
# arguments aren't associated ObjectIDs right now so this is a
|
||||
# simple fix for reference counting purposes.
|
||||
if (<int64_t>size <= put_threshold and
|
||||
len(serialized_arg.contained_object_ids) == 0):
|
||||
arg_data = dynamic_pointer_cast[CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](size))
|
||||
write_serialized_object(serialized_arg, arg_data)
|
||||
@@ -645,6 +651,7 @@ cdef class CoreWorker:
|
||||
|
||||
cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata,
|
||||
size_t data_size, ObjectID object_id,
|
||||
c_vector[CObjectID] contained_ids,
|
||||
CObjectID *c_object_id, shared_ptr[CBuffer] *data):
|
||||
delay = ray_constants.DEFAULT_PUT_OBJECT_DELAY
|
||||
for attempt in reversed(
|
||||
@@ -653,13 +660,14 @@ cdef class CoreWorker:
|
||||
if object_id is None:
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().Create(
|
||||
metadata, data_size,
|
||||
metadata, data_size, contained_ids,
|
||||
c_object_id, data))
|
||||
else:
|
||||
c_object_id[0] = object_id.native()
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().Create(
|
||||
metadata, data_size, c_object_id[0], data))
|
||||
metadata, data_size, contained_ids,
|
||||
c_object_id[0], data))
|
||||
break
|
||||
except ObjectStoreFullError as e:
|
||||
if attempt:
|
||||
@@ -685,22 +693,22 @@ cdef class CoreWorker:
|
||||
CObjectID c_object_id
|
||||
shared_ptr[CBuffer] data
|
||||
shared_ptr[CBuffer] metadata
|
||||
# The object won't be pinned if an ObjectID is provided by the
|
||||
# user (because we can't track its lifetime to unpin). Note that
|
||||
# the API to do this isn't supported as a public API.
|
||||
c_bool owns_object = object_id is None
|
||||
|
||||
metadata = string_to_buffer(serialized_object.metadata)
|
||||
total_bytes = serialized_object.total_bytes
|
||||
object_already_exists = self._create_put_buffer(
|
||||
metadata, total_bytes, object_id,
|
||||
ObjectIDsToVector(serialized_object.contained_object_ids),
|
||||
&c_object_id, &data)
|
||||
|
||||
if not object_already_exists:
|
||||
write_serialized_object(serialized_object, data)
|
||||
with nogil:
|
||||
# Using custom object IDs is not supported because we can't
|
||||
# track their lifecycle, so don't pin the object in that case.
|
||||
check_status(
|
||||
self.core_worker.get().Seal(
|
||||
c_object_id, owns_object, pin_object))
|
||||
c_object_id, pin_object and object_id is None))
|
||||
|
||||
return ObjectID(c_object_id.Binary())
|
||||
|
||||
@@ -942,6 +950,7 @@ cdef class CoreWorker:
|
||||
cdef:
|
||||
c_vector[size_t] data_sizes
|
||||
c_vector[shared_ptr[CBuffer]] metadatas
|
||||
c_vector[c_vector[CObjectID]] contained_ids
|
||||
|
||||
if return_ids.size() == 0:
|
||||
return
|
||||
@@ -963,9 +972,11 @@ cdef class CoreWorker:
|
||||
metadatas.push_back(
|
||||
string_to_buffer(serialized_object.metadata))
|
||||
serialized_objects.append(serialized_object)
|
||||
contained_ids.push_back(
|
||||
ObjectIDsToVector(serialized_object.contained_object_ids))
|
||||
|
||||
check_status(self.core_worker.get().AllocateReturnObjects(
|
||||
return_ids, data_sizes, metadatas, returns))
|
||||
return_ids, data_sizes, metadatas, contained_ids, returns))
|
||||
|
||||
for i, serialized_object in enumerate(serialized_objects):
|
||||
# A nullptr is returned if the object already exists.
|
||||
|
||||
@@ -52,7 +52,8 @@ def dump(obj, file, protocol=None, buffer_callback=None):
|
||||
Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
|
||||
compatibility with older versions of Python.
|
||||
"""
|
||||
CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback).dump(obj)
|
||||
CloudPickler(file, protocol=protocol,
|
||||
buffer_callback=buffer_callback).dump(obj)
|
||||
|
||||
|
||||
def dumps(obj, protocol=None, buffer_callback=None):
|
||||
@@ -66,7 +67,8 @@ def dumps(obj, protocol=None, buffer_callback=None):
|
||||
compatibility with older versions of Python.
|
||||
"""
|
||||
with io.BytesIO() as file:
|
||||
cp = CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback)
|
||||
cp = CloudPickler(file, protocol=protocol,
|
||||
buffer_callback=buffer_callback)
|
||||
cp.dump(obj)
|
||||
return file.getvalue()
|
||||
|
||||
@@ -79,9 +81,9 @@ def _class_getnewargs(obj):
|
||||
if hasattr(obj, "__slots__"):
|
||||
type_kwargs["__slots__"] = obj.__slots__
|
||||
|
||||
__dict__ = obj.__dict__.get('__dict__', None)
|
||||
__dict__ = obj.__dict__.get("__dict__", None)
|
||||
if isinstance(__dict__, property):
|
||||
type_kwargs['__dict__'] = __dict__
|
||||
type_kwargs["__dict__"] = __dict__
|
||||
|
||||
return (type(obj), obj.__name__, obj.__bases__, type_kwargs,
|
||||
_ensure_tracking(obj), None)
|
||||
@@ -141,7 +143,7 @@ def _function_getstate(func):
|
||||
|
||||
def _class_getstate(obj):
|
||||
clsdict = _extract_class_dict(obj)
|
||||
clsdict.pop('__weakref__', None)
|
||||
clsdict.pop("__weakref__", None)
|
||||
|
||||
# For ABCMeta in python3.7+, remove _abc_impl as it is not picklable.
|
||||
# This is a fix which breaks the cache but this only makes the first
|
||||
@@ -160,7 +162,7 @@ def _class_getstate(obj):
|
||||
for k in obj.__slots__:
|
||||
clsdict.pop(k, None)
|
||||
|
||||
clsdict.pop('__dict__', None) # unpicklable property object
|
||||
clsdict.pop("__dict__", None) # unpicklable property object
|
||||
|
||||
return (clsdict, {})
|
||||
|
||||
@@ -428,10 +430,10 @@ def _numpy_ndarray_reduce(array):
|
||||
# the PickleBuffer instance will hold a view on the transpose
|
||||
# of the initial array, that is C-contiguous.
|
||||
if not array.flags.c_contiguous and array.flags.f_contiguous:
|
||||
order = 'F'
|
||||
order = "F"
|
||||
picklebuf_args = array.transpose()
|
||||
else:
|
||||
order = 'C'
|
||||
order = "C"
|
||||
picklebuf_args = array
|
||||
try:
|
||||
buffer = picklebuf_class(picklebuf_args)
|
||||
@@ -485,7 +487,8 @@ class CloudPickler(Pickler):
|
||||
def __init__(self, file, protocol=None, buffer_callback=None):
|
||||
if protocol is None:
|
||||
protocol = DEFAULT_PROTOCOL
|
||||
Pickler.__init__(self, file, protocol=protocol, buffer_callback=buffer_callback)
|
||||
Pickler.__init__(self, file, protocol=protocol,
|
||||
buffer_callback=buffer_callback)
|
||||
# map functions __globals__ attribute ids, to ensure that functions
|
||||
# sharing the same global namespace at pickling time also share their
|
||||
# global namespace at unpickling time.
|
||||
@@ -531,8 +534,9 @@ class CloudPickler(Pickler):
|
||||
# This is a patch for python3.5
|
||||
if isinstance(obj, numpy.ndarray):
|
||||
if (self.proto < 5 or
|
||||
(not obj.flags.c_contiguous and not obj.flags.f_contiguous) or
|
||||
obj.dtype == 'O' or obj.itemsize == 0):
|
||||
(not obj.flags.c_contiguous and
|
||||
not obj.flags.f_contiguous) or
|
||||
obj.dtype == "O" or obj.itemsize == 0):
|
||||
return NotImplemented
|
||||
return _numpy_ndarray_reduce(obj)
|
||||
|
||||
|
||||
@@ -106,6 +106,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
const c_vector[CObjectID] &object_ids,
|
||||
const c_vector[size_t] &data_sizes,
|
||||
const c_vector[shared_ptr[CBuffer]] &metadatas,
|
||||
const c_vector[c_vector[CObjectID]] &contained_object_ids,
|
||||
c_vector[shared_ptr[CRayObject]] *return_objects)
|
||||
|
||||
CJobID GetCurrentJobId()
|
||||
@@ -129,16 +130,22 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
CAddress &owner_address)
|
||||
|
||||
CRayStatus SetClientOptions(c_string client_name, int64_t limit)
|
||||
CRayStatus Put(const CRayObject &object, CObjectID *object_id)
|
||||
CRayStatus Put(const CRayObject &object, const CObjectID &object_id)
|
||||
CRayStatus Put(const CRayObject &object,
|
||||
const c_vector[CObjectID] &contained_object_ids,
|
||||
CObjectID *object_id)
|
||||
CRayStatus Put(const CRayObject &object,
|
||||
const c_vector[CObjectID] &contained_object_ids,
|
||||
const CObjectID &object_id)
|
||||
CRayStatus Create(const shared_ptr[CBuffer] &metadata,
|
||||
const size_t data_size,
|
||||
const c_vector[CObjectID] &contained_object_ids,
|
||||
CObjectID *object_id, shared_ptr[CBuffer] *data)
|
||||
CRayStatus Create(const shared_ptr[CBuffer] &metadata,
|
||||
const size_t data_size, const CObjectID &object_id,
|
||||
const size_t data_size,
|
||||
const c_vector[CObjectID] &contained_object_ids,
|
||||
const CObjectID &object_id,
|
||||
shared_ptr[CBuffer] *data)
|
||||
CRayStatus Seal(const CObjectID &object_id, c_bool owns_object,
|
||||
c_bool pin_object)
|
||||
CRayStatus Seal(const CObjectID &object_id, c_bool pin_object)
|
||||
CRayStatus Get(const c_vector[CObjectID] &ids, int64_t timeout_ms,
|
||||
c_vector[shared_ptr[CRayObject]] *results)
|
||||
CRayStatus Contains(const CObjectID &object_id, c_bool *has_object)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
|
||||
import pyarrow.plasma as plasma
|
||||
|
||||
@@ -34,8 +35,9 @@ class DeserializationError(Exception):
|
||||
|
||||
|
||||
class SerializedObject:
|
||||
def __init__(self, metadata):
|
||||
def __init__(self, metadata, contained_object_ids=None):
|
||||
self._metadata = metadata
|
||||
self._contained_object_ids = contained_object_ids or []
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
@@ -45,11 +47,15 @@ class SerializedObject:
|
||||
def metadata(self):
|
||||
return self._metadata
|
||||
|
||||
@property
|
||||
def contained_object_ids(self):
|
||||
return self._contained_object_ids
|
||||
|
||||
|
||||
class Pickle5SerializedObject(SerializedObject):
|
||||
def __init__(self, inband, writer):
|
||||
super(Pickle5SerializedObject,
|
||||
self).__init__(ray_constants.PICKLE5_BUFFER_METADATA)
|
||||
def __init__(self, inband, writer, contained_object_ids):
|
||||
super(Pickle5SerializedObject, self).__init__(
|
||||
ray_constants.PICKLE5_BUFFER_METADATA, contained_object_ids)
|
||||
self.inband = inband
|
||||
self.writer = writer
|
||||
# cached total bytes
|
||||
@@ -126,6 +132,7 @@ class SerializationContext:
|
||||
self.worker = worker
|
||||
assert worker.use_pickle
|
||||
self.use_pickle = worker.use_pickle
|
||||
self._thread_local = threading.local()
|
||||
|
||||
def actor_handle_serializer(obj):
|
||||
return obj._serialization_helper(True)
|
||||
@@ -147,6 +154,7 @@ class SerializationContext:
|
||||
return serialized_obj[0](*serialized_obj[1])
|
||||
|
||||
def object_id_serializer(obj):
|
||||
self.add_contained_object_id(obj)
|
||||
owner_id = ""
|
||||
owner_address = ""
|
||||
if obj.is_direct_call_type():
|
||||
@@ -192,6 +200,21 @@ class SerializationContext:
|
||||
# construct a reducer
|
||||
pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer
|
||||
|
||||
def get_and_clear_contained_object_ids(self):
|
||||
if not hasattr(self._thread_local, "object_ids"):
|
||||
self._thread_local.object_ids = set()
|
||||
return set()
|
||||
|
||||
object_ids = self._thread_local.object_ids
|
||||
self._thread_local.object_ids = set()
|
||||
return object_ids
|
||||
|
||||
def add_contained_object_id(self, object_id):
|
||||
if not hasattr(self._thread_local, "object_ids"):
|
||||
self._thread_local.object_ids = set()
|
||||
|
||||
self._thread_local.object_ids.add(object_id)
|
||||
|
||||
def _deserialize_object(self, data, metadata, object_id):
|
||||
if metadata:
|
||||
if metadata == ray_constants.PICKLE5_BUFFER_METADATA:
|
||||
@@ -291,7 +314,8 @@ class SerializationContext:
|
||||
writer = Pickle5Writer()
|
||||
inband = pickle.dumps(
|
||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||
return Pickle5SerializedObject(inband, writer)
|
||||
return Pickle5SerializedObject(
|
||||
inband, writer, self.get_and_clear_contained_object_ids())
|
||||
|
||||
def register_custom_serializer(self,
|
||||
cls,
|
||||
|
||||
@@ -293,13 +293,9 @@ class Worker:
|
||||
should_warn_of_slow_puts = False
|
||||
return result
|
||||
|
||||
def deserialize_objects(self,
|
||||
data_metadata_pairs,
|
||||
object_ids,
|
||||
error_timeout=10):
|
||||
def deserialize_objects(self, data_metadata_pairs, object_ids):
|
||||
context = self.get_serialization_context()
|
||||
return context.deserialize_objects(data_metadata_pairs, object_ids,
|
||||
error_timeout)
|
||||
return context.deserialize_objects(data_metadata_pairs, object_ids)
|
||||
|
||||
def get_objects(self, object_ids, timeout=None):
|
||||
"""Get the values in the object store associated with the IDs.
|
||||
|
||||
Reference in New Issue
Block a user