[Streaming] Streaming Cross-Lang API (#7464)

This commit is contained in:
chaokunyang
2020-04-29 13:42:08 +08:00
committed by GitHub
parent 101255f782
commit 91f630f709
72 changed files with 1612 additions and 408 deletions
+29 -17
View File
@@ -1,11 +1,13 @@
import logging
import pickle
import threading
from abc import ABC, abstractmethod
from ray.streaming.collector import OutputCollector
from ray.streaming.config import Config
from ray.streaming.context import RuntimeContextImpl
from ray.streaming.runtime import serialization
from ray.streaming.runtime.serialization import \
PythonSerializer, CrossLangSerializer
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
logger = logging.getLogger(__name__)
@@ -38,36 +40,40 @@ class StreamTask(ABC):
# writers
collectors = []
for edge in execution_node.output_edges:
output_actor_ids = {}
output_actors_map = {}
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
edge.target_node_id)
for target_task_id, target_actor in task_id2_worker.items():
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
execution_graph.build_time())
output_actor_ids[channel_name] = target_actor
if len(output_actor_ids) > 0:
channel_ids = list(output_actor_ids.keys())
to_actor_ids = list(output_actor_ids.values())
writer = DataWriter(channel_ids, to_actor_ids, channel_conf)
logger.info("Create DataWriter succeed.")
output_actors_map[channel_name] = target_actor
if len(output_actors_map) > 0:
channel_ids = list(output_actors_map.keys())
target_actors = list(output_actors_map.values())
logger.info(
"Create DataWriter channel_ids {}, target_actors {}."
.format(channel_ids, target_actors))
writer = DataWriter(channel_ids, target_actors, channel_conf)
self.writers[edge] = writer
collectors.append(
OutputCollector(channel_ids, writer, edge.partition))
OutputCollector(writer, channel_ids, target_actors,
edge.partition))
# readers
input_actor_ids = {}
input_actor_map = {}
for edge in execution_node.input_edges:
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
edge.src_node_id)
for src_task_id, src_actor in task_id2_worker.items():
channel_name = ChannelID.gen_id(src_task_id, self.task_id,
execution_graph.build_time())
input_actor_ids[channel_name] = src_actor
if len(input_actor_ids) > 0:
channel_ids = list(input_actor_ids.keys())
from_actor_ids = list(input_actor_ids.values())
logger.info("Create DataReader, channels {}.".format(channel_ids))
self.reader = DataReader(channel_ids, from_actor_ids, channel_conf)
input_actor_map[channel_name] = src_actor
if len(input_actor_map) > 0:
channel_ids = list(input_actor_map.keys())
from_actors = list(input_actor_map.values())
logger.info("Create DataReader, channels {}, input_actors {}."
.format(channel_ids, from_actors))
self.reader = DataReader(channel_ids, from_actors, channel_conf)
def exit_handler():
# Make DataReader stop read data when MockQueue destructor
@@ -111,6 +117,8 @@ class InputStreamTask(StreamTask):
self.read_timeout_millis = \
int(worker.config.get(Config.READ_TIMEOUT_MS,
Config.DEFAULT_READ_TIMEOUT_MS))
self.python_serializer = PythonSerializer()
self.cross_lang_serializer = CrossLangSerializer()
def init(self):
pass
@@ -120,7 +128,11 @@ class InputStreamTask(StreamTask):
item = self.reader.read(self.read_timeout_millis)
if item is not None:
msg_data = item.body()
msg = pickle.loads(msg_data)
type_id = msg_data[:1]
if (type_id == serialization._PYTHON_TYPE_ID):
msg = self.python_serializer.deserialize(msg_data[1:])
else:
msg = self.cross_lang_serializer.deserialize(msg_data[1:])
self.processor.process(msg)
self.stopped = True