mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 03:04:28 +08:00
[Streaming] Streaming Python API (#6755)
This commit is contained in:
@@ -0,0 +1,158 @@
|
||||
import logging
|
||||
import pickle
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import ray
|
||||
from ray.streaming.collector import OutputCollector
|
||||
from ray.streaming.config import Config
|
||||
from ray.streaming.context import RuntimeContextImpl
|
||||
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamTask(ABC):
|
||||
"""Base class for all streaming tasks. Each task runs a processor."""
|
||||
|
||||
def __init__(self, task_id, processor, worker):
|
||||
self.task_id = task_id
|
||||
self.processor = processor
|
||||
self.worker = worker
|
||||
self.reader = None # DataReader
|
||||
self.writers = {} # ExecutionEdge -> DataWriter
|
||||
self.thread = None
|
||||
self.prepare_task()
|
||||
self.thread = threading.Thread(target=self.run, daemon=True)
|
||||
|
||||
def prepare_task(self):
|
||||
channel_conf = dict(self.worker.config)
|
||||
channel_size = int(
|
||||
self.worker.config.get(Config.CHANNEL_SIZE,
|
||||
Config.CHANNEL_SIZE_DEFAULT))
|
||||
channel_conf[Config.CHANNEL_SIZE] = channel_size
|
||||
channel_conf[Config.TASK_JOB_ID] = ray.runtime_context.\
|
||||
_get_runtime_context().current_driver_id
|
||||
channel_conf[Config.CHANNEL_TYPE] = self.worker.config \
|
||||
.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL)
|
||||
|
||||
execution_graph = self.worker.execution_graph
|
||||
execution_node = self.worker.execution_node
|
||||
# writers
|
||||
collectors = []
|
||||
for edge in execution_node.output_edges:
|
||||
output_actor_ids = {}
|
||||
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
|
||||
edge.target_node_id)
|
||||
for target_task_id, target_actor in task_id2_worker.items():
|
||||
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
|
||||
execution_graph.build_time())
|
||||
output_actor_ids[channel_name] = target_actor
|
||||
if len(output_actor_ids) > 0:
|
||||
channel_ids = list(output_actor_ids.keys())
|
||||
to_actor_ids = list(output_actor_ids.values())
|
||||
writer = DataWriter(channel_ids, to_actor_ids, channel_conf)
|
||||
logger.info("Create DataWriter succeed.")
|
||||
self.writers[edge] = writer
|
||||
collectors.append(
|
||||
OutputCollector(channel_ids, writer, edge.partition))
|
||||
|
||||
# readers
|
||||
input_actor_ids = {}
|
||||
for edge in execution_node.input_edges:
|
||||
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
|
||||
edge.src_node_id)
|
||||
for src_task_id, src_actor in task_id2_worker.items():
|
||||
channel_name = ChannelID.gen_id(src_task_id, self.task_id,
|
||||
execution_graph.build_time())
|
||||
input_actor_ids[channel_name] = src_actor
|
||||
if len(input_actor_ids) > 0:
|
||||
channel_ids = list(input_actor_ids.keys())
|
||||
from_actor_ids = list(input_actor_ids.values())
|
||||
logger.info("Create DataReader, channels {}.".format(channel_ids))
|
||||
self.reader = DataReader(channel_ids, from_actor_ids, channel_conf)
|
||||
|
||||
def exit_handler():
|
||||
# Make DataReader stop read data when MockQueue destructor
|
||||
# gets called to avoid crash
|
||||
self.cancel_task()
|
||||
|
||||
import atexit
|
||||
atexit.register(exit_handler)
|
||||
|
||||
runtime_context = RuntimeContextImpl(
|
||||
self.worker.execution_task.task_id,
|
||||
self.worker.execution_task.task_index, execution_node.parallelism)
|
||||
logger.info("open Processor {}".format(self.processor))
|
||||
self.processor.open(collectors, runtime_context)
|
||||
|
||||
@abstractmethod
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
self.thread.start()
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_task(self):
|
||||
pass
|
||||
|
||||
|
||||
class InputStreamTask(StreamTask):
|
||||
"""Base class for stream tasks that execute a
|
||||
:class:`runtime.processor.OneInputProcessor` or
|
||||
:class:`runtime.processor.TwoInputProcessor` """
|
||||
|
||||
def __init__(self, task_id, processor_instance, worker):
|
||||
super().__init__(task_id, processor_instance, worker)
|
||||
self.running = True
|
||||
self.stopped = False
|
||||
self.read_timeout_millis = \
|
||||
int(worker.config.get(Config.READ_TIMEOUT_MS,
|
||||
Config.DEFAULT_READ_TIMEOUT_MS))
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
while self.running:
|
||||
item = self.reader.read(self.read_timeout_millis)
|
||||
if item is not None:
|
||||
msg_data = item.body()
|
||||
msg = pickle.loads(msg_data)
|
||||
self.processor.process(msg)
|
||||
self.stopped = True
|
||||
|
||||
def cancel_task(self):
|
||||
self.running = False
|
||||
while not self.stopped:
|
||||
pass
|
||||
|
||||
|
||||
class OneInputStreamTask(InputStreamTask):
|
||||
"""A stream task for executing :class:`runtime.processor.OneInputProcessor`
|
||||
"""
|
||||
|
||||
def __init__(self, task_id, processor_instance, worker):
|
||||
super().__init__(task_id, processor_instance, worker)
|
||||
|
||||
|
||||
class SourceStreamTask(StreamTask):
|
||||
"""A stream task for executing :class:`runtime.processor.SourceProcessor`
|
||||
"""
|
||||
|
||||
def __init__(self, task_id, processor_instance, worker):
|
||||
super().__init__(task_id, processor_instance, worker)
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
self.processor.run()
|
||||
|
||||
def cancel_task(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user