mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 03:59:39 +08:00
Remove vanilla pickle serialization for task arguments (#6948)
This commit is contained in:
+27
-168
@@ -98,9 +98,6 @@ from ray.utils import decode
|
||||
from ray.ray_constants import (
|
||||
DEFAULT_PUT_OBJECT_DELAY,
|
||||
DEFAULT_PUT_OBJECT_RETRIES,
|
||||
RAW_BUFFER_METADATA,
|
||||
PICKLE_BUFFER_METADATA,
|
||||
PICKLE5_BUFFER_METADATA,
|
||||
)
|
||||
|
||||
# pyarrow cannot be imported until after _raylet finishes initializing
|
||||
@@ -215,84 +212,6 @@ def compute_task_id(ObjectID object_id):
|
||||
return TaskID(object_id.native().TaskId().Binary())
|
||||
|
||||
|
||||
cdef c_bool is_simple_value(value, int64_t *num_elements_contained):
|
||||
num_elements_contained[0] += 1
|
||||
|
||||
if num_elements_contained[0] >= RayConfig.instance().num_elements_limit():
|
||||
return False
|
||||
|
||||
if (cpython.PyInt_Check(value) or cpython.PyLong_Check(value) or
|
||||
value is False or value is True or cpython.PyFloat_Check(value) or
|
||||
value is None):
|
||||
return True
|
||||
|
||||
if cpython.PyBytes_CheckExact(value):
|
||||
num_elements_contained[0] += cpython.PyBytes_Size(value)
|
||||
return (num_elements_contained[0] <
|
||||
RayConfig.instance().num_elements_limit())
|
||||
|
||||
if cpython.PyUnicode_CheckExact(value):
|
||||
num_elements_contained[0] += cpython.PyUnicode_GET_SIZE(value)
|
||||
return (num_elements_contained[0] <
|
||||
RayConfig.instance().num_elements_limit())
|
||||
|
||||
if (cpython.PyList_CheckExact(value) and
|
||||
cpython.PyList_Size(value) < RayConfig.instance().size_limit()):
|
||||
for item in value:
|
||||
if not is_simple_value(item, num_elements_contained):
|
||||
return False
|
||||
return (num_elements_contained[0] <
|
||||
RayConfig.instance().num_elements_limit())
|
||||
|
||||
if (cpython.PyDict_CheckExact(value) and
|
||||
cpython.PyDict_Size(value) < RayConfig.instance().size_limit()):
|
||||
# TODO(suquark): Using "items" in Python2 is not very efficient.
|
||||
for k, v in value.items():
|
||||
if not (is_simple_value(k, num_elements_contained) and
|
||||
is_simple_value(v, num_elements_contained)):
|
||||
return False
|
||||
return (num_elements_contained[0] <
|
||||
RayConfig.instance().num_elements_limit())
|
||||
|
||||
if (cpython.PyTuple_CheckExact(value) and
|
||||
cpython.PyTuple_Size(value) < RayConfig.instance().size_limit()):
|
||||
for item in value:
|
||||
if not is_simple_value(item, num_elements_contained):
|
||||
return False
|
||||
return (num_elements_contained[0] <
|
||||
RayConfig.instance().num_elements_limit())
|
||||
|
||||
if isinstance(value, numpy.ndarray):
|
||||
if value.dtype == "O":
|
||||
return False
|
||||
num_elements_contained[0] += value.nbytes
|
||||
return (num_elements_contained[0] <
|
||||
RayConfig.instance().num_elements_limit())
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_simple_value(value):
|
||||
"""Check if value is simple enough to be send by value.
|
||||
|
||||
This method checks if a Python object is sufficiently simple that it can
|
||||
be serialized and passed by value as an argument to a task (without being
|
||||
put in the object store). The details of which objects are sufficiently
|
||||
simple are defined by this method and are not particularly important.
|
||||
But for performance reasons, it is better to place "small" objects in
|
||||
the task itself and "large" objects in the object store.
|
||||
|
||||
Args:
|
||||
value: Python object that should be checked.
|
||||
|
||||
Returns:
|
||||
True if the value should be send by value, False otherwise.
|
||||
"""
|
||||
|
||||
cdef int64_t num_elements_contained = 0
|
||||
return is_simple_value(value, &num_elements_contained)
|
||||
|
||||
|
||||
cdef class Language:
|
||||
cdef CLanguage lang
|
||||
|
||||
@@ -357,56 +276,35 @@ cdef c_vector[c_string] string_vector_from_list(list string_list):
|
||||
out.push_back(s)
|
||||
return out
|
||||
|
||||
|
||||
cdef:
|
||||
c_string pickle_metadata_str = PICKLE_BUFFER_METADATA
|
||||
shared_ptr[CBuffer] pickle_metadata = dynamic_pointer_cast[
|
||||
CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](
|
||||
<uint8_t*>(pickle_metadata_str.data()),
|
||||
pickle_metadata_str.size(), True))
|
||||
c_string raw_meta_str = RAW_BUFFER_METADATA
|
||||
shared_ptr[CBuffer] raw_metadata = dynamic_pointer_cast[
|
||||
CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](
|
||||
<uint8_t*>(raw_meta_str.data()),
|
||||
raw_meta_str.size(), True))
|
||||
|
||||
cdef void prepare_args(list args, c_vector[CTaskArg] *args_vector):
|
||||
cdef void prepare_args(
|
||||
CoreWorker core_worker, list args, c_vector[CTaskArg] *args_vector):
|
||||
cdef:
|
||||
c_string pickled_str
|
||||
const unsigned char[:] buffer
|
||||
size_t size
|
||||
int64_t put_threshold
|
||||
shared_ptr[CBuffer] arg_data
|
||||
shared_ptr[CBuffer] arg_metadata
|
||||
|
||||
# TODO be consistent with store_task_outputs
|
||||
worker = ray.worker.global_worker
|
||||
put_threshold = RayConfig.instance().max_direct_call_object_size()
|
||||
for arg in args:
|
||||
if isinstance(arg, ObjectID):
|
||||
args_vector.push_back(
|
||||
CTaskArg.PassByReference((<ObjectID>arg).native()))
|
||||
elif not ray._raylet.check_simple_value(arg):
|
||||
args_vector.push_back(
|
||||
CTaskArg.PassByReference((<ObjectID>ray.put(arg)).native()))
|
||||
elif type(arg) is bytes:
|
||||
buffer = arg
|
||||
size = buffer.nbytes
|
||||
arg_data = dynamic_pointer_cast[CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](
|
||||
<uint8_t*>(&buffer[0]), size, True))
|
||||
args_vector.push_back(
|
||||
CTaskArg.PassByValue(
|
||||
make_shared[CRayObject](arg_data, raw_metadata)))
|
||||
|
||||
else:
|
||||
buffer = pickle.dumps(
|
||||
arg, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
size = buffer.nbytes
|
||||
arg_data = dynamic_pointer_cast[CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](
|
||||
<uint8_t*>(&buffer[0]), size, True))
|
||||
args_vector.push_back(
|
||||
CTaskArg.PassByValue(
|
||||
make_shared[CRayObject](arg_data, pickle_metadata)))
|
||||
serialized_arg = worker.get_serialization_context().serialize(arg)
|
||||
size = serialized_arg.total_bytes
|
||||
if <int64_t>size <= put_threshold:
|
||||
arg_data = dynamic_pointer_cast[CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](size))
|
||||
write_serialized_object(serialized_arg, arg_data)
|
||||
args_vector.push_back(
|
||||
CTaskArg.PassByValue(make_shared[CRayObject](
|
||||
arg_data, string_to_buffer(serialized_arg.metadata))))
|
||||
else:
|
||||
args_vector.push_back(
|
||||
CTaskArg.PassByReference(
|
||||
(<ObjectID>core_worker.put_serialized_object(
|
||||
serialized_arg)).native()))
|
||||
|
||||
|
||||
cdef class RayletClient:
|
||||
@@ -465,51 +363,12 @@ cdef class RayletClient:
|
||||
cdef deserialize_args(
|
||||
const c_vector[shared_ptr[CRayObject]] &c_args,
|
||||
const c_vector[CObjectID] &arg_reference_ids):
|
||||
cdef:
|
||||
c_vector[shared_ptr[CRayObject]] objects_to_deserialize
|
||||
|
||||
if c_args.size() == 0:
|
||||
if c_args.empty():
|
||||
return [], {}
|
||||
|
||||
args = []
|
||||
ids_to_deserialize = []
|
||||
id_indices = []
|
||||
for i in range(c_args.size()):
|
||||
# Passed by value.
|
||||
if arg_reference_ids[i].IsNil():
|
||||
if (c_args[i].get().HasMetadata()
|
||||
and Buffer.make(
|
||||
c_args[i].get().GetMetadata()).to_pybytes()
|
||||
== RAW_BUFFER_METADATA):
|
||||
data = Buffer.make(c_args[i].get().GetData())
|
||||
args.append(data.to_pybytes())
|
||||
elif (c_args[i].get().HasMetadata() and Buffer.make(
|
||||
c_args[i].get().GetMetadata()).to_pybytes()
|
||||
== PICKLE_BUFFER_METADATA):
|
||||
# This is a pickled "simple python value" argument.
|
||||
data = Buffer.make(c_args[i].get().GetData())
|
||||
args.append(pickle.loads(data.to_pybytes()))
|
||||
else:
|
||||
# This is a Ray object inlined by the direct task submitter.
|
||||
ids_to_deserialize.append(
|
||||
ObjectID(arg_reference_ids[i].Binary()))
|
||||
id_indices.append(i)
|
||||
objects_to_deserialize.push_back(c_args[i])
|
||||
args.append(None)
|
||||
# Passed by reference.
|
||||
else:
|
||||
ids_to_deserialize.append(
|
||||
ObjectID(arg_reference_ids[i].Binary()))
|
||||
id_indices.append(i)
|
||||
objects_to_deserialize.push_back(c_args[i])
|
||||
args.append(None)
|
||||
|
||||
data_metadata_pairs = RayObjectsToDataMetadataPairs(
|
||||
objects_to_deserialize)
|
||||
for i, arg in enumerate(
|
||||
ray.worker.global_worker.deserialize_objects(
|
||||
data_metadata_pairs, ids_to_deserialize)):
|
||||
args[id_indices[i]] = arg
|
||||
args = ray.worker.global_worker.deserialize_objects(
|
||||
RayObjectsToDataMetadataPairs(c_args),
|
||||
VectorToObjectIDs(arg_reference_ids))
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, RayError):
|
||||
@@ -983,7 +842,7 @@ cdef class CoreWorker:
|
||||
num_return_vals, is_direct_call, c_resources)
|
||||
ray_function = CRayFunction(
|
||||
LANGUAGE_PYTHON, string_vector_from_list(function_descriptor))
|
||||
prepare_args(args, &args_vector)
|
||||
prepare_args(self, args, &args_vector)
|
||||
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().SubmitTask(
|
||||
@@ -1015,7 +874,7 @@ cdef class CoreWorker:
|
||||
prepare_resources(placement_resources, &c_placement_resources)
|
||||
ray_function = CRayFunction(
|
||||
LANGUAGE_PYTHON, string_vector_from_list(function_descriptor))
|
||||
prepare_args(args, &args_vector)
|
||||
prepare_args(self, args, &args_vector)
|
||||
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().CreateActor(
|
||||
@@ -1049,7 +908,7 @@ cdef class CoreWorker:
|
||||
task_options = CTaskOptions(num_return_vals, False, c_resources)
|
||||
ray_function = CRayFunction(
|
||||
LANGUAGE_PYTHON, string_vector_from_list(function_descriptor))
|
||||
prepare_args(args, &args_vector)
|
||||
prepare_args(self, args, &args_vector)
|
||||
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().SubmitActorTask(
|
||||
|
||||
@@ -48,10 +48,6 @@ cdef extern from "ray/common/ray_config.h" nogil:
|
||||
|
||||
int64_t max_time_for_handler_milliseconds() const
|
||||
|
||||
int64_t size_limit() const
|
||||
|
||||
int64_t num_elements_limit() const
|
||||
|
||||
int64_t max_time_for_loop() const
|
||||
|
||||
int64_t redis_db_connect_retries()
|
||||
@@ -86,4 +82,6 @@ cdef extern from "ray/common/ray_config.h" nogil:
|
||||
|
||||
uint32_t maximum_gcs_deletion_batch_size() const
|
||||
|
||||
int64_t max_direct_call_object_size() const
|
||||
|
||||
void initialize(const unordered_map[c_string, c_string] &config_map)
|
||||
|
||||
@@ -84,14 +84,6 @@ cdef class Config:
|
||||
def max_time_for_handler_milliseconds():
|
||||
return RayConfig.instance().max_time_for_handler_milliseconds()
|
||||
|
||||
@staticmethod
|
||||
def size_limit():
|
||||
return RayConfig.instance().size_limit()
|
||||
|
||||
@staticmethod
|
||||
def num_elements_limit():
|
||||
return RayConfig.instance().num_elements_limit()
|
||||
|
||||
@staticmethod
|
||||
def max_time_for_loop():
|
||||
return RayConfig.instance().max_time_for_loop()
|
||||
|
||||
@@ -95,8 +95,6 @@ cdef class TaskSpec:
|
||||
:self.task_spec.get().ArgMetadataSize(i)]
|
||||
if metadata == RAW_BUFFER_METADATA:
|
||||
obj = data
|
||||
elif metadata == PICKLE_BUFFER_METADATA:
|
||||
obj = pickle.loads(data)
|
||||
else:
|
||||
obj = data
|
||||
arg_list.append(obj)
|
||||
|
||||
@@ -317,7 +317,7 @@ class SerializationContext:
|
||||
# use a placeholder for 'self' argument
|
||||
pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer
|
||||
|
||||
def _deserialize_object_from_arrow(self, data, metadata, object_id):
|
||||
def _deserialize_object(self, data, metadata, object_id):
|
||||
if metadata:
|
||||
if metadata == ray_constants.PICKLE5_BUFFER_METADATA:
|
||||
if not self.use_pickle:
|
||||
@@ -437,8 +437,7 @@ class SerializationContext:
|
||||
data, metadata = data_metadata_pairs[i]
|
||||
try:
|
||||
results.append(
|
||||
self._deserialize_object_from_arrow(
|
||||
data, metadata, object_id))
|
||||
self._deserialize_object(data, metadata, object_id))
|
||||
i += 1
|
||||
except DeserializationError:
|
||||
# Wait a little bit for the import thread to import the class.
|
||||
|
||||
@@ -1523,8 +1523,6 @@ def put(value, weakref=False):
|
||||
"""Store an object in the object store.
|
||||
|
||||
The object may not be evicted while a reference to the returned ID exists.
|
||||
Note that this pinning only applies to the particular object ID returned
|
||||
by put, not object IDs in general.
|
||||
|
||||
Args:
|
||||
value: The Python object to be stored.
|
||||
|
||||
Reference in New Issue
Block a user