mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 07:52:05 +08:00
This reverts commit 1b1466748f.
This commit is contained in:
@@ -11,9 +11,6 @@ from libcpp.vector cimport vector as c_vector
|
||||
from libcpp.list cimport list as c_list
|
||||
from cpython cimport PyObject
|
||||
cimport cpython
|
||||
from libcpp.unordered_map cimport unordered_map as c_unordered_map
|
||||
from cython.operator cimport dereference, postincrement
|
||||
|
||||
|
||||
cdef inline object PyObject_to_object(PyObject* o):
|
||||
# Cast to "object" increments reference count
|
||||
@@ -35,7 +32,7 @@ from ray.includes.unique_ids cimport (
|
||||
CObjectID,
|
||||
)
|
||||
|
||||
cdef extern from "common/status.h" namespace "ray::streaming" nogil:
|
||||
cdef extern from "status.h" namespace "ray::streaming" nogil:
|
||||
cdef cppclass CStreamingStatus "ray::streaming::StreamingStatus":
|
||||
pass
|
||||
cdef CStreamingStatus StatusOK "ray::streaming::StreamingStatus::OK"
|
||||
@@ -73,21 +70,9 @@ cdef extern from "message/message.h" namespace "ray::streaming" nogil:
|
||||
cdef CStreamingMessageType MessageTypeMessage "ray::streaming::StreamingMessageType::Message"
|
||||
cdef cppclass CStreamingMessage "ray::streaming::StreamingMessage":
|
||||
inline uint8_t *RawData() const
|
||||
inline uint8_t *Payload() const
|
||||
inline uint32_t PayloadSize() const
|
||||
inline uint32_t GetDataSize() const
|
||||
inline CStreamingMessageType GetMessageType() const
|
||||
inline uint64_t GetMessageId() const
|
||||
@staticmethod
|
||||
inline void GetBarrierIdFromRawData(const uint8_t *data,
|
||||
CStreamingBarrierHeader *barrier_header)
|
||||
cdef struct CStreamingBarrierHeader "ray::streaming::StreamingBarrierHeader":
|
||||
CStreamingBarrierType barrier_type;
|
||||
uint64_t barrier_id;
|
||||
cdef cppclass CStreamingBarrierType "ray::streaming::StreamingBarrierType":
|
||||
pass
|
||||
cdef uint32_t kMessageHeaderSize;
|
||||
cdef uint32_t kBarrierHeaderSize;
|
||||
inline uint64_t GetMessageSeqId() const
|
||||
|
||||
cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil:
|
||||
cdef cppclass CStreamingMessageBundleType "ray::streaming::StreamingMessageBundleType":
|
||||
@@ -112,40 +97,13 @@ cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil:
|
||||
void GetMessageListFromRawData(const uint8_t *data, uint32_t size, uint32_t msg_nums,
|
||||
c_list[shared_ptr[CStreamingMessage]] &msg_list);
|
||||
|
||||
cdef extern from "channel/channel.h" namespace "ray::streaming" nogil:
|
||||
cdef extern from "channel.h" namespace "ray::streaming" nogil:
|
||||
cdef struct CChannelCreationParameter "ray::streaming::ChannelCreationParameter":
|
||||
CChannelCreationParameter()
|
||||
CActorID actor_id;
|
||||
shared_ptr[CRayFunction] async_function;
|
||||
shared_ptr[CRayFunction] sync_function;
|
||||
|
||||
cdef struct CStreamingQueueInfo "ray::streaming::StreamingQueueInfo":
|
||||
uint64_t first_seq_id;
|
||||
uint64_t last_message_id;
|
||||
uint64_t target_message_id;
|
||||
uint64_t consumed_message_id;
|
||||
|
||||
cdef struct CConsumerChannelInfo "ray::streaming::ConsumerChannelInfo":
|
||||
CObjectID channel_id;
|
||||
uint64_t current_message_id;
|
||||
uint64_t barrier_id;
|
||||
uint64_t partial_barrier_id;
|
||||
CStreamingQueueInfo queue_info;
|
||||
uint64_t last_queue_item_delay;
|
||||
uint64_t last_queue_item_latency;
|
||||
uint64_t last_queue_target_diff;
|
||||
uint64_t get_queue_item_times;
|
||||
uint64_t notify_cnt;
|
||||
CChannelCreationParameter parameter;
|
||||
|
||||
cdef enum CTransferCreationStatus "ray::streaming::TransferCreationStatus":
|
||||
FreshStarted = 0
|
||||
PullOk = 1
|
||||
Timeout = 2
|
||||
DataLost = 3
|
||||
Invalid = 999
|
||||
|
||||
|
||||
cdef extern from "queue/queue_client.h" namespace "ray::streaming" nogil:
|
||||
cdef cppclass CReaderClient "ray::streaming::ReaderClient":
|
||||
CReaderClient()
|
||||
@@ -170,12 +128,11 @@ cdef extern from "data_reader.h" namespace "ray::streaming" nogil:
|
||||
CDataReader(shared_ptr[CRuntimeContext] &runtime_context)
|
||||
void Init(const c_vector[CObjectID] &input_ids,
|
||||
const c_vector[CChannelCreationParameter] ¶ms,
|
||||
const c_vector[uint64_t] &seq_ids,
|
||||
const c_vector[uint64_t] &msg_ids,
|
||||
c_vector[CTransferCreationStatus] &creation_status,
|
||||
int64_t timer_interval);
|
||||
CStreamingStatus GetBundle(const uint32_t timeout_ms,
|
||||
shared_ptr[CDataBundle] &message)
|
||||
void GetOffsetInfo(c_unordered_map[CObjectID, CConsumerChannelInfo] *&offset_map);
|
||||
void Stop()
|
||||
|
||||
|
||||
@@ -188,9 +145,6 @@ cdef extern from "data_writer.h" namespace "ray::streaming" nogil:
|
||||
const c_vector[uint64_t] &queue_size_vec);
|
||||
long WriteMessageToBufferRing(
|
||||
const CObjectID &q_id, uint8_t *data, uint32_t data_size)
|
||||
void BroadcastBarrier(uint64_t checkpoint_id, const uint8_t *data, uint32_t data_size)
|
||||
void GetChannelOffset(c_vector[uint64_t] &result)
|
||||
void ClearCheckpoint(uint64_t checkpoint_id)
|
||||
void Run()
|
||||
void Stop()
|
||||
|
||||
|
||||
@@ -6,8 +6,6 @@ from libcpp.memory cimport shared_ptr, make_shared, dynamic_pointer_cast
|
||||
from libcpp.string cimport string as c_string
|
||||
from libcpp.vector cimport vector as c_vector
|
||||
from libcpp.list cimport list as c_list
|
||||
from libcpp.unordered_map cimport unordered_map as c_unordered_map
|
||||
from cython.operator cimport dereference, postincrement
|
||||
|
||||
from ray.includes.common cimport (
|
||||
CRayFunction,
|
||||
@@ -40,10 +38,6 @@ from ray.streaming.includes.libstreaming cimport (
|
||||
CWriterClient,
|
||||
CLocalMemoryBuffer,
|
||||
CChannelCreationParameter,
|
||||
CTransferCreationStatus,
|
||||
CConsumerChannelInfo,
|
||||
CStreamingBarrierHeader,
|
||||
kBarrierHeaderSize,
|
||||
)
|
||||
from ray._raylet import JavaFunctionDescriptor
|
||||
|
||||
@@ -197,7 +191,7 @@ cdef class DataWriter:
|
||||
self.writer = NULL
|
||||
|
||||
def write(self, ObjectRef qid, const unsigned char[:] value):
|
||||
"""support zero-copy bytes, byte array, array of unsigned char"""
|
||||
"""support zero-copy bytes, bytearray, array of unsigned char"""
|
||||
cdef:
|
||||
CObjectID native_id = qid.data
|
||||
uint64_t msg_id
|
||||
@@ -207,25 +201,6 @@ cdef class DataWriter:
|
||||
msg_id = self.writer.WriteMessageToBufferRing(native_id, data, size)
|
||||
return msg_id
|
||||
|
||||
def broadcast_barrier(self, uint64_t checkpoint_id, const unsigned char[:] value):
|
||||
cdef:
|
||||
uint8_t *data = <uint8_t *>(&value[0])
|
||||
uint32_t size = value.nbytes
|
||||
with nogil:
|
||||
self.writer.BroadcastBarrier(checkpoint_id, data, size)
|
||||
|
||||
def get_output_checkpoints(self):
|
||||
cdef:
|
||||
c_vector[uint64_t] results
|
||||
self.writer.GetChannelOffset(results)
|
||||
return results
|
||||
|
||||
def clear_checkpoint(self, checkpoint_id):
|
||||
cdef:
|
||||
uint64_t c_checkpoint_id = checkpoint_id
|
||||
with nogil:
|
||||
self.writer.ClearCheckpoint(c_checkpoint_id)
|
||||
|
||||
def stop(self):
|
||||
self.writer.Stop()
|
||||
channel_logger.info("stopped DataWriter")
|
||||
@@ -243,22 +218,25 @@ cdef class DataReader:
|
||||
@staticmethod
|
||||
def create(list py_input_queues,
|
||||
list input_creation_parameters: list[ChannelCreationParameter],
|
||||
list py_seq_ids,
|
||||
list py_msg_ids,
|
||||
int64_t timer_interval,
|
||||
c_bool is_recreate,
|
||||
bytes config_bytes,
|
||||
c_bool is_mock):
|
||||
cdef:
|
||||
c_vector[CObjectID] queue_id_vec = bytes_list_to_qid_vec(py_input_queues)
|
||||
c_vector[CChannelCreationParameter] initial_parameters
|
||||
c_vector[uint64_t] seq_ids
|
||||
c_vector[uint64_t] msg_ids
|
||||
c_vector[CTransferCreationStatus] c_creation_status
|
||||
CDataReader *c_reader
|
||||
ChannelCreationParameter parameter
|
||||
cdef const unsigned char[:] config_data
|
||||
for param in input_creation_parameters:
|
||||
parameter = param
|
||||
initial_parameters.push_back(parameter.get_parameter())
|
||||
|
||||
for py_seq_id in py_seq_ids:
|
||||
seq_ids.push_back(<uint64_t>py_seq_id)
|
||||
for py_msg_id in py_msg_ids:
|
||||
msg_ids.push_back(<uint64_t>py_msg_id)
|
||||
cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]()
|
||||
@@ -269,19 +247,11 @@ cdef class DataReader:
|
||||
if is_mock:
|
||||
ctx.get().MarkMockTest()
|
||||
c_reader = new CDataReader(ctx)
|
||||
c_reader.Init(queue_id_vec, initial_parameters, msg_ids, c_creation_status, timer_interval)
|
||||
|
||||
creation_status_map = {}
|
||||
if not c_creation_status.empty():
|
||||
for i in range(queue_id_vec.size()):
|
||||
k = queue_id_vec[i].Binary()
|
||||
v = <uint64_t>c_creation_status[i]
|
||||
creation_status_map[k] = v
|
||||
|
||||
c_reader.Init(queue_id_vec, initial_parameters, seq_ids, msg_ids, timer_interval)
|
||||
channel_logger.info("create native reader succeed")
|
||||
cdef DataReader reader = DataReader.__new__(DataReader)
|
||||
reader.reader = c_reader
|
||||
return reader, creation_status_map
|
||||
return reader
|
||||
|
||||
def __dealloc__(self):
|
||||
if self.reader != NULL:
|
||||
@@ -295,33 +265,23 @@ cdef class DataReader:
|
||||
CStreamingStatus status
|
||||
with nogil:
|
||||
status = self.reader.GetBundle(timeout_millis, bundle)
|
||||
cdef uint32_t bundle_type = <uint32_t>(bundle.get().meta.get().GetBundleType())
|
||||
if <uint32_t> status != <uint32_t> libstreaming.StatusOK:
|
||||
if <uint32_t> status == <uint32_t> libstreaming.StatusInterrupted:
|
||||
# avoid cyclic import
|
||||
import ray.streaming.runtime.transfer as transfer
|
||||
raise transfer.ChannelInterruptException("reader interrupted")
|
||||
elif <uint32_t> status == <uint32_t> libstreaming.StatusInitQueueFailed:
|
||||
import ray.streaming.runtime.transfer as transfer
|
||||
raise transfer.ChannelInitException("init channel failed")
|
||||
elif <uint32_t> status == <uint32_t> libstreaming.StatusGetBundleTimeOut:
|
||||
return []
|
||||
else:
|
||||
raise Exception("no such status " + str(<uint32_t>status))
|
||||
raise Exception("init channel failed")
|
||||
elif <uint32_t> status == <uint32_t> libstreaming.StatusWaitQueueTimeOut:
|
||||
raise Exception("wait channel object timeout")
|
||||
cdef:
|
||||
uint32_t msg_nums
|
||||
CObjectID queue_id = bundle.get().c_from
|
||||
CObjectID queue_id
|
||||
c_list[shared_ptr[CStreamingMessage]] msg_list
|
||||
list msgs = []
|
||||
uint64_t timestamp
|
||||
uint64_t msg_id
|
||||
c_unordered_map[CObjectID, CConsumerChannelInfo] *offset_map = NULL
|
||||
shared_ptr[CStreamingMessage] barrier
|
||||
CStreamingBarrierHeader barrier_header
|
||||
c_unordered_map[CObjectID, CConsumerChannelInfo].iterator it
|
||||
|
||||
cdef uint32_t bundle_type = <uint32_t>(bundle.get().meta.get().GetBundleType())
|
||||
# avoid cyclic import
|
||||
from ray.streaming.runtime.transfer import DataMessage
|
||||
if bundle_type == <uint32_t> libstreaming.BundleTypeBundle:
|
||||
msg_nums = bundle.get().meta.get().GetMessageListSize()
|
||||
CStreamingMessageBundle.GetMessageListFromRawData(
|
||||
@@ -331,48 +291,16 @@ cdef class DataReader:
|
||||
msg_list)
|
||||
timestamp = bundle.get().meta.get().GetMessageBundleTs()
|
||||
for msg in msg_list:
|
||||
msg_bytes = msg.get().Payload()[:msg.get().PayloadSize()]
|
||||
msg_bytes = msg.get().RawData()[:msg.get().GetDataSize()]
|
||||
qid_bytes = queue_id.Binary()
|
||||
msg_id = msg.get().GetMessageId()
|
||||
msgs.append(
|
||||
DataMessage(msg_bytes, timestamp, msg_id, qid_bytes))
|
||||
msg_id = msg.get().GetMessageSeqId()
|
||||
msgs.append((msg_bytes, msg_id, timestamp, qid_bytes))
|
||||
return msgs
|
||||
elif bundle_type == <uint32_t> libstreaming.BundleTypeEmpty:
|
||||
timestamp = bundle.get().meta.get().GetMessageBundleTs()
|
||||
msg_id = bundle.get().meta.get().GetLastMessageId()
|
||||
return [DataMessage(None, timestamp, msg_id, queue_id.Binary(), True)]
|
||||
elif bundle.get().meta.get().IsBarrier():
|
||||
py_offset_map = {}
|
||||
self.reader.GetOffsetInfo(offset_map)
|
||||
it = offset_map.begin()
|
||||
while it != offset_map.end():
|
||||
queue_id_bytes = dereference(it).first.Binary()
|
||||
current_message_id = dereference(it).second.current_message_id
|
||||
py_offset_map[queue_id_bytes] = current_message_id
|
||||
postincrement(it)
|
||||
msg_nums = bundle.get().meta.get().GetMessageListSize()
|
||||
CStreamingMessageBundle.GetMessageListFromRawData(
|
||||
bundle.get().data + libstreaming.kMessageBundleHeaderSize,
|
||||
bundle.get().data_size - libstreaming.kMessageBundleHeaderSize,
|
||||
msg_nums,
|
||||
msg_list)
|
||||
timestamp = bundle.get().meta.get().GetMessageBundleTs()
|
||||
barrier = msg_list.front()
|
||||
msg_id = barrier.get().GetMessageId()
|
||||
CStreamingMessage.GetBarrierIdFromRawData(barrier.get().Payload(), &barrier_header)
|
||||
barrier_id = barrier_header.barrier_id
|
||||
barrier_data = (barrier.get().Payload() + kBarrierHeaderSize)[
|
||||
:barrier.get().PayloadSize() - kBarrierHeaderSize]
|
||||
barrier_type = <uint64_t> barrier_header.barrier_type
|
||||
py_queue_id = queue_id.Binary()
|
||||
from ray.streaming.runtime.transfer import CheckpointBarrier
|
||||
return [CheckpointBarrier(
|
||||
barrier_data, timestamp, msg_id, py_queue_id, py_offset_map,
|
||||
barrier_id, barrier_type)]
|
||||
return []
|
||||
else:
|
||||
raise Exception("Unsupported bundle type {}".format(bundle_type))
|
||||
|
||||
|
||||
def stop(self):
|
||||
self.reader.Stop()
|
||||
channel_logger.info("stopped DataReader")
|
||||
|
||||
Reference in New Issue
Block a user