mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 05:41:19 +08:00
[Streaming] Streaming Python API (#6755)
This commit is contained in:
@@ -0,0 +1,67 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""Module to interact between java and python
|
||||
"""
|
||||
|
||||
import msgpack
|
||||
import ray
|
||||
|
||||
|
||||
class GatewayClient:
|
||||
"""GatewayClient is used to interact with `PythonGateway` java actor"""
|
||||
|
||||
_PYTHON_GATEWAY_CLASSNAME = \
|
||||
b"org.ray.streaming.runtime.python.PythonGateway"
|
||||
|
||||
def __init__(self):
|
||||
self._python_gateway_actor = ray.java_actor_class(
|
||||
GatewayClient._PYTHON_GATEWAY_CLASSNAME).remote()
|
||||
|
||||
def create_streaming_context(self):
|
||||
call = self._python_gateway_actor.createStreamingContext.remote()
|
||||
return deserialize(ray.get(call))
|
||||
|
||||
def with_config(self, conf):
|
||||
call = self._python_gateway_actor.withConfig.remote(serialize(conf))
|
||||
ray.get(call)
|
||||
|
||||
def execute(self, job_name):
|
||||
call = self._python_gateway_actor.execute.remote(serialize(job_name))
|
||||
ray.get(call)
|
||||
|
||||
def create_py_stream_source(self, serialized_func):
|
||||
assert isinstance(serialized_func, bytes)
|
||||
call = self._python_gateway_actor.createPythonStreamSource\
|
||||
.remote(serialized_func)
|
||||
return deserialize(ray.get(call))
|
||||
|
||||
def create_py_func(self, serialized_func):
|
||||
assert isinstance(serialized_func, bytes)
|
||||
call = self._python_gateway_actor.createPyFunc.remote(serialized_func)
|
||||
return deserialize(ray.get(call))
|
||||
|
||||
def create_py_partition(self, serialized_partition):
|
||||
assert isinstance(serialized_partition, bytes)
|
||||
call = self._python_gateway_actor.createPyPartition\
|
||||
.remote(serialized_partition)
|
||||
return deserialize(ray.get(call))
|
||||
|
||||
def call_function(self, java_class, java_function, *args):
|
||||
java_params = serialize([java_class, java_function] + list(args))
|
||||
call = self._python_gateway_actor.callFunction.remote(java_params)
|
||||
return deserialize(ray.get(call))
|
||||
|
||||
def call_method(self, java_object, java_method, *args):
|
||||
java_params = serialize([java_object, java_method] + list(args))
|
||||
call = self._python_gateway_actor.callMethod.remote(java_params)
|
||||
return deserialize(ray.get(call))
|
||||
|
||||
|
||||
def serialize(obj) -> bytes:
|
||||
"""Serialize a python object which can be deserialized by `PythonGateway`
|
||||
"""
|
||||
return msgpack.packb(obj, use_bin_type=True)
|
||||
|
||||
|
||||
def deserialize(data: bytes):
|
||||
"""Deserialize the binary data serialized by `PythonGateway`"""
|
||||
return msgpack.unpackb(data, raw=False)
|
||||
@@ -0,0 +1,102 @@
|
||||
import enum
|
||||
|
||||
import ray
|
||||
import ray.streaming.generated.remote_call_pb2 as remote_call_pb
|
||||
import ray.streaming.generated.streaming_pb2 as streaming_pb
|
||||
import ray.streaming.operator as operator
|
||||
import ray.streaming.partition as partition
|
||||
from ray.streaming import function
|
||||
from ray.streaming.generated.streaming_pb2 import Language
|
||||
|
||||
|
||||
class NodeType(enum.Enum):
|
||||
"""
|
||||
SOURCE: Sources are where your program reads its input from
|
||||
|
||||
TRANSFORM: Operators transform one or more DataStreams into a new
|
||||
DataStream. Programs can combine multiple transformations into
|
||||
sophisticated dataflow topologies.
|
||||
|
||||
SINK: Sinks consume DataStreams and forward them to files, sockets,
|
||||
external systems, or print them.
|
||||
"""
|
||||
SOURCE = 0
|
||||
TRANSFORM = 1
|
||||
SINK = 2
|
||||
|
||||
|
||||
class ExecutionNode:
|
||||
def __init__(self, node_pb):
|
||||
self.node_id = node_pb.node_id
|
||||
self.node_type = NodeType[streaming_pb.NodeType.Name(
|
||||
node_pb.node_type)]
|
||||
self.parallelism = node_pb.parallelism
|
||||
if node_pb.language == Language.PYTHON:
|
||||
func_bytes = node_pb.function # python function descriptor
|
||||
func = function.load_function(func_bytes)
|
||||
self.stream_operator = operator.create_operator(func)
|
||||
self.execution_tasks = [
|
||||
ExecutionTask(task) for task in node_pb.execution_tasks
|
||||
]
|
||||
self.input_edges = [
|
||||
ExecutionEdge(edge, node_pb.language)
|
||||
for edge in node_pb.input_edges
|
||||
]
|
||||
self.output_edges = [
|
||||
ExecutionEdge(edge, node_pb.language)
|
||||
for edge in node_pb.output_edges
|
||||
]
|
||||
|
||||
|
||||
class ExecutionEdge:
|
||||
def __init__(self, edge_pb, language):
|
||||
self.src_node_id = edge_pb.src_node_id
|
||||
self.target_node_id = edge_pb.target_node_id
|
||||
partition_bytes = edge_pb.partition
|
||||
if language == Language.PYTHON:
|
||||
self.partition = partition.load_partition(partition_bytes)
|
||||
|
||||
|
||||
class ExecutionTask:
|
||||
def __init__(self, task_pb):
|
||||
self.task_id = task_pb.task_id
|
||||
self.task_index = task_pb.task_index
|
||||
self.worker_actor = ray.actor.ActorHandle.\
|
||||
_deserialization_helper(task_pb.worker_actor, False)
|
||||
|
||||
|
||||
class ExecutionGraph:
|
||||
def __init__(self, graph_pb: remote_call_pb.ExecutionGraph):
|
||||
self._graph_pb = graph_pb
|
||||
self.execution_nodes = [
|
||||
ExecutionNode(node) for node in graph_pb.execution_nodes
|
||||
]
|
||||
|
||||
def build_time(self):
|
||||
return self._graph_pb.build_time
|
||||
|
||||
def execution_nodes(self):
|
||||
return self.execution_nodes
|
||||
|
||||
def get_execution_task_by_task_id(self, task_id):
|
||||
for execution_node in self.execution_nodes:
|
||||
for task in execution_node.execution_tasks:
|
||||
if task.task_id == task_id:
|
||||
return task
|
||||
raise Exception("Task %s does not exist!".format(task_id))
|
||||
|
||||
def get_execution_node_by_task_id(self, task_id):
|
||||
for execution_node in self.execution_nodes:
|
||||
for task in execution_node.execution_tasks:
|
||||
if task.task_id == task_id:
|
||||
return execution_node
|
||||
raise Exception("Task %s does not exist!".format(task_id))
|
||||
|
||||
def get_task_id2_worker_by_node_id(self, node_id):
|
||||
for execution_node in self.execution_nodes:
|
||||
if execution_node.node_id == node_id:
|
||||
task_id2_worker = {}
|
||||
for task in execution_node.execution_tasks:
|
||||
task_id2_worker[task.task_id] = task.worker_actor
|
||||
return task_id2_worker
|
||||
raise Exception("Node %s does not exist!".format(node_id))
|
||||
@@ -0,0 +1,113 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import ray.streaming.context as context
|
||||
from ray.streaming import message
|
||||
from ray.streaming.operator import OperatorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Processor(ABC):
|
||||
"""The base interface for all processors."""
|
||||
|
||||
@abstractmethod
|
||||
def open(self, collectors, runtime_context):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process(self, record: message.Record):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class StreamingProcessor(Processor, ABC):
|
||||
"""StreamingProcessor is a process unit for a operator."""
|
||||
|
||||
def __init__(self, operator):
|
||||
self.operator = operator
|
||||
self.collectors = None
|
||||
self.runtime_context = None
|
||||
|
||||
def open(self, collectors, runtime_context: context.RuntimeContext):
|
||||
self.collectors = collectors
|
||||
self.runtime_context = runtime_context
|
||||
if self.operator is not None:
|
||||
self.operator.open(collectors, runtime_context)
|
||||
logger.info("Opened Processor {}".format(self))
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class SourceProcessor(StreamingProcessor):
|
||||
"""Processor for :class:`ray.streaming.operator.SourceOperator` """
|
||||
|
||||
def __init__(self, operator):
|
||||
super().__init__(operator)
|
||||
|
||||
def process(self, record):
|
||||
raise Exception("SourceProcessor should not process record")
|
||||
|
||||
def run(self):
|
||||
self.operator.run()
|
||||
|
||||
|
||||
class OneInputProcessor(StreamingProcessor):
|
||||
"""Processor for stream operator with one input"""
|
||||
|
||||
def __init__(self, operator):
|
||||
super().__init__(operator)
|
||||
|
||||
def process(self, record):
|
||||
self.operator.process_element(record)
|
||||
|
||||
|
||||
class TwoInputProcessor(StreamingProcessor):
|
||||
"""Processor for stream operator with two inputs"""
|
||||
|
||||
def __init__(self, operator):
|
||||
super().__init__(operator)
|
||||
self.left_stream = None
|
||||
self.right_stream = None
|
||||
|
||||
def process(self, record: message.Record):
|
||||
if record.stream == self.left_stream:
|
||||
self.operator.process_element(record, None)
|
||||
else:
|
||||
self.operator.process_element(None, record)
|
||||
|
||||
@property
|
||||
def left_stream(self):
|
||||
return self.left_stream
|
||||
|
||||
@left_stream.setter
|
||||
def left_stream(self, value):
|
||||
self._left_stream = value
|
||||
|
||||
@property
|
||||
def right_stream(self):
|
||||
return self.right_stream
|
||||
|
||||
@right_stream.setter
|
||||
def right_stream(self, value):
|
||||
self.right_stream = value
|
||||
|
||||
|
||||
def build_processor(operator_instance):
|
||||
"""Create a processor for the given operator."""
|
||||
operator_type = operator_instance.operator_type()
|
||||
logger.info(
|
||||
"Building StreamProcessor, operator type = {}, operator = {}.".format(
|
||||
operator_type, operator_instance))
|
||||
if operator_type == OperatorType.SOURCE:
|
||||
return SourceProcessor(operator_instance)
|
||||
elif operator_type == OperatorType.ONE_INPUT:
|
||||
return OneInputProcessor(operator_instance)
|
||||
elif operator_type == OperatorType.TWO_INPUT:
|
||||
return TwoInputProcessor(operator_instance)
|
||||
else:
|
||||
raise Exception("Current operator type is not supported")
|
||||
@@ -0,0 +1,158 @@
|
||||
import logging
|
||||
import pickle
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import ray
|
||||
from ray.streaming.collector import OutputCollector
|
||||
from ray.streaming.config import Config
|
||||
from ray.streaming.context import RuntimeContextImpl
|
||||
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamTask(ABC):
|
||||
"""Base class for all streaming tasks. Each task runs a processor."""
|
||||
|
||||
def __init__(self, task_id, processor, worker):
|
||||
self.task_id = task_id
|
||||
self.processor = processor
|
||||
self.worker = worker
|
||||
self.reader = None # DataReader
|
||||
self.writers = {} # ExecutionEdge -> DataWriter
|
||||
self.thread = None
|
||||
self.prepare_task()
|
||||
self.thread = threading.Thread(target=self.run, daemon=True)
|
||||
|
||||
def prepare_task(self):
|
||||
channel_conf = dict(self.worker.config)
|
||||
channel_size = int(
|
||||
self.worker.config.get(Config.CHANNEL_SIZE,
|
||||
Config.CHANNEL_SIZE_DEFAULT))
|
||||
channel_conf[Config.CHANNEL_SIZE] = channel_size
|
||||
channel_conf[Config.TASK_JOB_ID] = ray.runtime_context.\
|
||||
_get_runtime_context().current_driver_id
|
||||
channel_conf[Config.CHANNEL_TYPE] = self.worker.config \
|
||||
.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL)
|
||||
|
||||
execution_graph = self.worker.execution_graph
|
||||
execution_node = self.worker.execution_node
|
||||
# writers
|
||||
collectors = []
|
||||
for edge in execution_node.output_edges:
|
||||
output_actor_ids = {}
|
||||
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
|
||||
edge.target_node_id)
|
||||
for target_task_id, target_actor in task_id2_worker.items():
|
||||
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
|
||||
execution_graph.build_time())
|
||||
output_actor_ids[channel_name] = target_actor
|
||||
if len(output_actor_ids) > 0:
|
||||
channel_ids = list(output_actor_ids.keys())
|
||||
to_actor_ids = list(output_actor_ids.values())
|
||||
writer = DataWriter(channel_ids, to_actor_ids, channel_conf)
|
||||
logger.info("Create DataWriter succeed.")
|
||||
self.writers[edge] = writer
|
||||
collectors.append(
|
||||
OutputCollector(channel_ids, writer, edge.partition))
|
||||
|
||||
# readers
|
||||
input_actor_ids = {}
|
||||
for edge in execution_node.input_edges:
|
||||
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
|
||||
edge.src_node_id)
|
||||
for src_task_id, src_actor in task_id2_worker.items():
|
||||
channel_name = ChannelID.gen_id(src_task_id, self.task_id,
|
||||
execution_graph.build_time())
|
||||
input_actor_ids[channel_name] = src_actor
|
||||
if len(input_actor_ids) > 0:
|
||||
channel_ids = list(input_actor_ids.keys())
|
||||
from_actor_ids = list(input_actor_ids.values())
|
||||
logger.info("Create DataReader, channels {}.".format(channel_ids))
|
||||
self.reader = DataReader(channel_ids, from_actor_ids, channel_conf)
|
||||
|
||||
def exit_handler():
|
||||
# Make DataReader stop read data when MockQueue destructor
|
||||
# gets called to avoid crash
|
||||
self.cancel_task()
|
||||
|
||||
import atexit
|
||||
atexit.register(exit_handler)
|
||||
|
||||
runtime_context = RuntimeContextImpl(
|
||||
self.worker.execution_task.task_id,
|
||||
self.worker.execution_task.task_index, execution_node.parallelism)
|
||||
logger.info("open Processor {}".format(self.processor))
|
||||
self.processor.open(collectors, runtime_context)
|
||||
|
||||
@abstractmethod
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
self.thread.start()
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_task(self):
|
||||
pass
|
||||
|
||||
|
||||
class InputStreamTask(StreamTask):
|
||||
"""Base class for stream tasks that execute a
|
||||
:class:`runtime.processor.OneInputProcessor` or
|
||||
:class:`runtime.processor.TwoInputProcessor` """
|
||||
|
||||
def __init__(self, task_id, processor_instance, worker):
|
||||
super().__init__(task_id, processor_instance, worker)
|
||||
self.running = True
|
||||
self.stopped = False
|
||||
self.read_timeout_millis = \
|
||||
int(worker.config.get(Config.READ_TIMEOUT_MS,
|
||||
Config.DEFAULT_READ_TIMEOUT_MS))
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
while self.running:
|
||||
item = self.reader.read(self.read_timeout_millis)
|
||||
if item is not None:
|
||||
msg_data = item.body()
|
||||
msg = pickle.loads(msg_data)
|
||||
self.processor.process(msg)
|
||||
self.stopped = True
|
||||
|
||||
def cancel_task(self):
|
||||
self.running = False
|
||||
while not self.stopped:
|
||||
pass
|
||||
|
||||
|
||||
class OneInputStreamTask(InputStreamTask):
|
||||
"""A stream task for executing :class:`runtime.processor.OneInputProcessor`
|
||||
"""
|
||||
|
||||
def __init__(self, task_id, processor_instance, worker):
|
||||
super().__init__(task_id, processor_instance, worker)
|
||||
|
||||
|
||||
class SourceStreamTask(StreamTask):
|
||||
"""A stream task for executing :class:`runtime.processor.SourceProcessor`
|
||||
"""
|
||||
|
||||
def __init__(self, task_id, processor_instance, worker):
|
||||
super().__init__(task_id, processor_instance, worker)
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
self.processor.run()
|
||||
|
||||
def cancel_task(self):
|
||||
pass
|
||||
@@ -0,0 +1,104 @@
|
||||
import logging
|
||||
|
||||
import ray
|
||||
import ray.streaming._streaming as _streaming
|
||||
import ray.streaming.generated.remote_call_pb2 as remote_call_pb
|
||||
import ray.streaming.runtime.processor as processor
|
||||
from ray._raylet import PythonFunctionDescriptor
|
||||
from ray.streaming.config import Config
|
||||
from ray.streaming.runtime.graph import ExecutionGraph
|
||||
from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ray.remote
|
||||
class JobWorker(object):
|
||||
"""A streaming job worker is used to execute user-defined function and
|
||||
interact with `JobMaster`"""
|
||||
|
||||
def __init__(self):
|
||||
self.worker_context = None
|
||||
self.task_id = None
|
||||
self.config = None
|
||||
self.execution_graph = None
|
||||
self.execution_task = None
|
||||
self.execution_node = None
|
||||
self.stream_processor = None
|
||||
self.task = None
|
||||
self.reader_client = None
|
||||
self.writer_client = None
|
||||
|
||||
def init(self, worker_context_bytes):
|
||||
worker_context = remote_call_pb.WorkerContext()
|
||||
worker_context.ParseFromString(worker_context_bytes)
|
||||
self.worker_context = worker_context
|
||||
self.task_id = worker_context.task_id
|
||||
self.config = worker_context.conf
|
||||
execution_graph = ExecutionGraph(worker_context.graph)
|
||||
self.execution_graph = execution_graph
|
||||
self.execution_task = self.execution_graph. \
|
||||
get_execution_task_by_task_id(self.task_id)
|
||||
self.execution_node = self.execution_graph. \
|
||||
get_execution_node_by_task_id(self.task_id)
|
||||
operator = self.execution_node.stream_operator
|
||||
self.stream_processor = processor.build_processor(operator)
|
||||
logger.info(
|
||||
"Initializing JobWorker, task_id: {}, operator: {}.".format(
|
||||
self.task_id, self.stream_processor))
|
||||
|
||||
if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL):
|
||||
core_worker = ray.worker.global_worker.core_worker
|
||||
reader_async_func = PythonFunctionDescriptor(
|
||||
__name__, self.on_reader_message.__name__,
|
||||
self.__class__.__name__)
|
||||
reader_sync_func = PythonFunctionDescriptor(
|
||||
__name__, self.on_reader_message_sync.__name__,
|
||||
self.__class__.__name__)
|
||||
self.reader_client = _streaming.ReaderClient(
|
||||
core_worker, reader_async_func, reader_sync_func)
|
||||
writer_async_func = PythonFunctionDescriptor(
|
||||
__name__, self.on_writer_message.__name__,
|
||||
self.__class__.__name__)
|
||||
writer_sync_func = PythonFunctionDescriptor(
|
||||
__name__, self.on_writer_message_sync.__name__,
|
||||
self.__class__.__name__)
|
||||
self.writer_client = _streaming.WriterClient(
|
||||
core_worker, writer_async_func, writer_sync_func)
|
||||
|
||||
self.task = self.create_stream_task()
|
||||
self.task.start()
|
||||
logger.info("JobWorker init succeed")
|
||||
return True
|
||||
|
||||
def create_stream_task(self):
|
||||
if isinstance(self.stream_processor, processor.SourceProcessor):
|
||||
return SourceStreamTask(self.task_id, self.stream_processor, self)
|
||||
elif isinstance(self.stream_processor, processor.OneInputProcessor):
|
||||
return OneInputStreamTask(self.task_id, self.stream_processor,
|
||||
self)
|
||||
else:
|
||||
raise Exception("Unsupported processor type: " +
|
||||
type(self.stream_processor))
|
||||
|
||||
def on_reader_message(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
self.reader_client.on_reader_message(buffer)
|
||||
|
||||
def on_reader_message_sync(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
if self.reader_client is None:
|
||||
return b" " * 4 # special flag to indicate this actor not ready
|
||||
result = self.reader_client.on_reader_message_sync(buffer)
|
||||
return result.to_pybytes()
|
||||
|
||||
def on_writer_message(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
self.writer_client.on_writer_message(buffer)
|
||||
|
||||
def on_writer_message_sync(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
if self.writer_client is None:
|
||||
return b" " * 4 # special flag to indicate this actor not ready
|
||||
result = self.writer_client.on_writer_message_sync(buffer)
|
||||
return result.to_pybytes()
|
||||
Reference in New Issue
Block a user