mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 22:20:31 +08:00
[Streaming] operator chain (#8910)
This commit is contained in:
@@ -94,6 +94,18 @@ class Stream(ABC):
|
||||
def get_language(self):
|
||||
pass
|
||||
|
||||
def forward(self):
|
||||
"""Set the partition function of this {@link Stream} so that output
|
||||
elements are forwarded to next operator locally."""
|
||||
self._gateway_client().call_method(self._j_stream, "forward")
|
||||
return self
|
||||
|
||||
def disable_chain(self):
|
||||
"""Disable chain for this stream so that it will be run in a separate
|
||||
task."""
|
||||
self._gateway_client().call_method(self._j_stream, "disableChain")
|
||||
return self
|
||||
|
||||
def _gateway_client(self):
|
||||
return self.get_streaming_context()._gateway_client
|
||||
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import enum
|
||||
import importlib
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray import streaming
|
||||
from ray.streaming import function
|
||||
from ray.streaming import message
|
||||
from ray.streaming.collector import Collector
|
||||
from ray.streaming.runtime import gateway_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OperatorType(enum.Enum):
|
||||
SOURCE = 0 # Sources are where your program reads its input from
|
||||
@@ -227,15 +231,93 @@ class UnionOperator(StreamOperator, OneInputOperator):
|
||||
self.collect(record)
|
||||
|
||||
|
||||
_function_to_operator = {
|
||||
function.SourceFunction: SourceOperator,
|
||||
function.MapFunction: MapOperator,
|
||||
function.FlatMapFunction: FlatMapOperator,
|
||||
function.FilterFunction: FilterOperator,
|
||||
function.KeyFunction: KeyByOperator,
|
||||
function.ReduceFunction: ReduceOperator,
|
||||
function.SinkFunction: SinkOperator,
|
||||
}
|
||||
class ChainedOperator(StreamOperator, ABC):
|
||||
class ForwardCollector(Collector):
|
||||
def __init__(self, succeeding_operator):
|
||||
self.succeeding_operator = succeeding_operator
|
||||
|
||||
def collect(self, record):
|
||||
self.succeeding_operator.process_element(record)
|
||||
|
||||
def __init__(self, operators, configs):
|
||||
super().__init__(operators[0].func)
|
||||
self.operators = operators
|
||||
self.configs = configs
|
||||
|
||||
def open(self, collectors, runtime_context):
|
||||
# Dont' call super.open() as we `open` every operator separately.
|
||||
num_operators = len(self.operators)
|
||||
succeeding_collectors = [
|
||||
ChainedOperator.ForwardCollector(operator)
|
||||
for operator in self.operators[1:]
|
||||
]
|
||||
for i in range(0, num_operators - 1):
|
||||
forward_collectors = [succeeding_collectors[i]]
|
||||
self.operators[i].open(
|
||||
forward_collectors,
|
||||
self.__create_runtime_context(runtime_context, i))
|
||||
self.operators[-1].open(
|
||||
collectors,
|
||||
self.__create_runtime_context(runtime_context, num_operators - 1))
|
||||
|
||||
def operator_type(self) -> OperatorType:
|
||||
return self.operators[0].operator_type()
|
||||
|
||||
def __create_runtime_context(self, runtime_context, index):
|
||||
def get_config():
|
||||
return self.configs[index]
|
||||
|
||||
runtime_context.get_config = get_config
|
||||
return runtime_context
|
||||
|
||||
@staticmethod
|
||||
def new_chained_operator(operators, configs):
|
||||
operator_type = operators[0].operator_type()
|
||||
logger.info(
|
||||
"Building ChainedOperator from operators {} and configs {}."
|
||||
.format(operators, configs))
|
||||
if operator_type == OperatorType.SOURCE:
|
||||
return ChainedSourceOperator(operators, configs)
|
||||
elif operator_type == OperatorType.ONE_INPUT:
|
||||
return ChainedOneInputOperator(operators, configs)
|
||||
elif operator_type == OperatorType.TWO_INPUT:
|
||||
return ChainedTwoInputOperator(operators, configs)
|
||||
else:
|
||||
raise Exception("Current operator type is not supported")
|
||||
|
||||
|
||||
class ChainedSourceOperator(ChainedOperator):
|
||||
def __init__(self, operators, configs):
|
||||
super().__init__(operators, configs)
|
||||
|
||||
def run(self):
|
||||
self.operators[0].run()
|
||||
|
||||
|
||||
class ChainedOneInputOperator(ChainedOperator):
|
||||
def __init__(self, operators, configs):
|
||||
super().__init__(operators, configs)
|
||||
|
||||
def process_element(self, record):
|
||||
self.operators[0].process_element(record)
|
||||
|
||||
|
||||
class ChainedTwoInputOperator(ChainedOperator):
|
||||
def __init__(self, operators, configs):
|
||||
super().__init__(operators, configs)
|
||||
|
||||
def process_element(self, record1, record2):
|
||||
self.operators[0].process_element(record1, record2)
|
||||
|
||||
|
||||
def load_chained_operator(chained_operator_bytes: bytes):
|
||||
"""Load chained operator from serialized operators and configs"""
|
||||
serialized_operators, configs = gateway_client.deserialize(
|
||||
chained_operator_bytes)
|
||||
operators = [
|
||||
load_operator(desc_bytes) for desc_bytes in serialized_operators
|
||||
]
|
||||
return ChainedOperator.new_chained_operator(operators, configs)
|
||||
|
||||
|
||||
def load_operator(descriptor_operator_bytes: bytes):
|
||||
@@ -267,6 +349,17 @@ def load_operator(descriptor_operator_bytes: bytes):
|
||||
return cls()
|
||||
|
||||
|
||||
_function_to_operator = {
|
||||
function.SourceFunction: SourceOperator,
|
||||
function.MapFunction: MapOperator,
|
||||
function.FlatMapFunction: FlatMapOperator,
|
||||
function.FilterFunction: FilterOperator,
|
||||
function.KeyFunction: KeyByOperator,
|
||||
function.ReduceFunction: ReduceOperator,
|
||||
function.SinkFunction: SinkOperator,
|
||||
}
|
||||
|
||||
|
||||
def create_operator_with_func(func: function.Function):
|
||||
"""Create an operator according to a :class:`function.Function`
|
||||
|
||||
|
||||
@@ -60,6 +60,17 @@ class RoundRobinPartition(Partition):
|
||||
return self.__partitions
|
||||
|
||||
|
||||
class ForwardPartition(Partition):
|
||||
"""Default partition for operator if the operator can be chained with
|
||||
succeeding operators."""
|
||||
|
||||
def __init__(self):
|
||||
self.__partitions = [0]
|
||||
|
||||
def partition(self, key_record, num_partition: int):
|
||||
return self.__partitions
|
||||
|
||||
|
||||
class SimplePartition(Partition):
|
||||
"""Wrap a python function as subclass of :class:`Partition`"""
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import enum
|
||||
import logging
|
||||
|
||||
import ray
|
||||
import ray.streaming.generated.remote_call_pb2 as remote_call_pb
|
||||
@@ -6,6 +7,8 @@ import ray.streaming.operator as operator
|
||||
import ray.streaming.partition as partition
|
||||
from ray.streaming.generated.streaming_pb2 import Language
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeType(enum.Enum):
|
||||
"""
|
||||
@@ -43,7 +46,13 @@ class ExecutionVertex:
|
||||
self.parallelism = vertex_pb.parallelism
|
||||
if vertex_pb.language == Language.PYTHON:
|
||||
operator_bytes = vertex_pb.operator # python operator descriptor
|
||||
self.stream_operator = operator.load_operator(operator_bytes)
|
||||
if vertex_pb.chained:
|
||||
logger.info("Load chained operator")
|
||||
self.stream_operator = operator.load_chained_operator(
|
||||
operator_bytes)
|
||||
else:
|
||||
logger.info("Load operator")
|
||||
self.stream_operator = operator.load_operator(operator_bytes)
|
||||
self.worker_actor = ray.actor.ActorHandle. \
|
||||
_deserialization_helper(vertex_pb.worker_actor)
|
||||
self.container_id = vertex_pb.container_id
|
||||
|
||||
Reference in New Issue
Block a user