mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 16:48:11 +08:00
This reverts commit 1b1466748f.
This commit is contained in:
@@ -2,8 +2,6 @@ import logging
|
||||
import random
|
||||
from queue import Queue
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import ray
|
||||
import ray.streaming._streaming as _streaming
|
||||
@@ -15,7 +13,6 @@ from ray._raylet import PythonFunctionDescriptor
|
||||
from ray._raylet import Language
|
||||
|
||||
CHANNEL_ID_LEN = 20
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChannelID:
|
||||
@@ -100,70 +97,40 @@ def channel_bytes_to_str(id_bytes):
|
||||
return bytes.hex(id_bytes)
|
||||
|
||||
|
||||
class Message(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def body(self):
|
||||
"""Message data"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def timestamp(self):
|
||||
"""Get timestamp when item is written by upstream DataWriter
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def channel_id(self):
|
||||
"""Get string id of channel where data is coming from"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def message_id(self):
|
||||
"""Get message id of the message"""
|
||||
pass
|
||||
|
||||
|
||||
class DataMessage(Message):
|
||||
class DataMessage:
|
||||
"""
|
||||
DataMessage represents data between upstream and downstream operator.
|
||||
DataMessage represents data between upstream and downstream operator
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
body,
|
||||
timestamp,
|
||||
message_id,
|
||||
channel_id,
|
||||
message_id_,
|
||||
is_empty_message=False):
|
||||
self.__body = body
|
||||
self.__timestamp = timestamp
|
||||
self.__channel_id = channel_id
|
||||
self.__message_id = message_id
|
||||
self.__message_id = message_id_
|
||||
self.__is_empty_message = is_empty_message
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__body)
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
"""Message data"""
|
||||
return self.__body
|
||||
|
||||
@property
|
||||
def timestamp(self):
|
||||
"""Get timestamp when item is written by upstream DataWriter
|
||||
"""
|
||||
return self.__timestamp
|
||||
|
||||
@property
|
||||
def channel_id(self):
|
||||
"""Get string id of channel where data is coming from
|
||||
"""
|
||||
return self.__channel_id
|
||||
|
||||
@property
|
||||
def message_id(self):
|
||||
return self.__message_id
|
||||
|
||||
@property
|
||||
def is_empty_message(self):
|
||||
"""Whether this message is an empty message.
|
||||
Upstream DataWriter will send an empty message when this is no data
|
||||
@@ -171,47 +138,10 @@ class DataMessage(Message):
|
||||
"""
|
||||
return self.__is_empty_message
|
||||
|
||||
|
||||
class CheckpointBarrier(Message):
|
||||
"""
|
||||
CheckpointBarrier separates the records in the data stream into the set of
|
||||
records that goes into the current snapshot, and the records that go into
|
||||
the next snapshot. Each barrier carries the ID of the snapshot whose
|
||||
records it pushed in front of it.
|
||||
"""
|
||||
|
||||
def __init__(self, barrier_data, timestamp, message_id, channel_id,
|
||||
offsets, barrier_id, barrier_type):
|
||||
self.__barrier_data = barrier_data
|
||||
self.__timestamp = timestamp
|
||||
self.__message_id = message_id
|
||||
self.__channel_id = channel_id
|
||||
self.checkpoint_id = barrier_id
|
||||
self.offsets = offsets
|
||||
self.barrier_type = barrier_type
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.__barrier_data
|
||||
|
||||
@property
|
||||
def timestamp(self):
|
||||
return self.__timestamp
|
||||
|
||||
@property
|
||||
def channel_id(self):
|
||||
return self.__channel_id
|
||||
|
||||
@property
|
||||
def message_id(self):
|
||||
return self.__message_id
|
||||
|
||||
def get_input_checkpoints(self):
|
||||
return self.offsets
|
||||
|
||||
def __str__(self):
|
||||
return "Barrier(Checkpoint id : {})".format(self.checkpoint_id)
|
||||
|
||||
|
||||
class ChannelCreationParametersBuilder:
|
||||
"""
|
||||
@@ -288,6 +218,9 @@ class ChannelCreationParametersBuilder:
|
||||
_python_reader_sync_function_descriptor = sync_function
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataWriter:
|
||||
"""Data Writer is a wrapper of streaming c++ DataWriter, which sends data
|
||||
to downstream workers
|
||||
@@ -331,26 +264,6 @@ class DataWriter:
|
||||
msg_id = self.writer.write(channel_id.object_qid, item)
|
||||
return msg_id
|
||||
|
||||
def broadcast_barrier(self, checkpoint_id: int, body: bytes):
|
||||
"""Broadcast barriers to all downstream channels
|
||||
Args:
|
||||
checkpoint_id: the checkpoint_id
|
||||
body: barrier payload
|
||||
"""
|
||||
self.writer.broadcast_barrier(checkpoint_id, body)
|
||||
|
||||
def get_output_checkpoints(self) -> List[int]:
|
||||
"""Get output offsets of all downstream channels
|
||||
Returns:
|
||||
a list contains current msg_id of each downstream channel
|
||||
"""
|
||||
return self.writer.get_output_checkpoints()
|
||||
|
||||
def clear_checkpoint(self, checkpoint_id):
|
||||
logger.info("producer start to clear checkpoint, checkpoint_id={}"
|
||||
.format(checkpoint_id))
|
||||
self.writer.clear_checkpoint(checkpoint_id)
|
||||
|
||||
def stop(self):
|
||||
logger.info("stopping channel writer.")
|
||||
self.writer.stop()
|
||||
@@ -381,20 +294,18 @@ class DataReader:
|
||||
]
|
||||
creation_parameters = ChannelCreationParametersBuilder()
|
||||
creation_parameters.build_input_queue_parameters(from_actors)
|
||||
py_seq_ids = [0 for _ in range(len(input_channels))]
|
||||
py_msg_ids = [0 for _ in range(len(input_channels))]
|
||||
timer_interval = int(conf.get(Config.TIMER_INTERVAL_MS, -1))
|
||||
is_recreate = bool(conf.get(Config.IS_RECREATE, False))
|
||||
config_bytes = _to_native_conf(conf)
|
||||
self.__queue = Queue(10000)
|
||||
is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL
|
||||
self.reader, queues_creation_status = _streaming.DataReader.create(
|
||||
self.reader = _streaming.DataReader.create(
|
||||
py_input_channels, creation_parameters.get_parameters(),
|
||||
py_msg_ids, timer_interval, config_bytes, is_mock)
|
||||
|
||||
self.__creation_status = {}
|
||||
for q, status in queues_creation_status.items():
|
||||
self.__creation_status[q] = ChannelCreationStatus(status)
|
||||
logger.info("create DataReader succeed, creation_status={}".format(
|
||||
self.__creation_status))
|
||||
py_seq_ids, py_msg_ids, timer_interval, is_recreate, config_bytes,
|
||||
is_mock)
|
||||
logger.info("create DataReader succeed")
|
||||
|
||||
def read(self, timeout_millis):
|
||||
"""Read data from channel
|
||||
@@ -405,17 +316,16 @@ class DataReader:
|
||||
channel item
|
||||
"""
|
||||
if self.__queue.empty():
|
||||
messages = self.reader.read(timeout_millis)
|
||||
for message in messages:
|
||||
self.__queue.put(message)
|
||||
|
||||
msgs = self.reader.read(timeout_millis)
|
||||
for msg in msgs:
|
||||
msg_bytes, msg_id, timestamp, qid_bytes = msg
|
||||
data_msg = DataMessage(msg_bytes, timestamp,
|
||||
channel_bytes_to_str(qid_bytes), msg_id)
|
||||
self.__queue.put(data_msg)
|
||||
if self.__queue.empty():
|
||||
return None
|
||||
return self.__queue.get()
|
||||
|
||||
def get_channel_recover_info(self):
|
||||
return ChannelRecoverInfo(self.__creation_status)
|
||||
|
||||
def stop(self):
|
||||
logger.info("stopping Data Reader.")
|
||||
self.reader.stop()
|
||||
@@ -462,45 +372,3 @@ class ChannelInitException(Exception):
|
||||
class ChannelInterruptException(Exception):
|
||||
def __init__(self, msg=None):
|
||||
self.msg = msg
|
||||
|
||||
|
||||
class ChannelRecoverInfo:
|
||||
def __init__(self, queue_creation_status_map=None):
|
||||
if queue_creation_status_map is None:
|
||||
queue_creation_status_map = {}
|
||||
self.__queue_creation_status_map = queue_creation_status_map
|
||||
|
||||
def get_creation_status(self):
|
||||
return self.__queue_creation_status_map
|
||||
|
||||
def get_data_lost_queues(self):
|
||||
data_lost_queues = set()
|
||||
for (q, status) in self.__queue_creation_status_map.items():
|
||||
if status == ChannelCreationStatus.DataLost:
|
||||
data_lost_queues.add(q)
|
||||
return data_lost_queues
|
||||
|
||||
def __str__(self):
|
||||
return "QueueRecoverInfo [dataLostQueues=%s]" \
|
||||
% (self.get_data_lost_queues())
|
||||
|
||||
|
||||
class ChannelCreationStatus(Enum):
|
||||
FreshStarted = 0
|
||||
PullOk = 1
|
||||
Timeout = 2
|
||||
DataLost = 3
|
||||
|
||||
|
||||
def channel_id_bytes_to_str(id_bytes):
|
||||
"""
|
||||
Args:
|
||||
id_bytes: bytes representation of channel id
|
||||
|
||||
Returns:
|
||||
string representation of channel id
|
||||
"""
|
||||
assert type(id_bytes) in [str, bytes]
|
||||
if isinstance(id_bytes, str):
|
||||
return id_bytes
|
||||
return bytes.hex(id_bytes)
|
||||
|
||||
Reference in New Issue
Block a user