mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 05:41:19 +08:00
[Streaming] Streaming Cross-Lang API (#7464)
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
import logging
|
||||
import pickle
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray.streaming.collector import OutputCollector
|
||||
from ray.streaming.config import Config
|
||||
from ray.streaming.context import RuntimeContextImpl
|
||||
from ray.streaming.runtime import serialization
|
||||
from ray.streaming.runtime.serialization import \
|
||||
PythonSerializer, CrossLangSerializer
|
||||
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -38,36 +40,40 @@ class StreamTask(ABC):
|
||||
# writers
|
||||
collectors = []
|
||||
for edge in execution_node.output_edges:
|
||||
output_actor_ids = {}
|
||||
output_actors_map = {}
|
||||
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
|
||||
edge.target_node_id)
|
||||
for target_task_id, target_actor in task_id2_worker.items():
|
||||
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
|
||||
execution_graph.build_time())
|
||||
output_actor_ids[channel_name] = target_actor
|
||||
if len(output_actor_ids) > 0:
|
||||
channel_ids = list(output_actor_ids.keys())
|
||||
to_actor_ids = list(output_actor_ids.values())
|
||||
writer = DataWriter(channel_ids, to_actor_ids, channel_conf)
|
||||
logger.info("Create DataWriter succeed.")
|
||||
output_actors_map[channel_name] = target_actor
|
||||
if len(output_actors_map) > 0:
|
||||
channel_ids = list(output_actors_map.keys())
|
||||
target_actors = list(output_actors_map.values())
|
||||
logger.info(
|
||||
"Create DataWriter channel_ids {}, target_actors {}."
|
||||
.format(channel_ids, target_actors))
|
||||
writer = DataWriter(channel_ids, target_actors, channel_conf)
|
||||
self.writers[edge] = writer
|
||||
collectors.append(
|
||||
OutputCollector(channel_ids, writer, edge.partition))
|
||||
OutputCollector(writer, channel_ids, target_actors,
|
||||
edge.partition))
|
||||
|
||||
# readers
|
||||
input_actor_ids = {}
|
||||
input_actor_map = {}
|
||||
for edge in execution_node.input_edges:
|
||||
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
|
||||
edge.src_node_id)
|
||||
for src_task_id, src_actor in task_id2_worker.items():
|
||||
channel_name = ChannelID.gen_id(src_task_id, self.task_id,
|
||||
execution_graph.build_time())
|
||||
input_actor_ids[channel_name] = src_actor
|
||||
if len(input_actor_ids) > 0:
|
||||
channel_ids = list(input_actor_ids.keys())
|
||||
from_actor_ids = list(input_actor_ids.values())
|
||||
logger.info("Create DataReader, channels {}.".format(channel_ids))
|
||||
self.reader = DataReader(channel_ids, from_actor_ids, channel_conf)
|
||||
input_actor_map[channel_name] = src_actor
|
||||
if len(input_actor_map) > 0:
|
||||
channel_ids = list(input_actor_map.keys())
|
||||
from_actors = list(input_actor_map.values())
|
||||
logger.info("Create DataReader, channels {}, input_actors {}."
|
||||
.format(channel_ids, from_actors))
|
||||
self.reader = DataReader(channel_ids, from_actors, channel_conf)
|
||||
|
||||
def exit_handler():
|
||||
# Make DataReader stop read data when MockQueue destructor
|
||||
@@ -111,6 +117,8 @@ class InputStreamTask(StreamTask):
|
||||
self.read_timeout_millis = \
|
||||
int(worker.config.get(Config.READ_TIMEOUT_MS,
|
||||
Config.DEFAULT_READ_TIMEOUT_MS))
|
||||
self.python_serializer = PythonSerializer()
|
||||
self.cross_lang_serializer = CrossLangSerializer()
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
@@ -120,7 +128,11 @@ class InputStreamTask(StreamTask):
|
||||
item = self.reader.read(self.read_timeout_millis)
|
||||
if item is not None:
|
||||
msg_data = item.body()
|
||||
msg = pickle.loads(msg_data)
|
||||
type_id = msg_data[:1]
|
||||
if (type_id == serialization._PYTHON_TYPE_ID):
|
||||
msg = self.python_serializer.deserialize(msg_data[1:])
|
||||
else:
|
||||
msg = self.cross_lang_serializer.deserialize(msg_data[1:])
|
||||
self.processor.process(msg)
|
||||
self.stopped = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user