mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 06:14:06 +08:00
[Streaming] Streaming Cross-Lang API (#7464)
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
import logging
|
||||
import pickle
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray import Language
|
||||
from ray.actor import ActorHandle
|
||||
from ray.streaming import function
|
||||
from ray.streaming import message
|
||||
from ray.streaming import partition
|
||||
from ray.streaming.runtime import serialization
|
||||
from ray.streaming.runtime.transfer import ChannelID, DataWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,19 +34,46 @@ class CollectionCollector(Collector):
|
||||
|
||||
|
||||
class OutputCollector(Collector):
|
||||
def __init__(self, channel_ids: typing.List[str], writer: DataWriter,
|
||||
def __init__(self, writer: DataWriter, channel_ids: typing.List[str],
|
||||
target_actors: typing.List[ActorHandle],
|
||||
partition_func: partition.Partition):
|
||||
self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
|
||||
self._writer = writer
|
||||
self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
|
||||
self._target_languages = []
|
||||
for actor in target_actors:
|
||||
if actor._ray_actor_language == Language.PYTHON:
|
||||
self._target_languages.append(function.Language.PYTHON)
|
||||
elif actor._ray_actor_language == Language.JAVA:
|
||||
self._target_languages.append(function.Language.JAVA)
|
||||
else:
|
||||
raise Exception("Unsupported language {}"
|
||||
.format(actor._ray_actor_language))
|
||||
self._partition_func = partition_func
|
||||
self.python_serializer = serialization.PythonSerializer()
|
||||
self.cross_lang_serializer = serialization.CrossLangSerializer()
|
||||
logger.info(
|
||||
"Create OutputCollector, channel_ids {}, partition_func {}".format(
|
||||
channel_ids, partition_func))
|
||||
|
||||
def collect(self, record):
|
||||
partitions = self._partition_func.partition(record,
|
||||
len(self._channel_ids))
|
||||
serialized_message = pickle.dumps(record)
|
||||
partitions = self._partition_func \
|
||||
.partition(record, len(self._channel_ids))
|
||||
python_buffer = None
|
||||
cross_lang_buffer = None
|
||||
for partition_index in partitions:
|
||||
self._writer.write(self._channel_ids[partition_index],
|
||||
serialized_message)
|
||||
if self._target_languages[partition_index] == \
|
||||
function.Language.PYTHON:
|
||||
# avoid repeated serialization
|
||||
if python_buffer is None:
|
||||
python_buffer = self.python_serializer.serialize(record)
|
||||
self._writer.write(
|
||||
self._channel_ids[partition_index],
|
||||
serialization._PYTHON_TYPE_ID + python_buffer)
|
||||
else:
|
||||
# avoid repeated serialization
|
||||
if cross_lang_buffer is None:
|
||||
cross_lang_buffer = self.cross_lang_serializer.serialize(
|
||||
record)
|
||||
self._writer.write(
|
||||
self._channel_ids[partition_index],
|
||||
serialization._CROSS_LANG_TYPE_ID + cross_lang_buffer)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray.streaming import function
|
||||
from ray.streaming import partition
|
||||
@@ -19,7 +19,6 @@ class Stream(ABC):
|
||||
self.streaming_context = input_stream.streaming_context
|
||||
else:
|
||||
self.streaming_context = streaming_context
|
||||
self.parallelism = 1
|
||||
|
||||
def get_streaming_context(self):
|
||||
return self.streaming_context
|
||||
@@ -29,7 +28,8 @@ class Stream(ABC):
|
||||
Returns:
|
||||
the parallelism of this transformation
|
||||
"""
|
||||
return self.parallelism
|
||||
return self._gateway_client(). \
|
||||
call_method(self._j_stream, "getParallelism")
|
||||
|
||||
def set_parallelism(self, parallelism: int):
|
||||
"""Sets the parallelism of this transformation
|
||||
@@ -40,7 +40,6 @@ class Stream(ABC):
|
||||
Returns:
|
||||
self
|
||||
"""
|
||||
self.parallelism = parallelism
|
||||
self._gateway_client(). \
|
||||
call_method(self._j_stream, "setParallelism", parallelism)
|
||||
return self
|
||||
@@ -60,6 +59,10 @@ class Stream(ABC):
|
||||
return self._gateway_client(). \
|
||||
call_method(self._j_stream, "getId")
|
||||
|
||||
@abstractmethod
|
||||
def get_language(self):
|
||||
pass
|
||||
|
||||
def _gateway_client(self):
|
||||
return self.get_streaming_context()._gateway_client
|
||||
|
||||
@@ -75,6 +78,9 @@ class DataStream(Stream):
|
||||
super().__init__(
|
||||
input_stream, j_stream, streaming_context=streaming_context)
|
||||
|
||||
def get_language(self):
|
||||
return function.Language.PYTHON
|
||||
|
||||
def map(self, func):
|
||||
"""
|
||||
Applies a Map transformation on a :class:`DataStream`.
|
||||
@@ -158,6 +164,7 @@ class DataStream(Stream):
|
||||
Returns:
|
||||
A KeyDataStream
|
||||
"""
|
||||
self._check_partition_call()
|
||||
if not isinstance(func, function.KeyFunction):
|
||||
func = function.SimpleKeyFunction(func)
|
||||
j_func = self._gateway_client().create_py_func(
|
||||
@@ -175,6 +182,7 @@ class DataStream(Stream):
|
||||
Returns:
|
||||
The DataStream with broadcast partitioning set.
|
||||
"""
|
||||
self._check_partition_call()
|
||||
self._gateway_client().call_method(self._j_stream, "broadcast")
|
||||
return self
|
||||
|
||||
@@ -191,6 +199,7 @@ class DataStream(Stream):
|
||||
Returns:
|
||||
The DataStream with specified partitioning set.
|
||||
"""
|
||||
self._check_partition_call()
|
||||
if not isinstance(partition_func, partition.Partition):
|
||||
partition_func = partition.SimplePartition(partition_func)
|
||||
j_partition = self._gateway_client().create_py_func(
|
||||
@@ -199,6 +208,16 @@ class DataStream(Stream):
|
||||
call_method(self._j_stream, "partitionBy", j_partition)
|
||||
return self
|
||||
|
||||
def _check_partition_call(self):
|
||||
"""
|
||||
If parent stream is a java stream, we can't call partition related
|
||||
methods in the python stream
|
||||
"""
|
||||
if self.input_stream is not None and \
|
||||
self.input_stream.get_language() == function.Language.JAVA:
|
||||
raise Exception("Partition related methods can't be called on a "
|
||||
"python stream if parent stream is a java stream.")
|
||||
|
||||
def sink(self, func):
|
||||
"""
|
||||
Create a StreamSink with the given sink.
|
||||
@@ -217,8 +236,97 @@ class DataStream(Stream):
|
||||
call_method(self._j_stream, "sink", j_func)
|
||||
return StreamSink(self, j_stream, func)
|
||||
|
||||
def as_java_stream(self):
|
||||
"""
|
||||
Convert this stream as a java JavaDataStream.
|
||||
The converted stream and this stream are the same logical stream,
|
||||
which has same stream id. Changes in converted stream will be reflected
|
||||
in this stream and vice versa.
|
||||
"""
|
||||
j_stream = self._gateway_client(). \
|
||||
call_method(self._j_stream, "asJavaStream")
|
||||
return JavaDataStream(self, j_stream)
|
||||
|
||||
class KeyDataStream(Stream):
|
||||
|
||||
class JavaDataStream(Stream):
|
||||
"""
|
||||
Represents a stream of data which applies a transformation executed by
|
||||
java. It's also a wrapper of java
|
||||
`org.ray.streaming.api.stream.DataStream`
|
||||
"""
|
||||
|
||||
def __init__(self, input_stream, j_stream, streaming_context=None):
|
||||
super().__init__(
|
||||
input_stream, j_stream, streaming_context=streaming_context)
|
||||
|
||||
def get_language(self):
|
||||
return function.Language.JAVA
|
||||
|
||||
def map(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.map"""
|
||||
return JavaDataStream(self, self._unary_call("map", java_func_class))
|
||||
|
||||
def flat_map(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.flatMap"""
|
||||
return JavaDataStream(self, self._unary_call("flatMap",
|
||||
java_func_class))
|
||||
|
||||
def filter(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.filter"""
|
||||
return JavaDataStream(self, self._unary_call("filter",
|
||||
java_func_class))
|
||||
|
||||
def key_by(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.keyBy"""
|
||||
self._check_partition_call()
|
||||
return JavaKeyDataStream(self,
|
||||
self._unary_call("keyBy", java_func_class))
|
||||
|
||||
def broadcast(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.broadcast"""
|
||||
self._check_partition_call()
|
||||
return JavaDataStream(self,
|
||||
self._unary_call("broadcast", java_func_class))
|
||||
|
||||
def partition_by(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.partitionBy"""
|
||||
self._check_partition_call()
|
||||
return JavaDataStream(self,
|
||||
self._unary_call("partitionBy", java_func_class))
|
||||
|
||||
def sink(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.sink"""
|
||||
return JavaStreamSink(self, self._unary_call("sink", java_func_class))
|
||||
|
||||
def as_python_stream(self):
|
||||
"""
|
||||
Convert this stream as a python DataStream.
|
||||
The converted stream and this stream are the same logical stream,
|
||||
which has same stream id. Changes in converted stream will be reflected
|
||||
in this stream and vice versa.
|
||||
"""
|
||||
j_stream = self._gateway_client(). \
|
||||
call_method(self._j_stream, "asPythonStream")
|
||||
return DataStream(self, j_stream)
|
||||
|
||||
def _check_partition_call(self):
|
||||
"""
|
||||
If parent stream is a python stream, we can't call partition related
|
||||
methods in the java stream
|
||||
"""
|
||||
if self.input_stream is not None and \
|
||||
self.input_stream.get_language() == function.Language.PYTHON:
|
||||
raise Exception("Partition related methods can't be called on a"
|
||||
"java stream if parent stream is a python stream.")
|
||||
|
||||
def _unary_call(self, func_name, java_func_class):
|
||||
j_func = self._gateway_client().new_instance(java_func_class)
|
||||
j_stream = self._gateway_client(). \
|
||||
call_method(self._j_stream, func_name, j_func)
|
||||
return j_stream
|
||||
|
||||
|
||||
class KeyDataStream(DataStream):
|
||||
"""Represents a DataStream returned by a key-by operation.
|
||||
Wrapper of java io.ray.streaming.python.stream.PythonKeyDataStream
|
||||
"""
|
||||
@@ -251,6 +359,43 @@ class KeyDataStream(Stream):
|
||||
call_method(self._j_stream, "reduce", j_func)
|
||||
return DataStream(self, j_stream)
|
||||
|
||||
def as_java_stream(self):
|
||||
"""
|
||||
Convert this stream as a java KeyDataStream.
|
||||
The converted stream and this stream are the same logical stream,
|
||||
which has same stream id. Changes in converted stream will be reflected
|
||||
in this stream and vice versa.
|
||||
"""
|
||||
j_stream = self._gateway_client(). \
|
||||
call_method(self._j_stream, "asJavaStream")
|
||||
return JavaKeyDataStream(self, j_stream)
|
||||
|
||||
|
||||
class JavaKeyDataStream(JavaDataStream):
|
||||
"""
|
||||
Represents a DataStream returned by a key-by operation in java.
|
||||
Wrapper of org.ray.streaming.api.stream.KeyDataStream
|
||||
"""
|
||||
|
||||
def __init__(self, input_stream, j_stream):
|
||||
super().__init__(input_stream, j_stream)
|
||||
|
||||
def reduce(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.KeyDataStream.reduce"""
|
||||
return JavaDataStream(self,
|
||||
super()._unary_call("reduce", java_func_class))
|
||||
|
||||
def as_python_stream(self):
|
||||
"""
|
||||
Convert this stream as a python KeyDataStream.
|
||||
The converted stream and this stream are the same logical stream,
|
||||
which has same stream id. Changes in converted stream will be reflected
|
||||
in this stream and vice versa.
|
||||
"""
|
||||
j_stream = self._gateway_client(). \
|
||||
call_method(self._j_stream, "asPythonStream")
|
||||
return KeyDataStream(self, j_stream)
|
||||
|
||||
|
||||
class StreamSource(DataStream):
|
||||
"""Represents a source of the DataStream.
|
||||
@@ -261,9 +406,12 @@ class StreamSource(DataStream):
|
||||
super().__init__(None, j_stream, streaming_context=streaming_context)
|
||||
self.source_func = source_func
|
||||
|
||||
def get_language(self):
|
||||
return function.Language.PYTHON
|
||||
|
||||
@staticmethod
|
||||
def build_source(streaming_context, func):
|
||||
"""Build a StreamSource source from a collection.
|
||||
"""Build a StreamSource source from a source function.
|
||||
Args:
|
||||
streaming_context: Stream context
|
||||
func: A instance of `SourceFunction`
|
||||
@@ -275,6 +423,34 @@ class StreamSource(DataStream):
|
||||
return StreamSource(j_stream, streaming_context, func)
|
||||
|
||||
|
||||
class JavaStreamSource(JavaDataStream):
|
||||
"""Represents a source of the java DataStream.
|
||||
Wrapper of java org.ray.streaming.api.stream.DataStreamSource
|
||||
"""
|
||||
|
||||
def __init__(self, j_stream, streaming_context):
|
||||
super().__init__(None, j_stream, streaming_context=streaming_context)
|
||||
|
||||
def get_language(self):
|
||||
return function.Language.JAVA
|
||||
|
||||
@staticmethod
|
||||
def build_source(streaming_context, java_source_func_class):
|
||||
"""Build a java StreamSource source from a java source function.
|
||||
Args:
|
||||
streaming_context: Stream context
|
||||
java_source_func_class: qualified class name of java SourceFunction
|
||||
Returns:
|
||||
A java StreamSource
|
||||
"""
|
||||
j_func = streaming_context._gateway_client() \
|
||||
.new_instance(java_source_func_class)
|
||||
j_stream = streaming_context._gateway_client() \
|
||||
.call_function("org.ray.streaming.api.stream.DataStreamSource"
|
||||
"fromSource", streaming_context._j_ctx, j_func)
|
||||
return JavaStreamSource(j_stream, streaming_context)
|
||||
|
||||
|
||||
class StreamSink(Stream):
|
||||
"""Represents a sink of the DataStream.
|
||||
Wrapper of java io.ray.streaming.python.stream.PythonStreamSink
|
||||
@@ -282,3 +458,18 @@ class StreamSink(Stream):
|
||||
|
||||
def __init__(self, input_stream, j_stream, func):
|
||||
super().__init__(input_stream, j_stream)
|
||||
|
||||
def get_language(self):
|
||||
return function.Language.PYTHON
|
||||
|
||||
|
||||
class JavaStreamSink(Stream):
|
||||
"""Represents a sink of the java DataStream.
|
||||
Wrapper of java org.ray.streaming.api.stream.StreamSink
|
||||
"""
|
||||
|
||||
def __init__(self, input_stream, j_stream):
|
||||
super().__init__(input_stream, j_stream)
|
||||
|
||||
def get_language(self):
|
||||
return function.Language.JAVA
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
import enum
|
||||
import importlib
|
||||
import inspect
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray import cloudpickle
|
||||
from ray.streaming.runtime import gateway_client
|
||||
|
||||
|
||||
class Language(enum.Enum):
|
||||
JAVA = 0
|
||||
PYTHON = 1
|
||||
|
||||
|
||||
class Function(ABC):
|
||||
"""The base interface for all user-defined functions."""
|
||||
|
||||
@@ -60,6 +66,7 @@ class MapFunction(Function):
|
||||
for each input element.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def map(self, value):
|
||||
pass
|
||||
|
||||
@@ -70,6 +77,7 @@ class FlatMapFunction(Function):
|
||||
transform them into zero, one, or more elements.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def flat_map(self, value, collector):
|
||||
"""Takes an element from the input data set and transforms it into zero,
|
||||
one, or more elements.
|
||||
@@ -87,6 +95,7 @@ class FilterFunction(Function):
|
||||
The predicate decides whether to keep the element, or to discard it.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def filter(self, value):
|
||||
"""The filter function that evaluates the predicate.
|
||||
|
||||
@@ -106,6 +115,7 @@ class KeyFunction(Function):
|
||||
deterministic key for that object.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def key_by(self, value):
|
||||
"""User-defined function that deterministically extracts the key from
|
||||
an object.
|
||||
@@ -126,6 +136,7 @@ class ReduceFunction(Function):
|
||||
them into one.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reduce(self, old_value, new_value):
|
||||
"""
|
||||
The core method of ReduceFunction, combining two values into one value
|
||||
@@ -145,6 +156,7 @@ class ReduceFunction(Function):
|
||||
class SinkFunction(Function):
|
||||
"""Interface for implementing user defined sink functionality."""
|
||||
|
||||
@abstractmethod
|
||||
def sink(self, value):
|
||||
"""Writes the given value to the sink. This function is called for
|
||||
every record."""
|
||||
@@ -283,7 +295,8 @@ def load_function(descriptor_func_bytes: bytes):
|
||||
Returns:
|
||||
a streaming function
|
||||
"""
|
||||
function_bytes, module_name, class_name, function_name, function_interface\
|
||||
assert len(descriptor_func_bytes) > 0
|
||||
function_bytes, module_name, function_name, function_interface\
|
||||
= gateway_client.deserialize(descriptor_func_bytes)
|
||||
if function_bytes:
|
||||
return deserialize(function_bytes)
|
||||
@@ -292,16 +305,18 @@ def load_function(descriptor_func_bytes: bytes):
|
||||
assert function_interface
|
||||
function_interface = getattr(sys.modules[__name__], function_interface)
|
||||
mod = importlib.import_module(module_name)
|
||||
if class_name:
|
||||
assert function_name is None
|
||||
cls = getattr(mod, class_name)
|
||||
assert issubclass(cls, function_interface)
|
||||
return cls()
|
||||
else:
|
||||
assert function_name
|
||||
func = getattr(mod, function_name)
|
||||
assert function_name
|
||||
func = getattr(mod, function_name)
|
||||
# If func is a python function, user function is a simple python
|
||||
# function, which will be wrapped as a SimpleXXXFunction.
|
||||
# If func is a python class, user function is a sub class
|
||||
# of XXXFunction.
|
||||
if inspect.isfunction(func):
|
||||
simple_func_class = _get_simple_function_class(function_interface)
|
||||
return simple_func_class(func)
|
||||
else:
|
||||
assert issubclass(func, function_interface)
|
||||
return func()
|
||||
|
||||
|
||||
def _get_simple_function_class(function_interface):
|
||||
|
||||
@@ -8,6 +8,14 @@ class Record:
|
||||
def __repr__(self):
|
||||
return "Record(%s)".format(self.value)
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(self) is type(other):
|
||||
return (self.stream, self.value) == (other.stream, other.value)
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.stream, self.value))
|
||||
|
||||
|
||||
class KeyRecord(Record):
|
||||
"""Data record in a keyed data stream"""
|
||||
@@ -15,3 +23,12 @@ class KeyRecord(Record):
|
||||
def __init__(self, key, value):
|
||||
super().__init__(value)
|
||||
self.key = key
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(self) is type(other):
|
||||
return (self.stream, self.key, self.value) ==\
|
||||
(other.stream, other.key, other.value)
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.stream, self.key, self.value))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import importlib
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray import cloudpickle
|
||||
@@ -96,22 +97,22 @@ def load_partition(descriptor_partition_bytes: bytes):
|
||||
Returns:
|
||||
partition function
|
||||
"""
|
||||
partition_bytes, module_name, class_name, function_name =\
|
||||
assert len(descriptor_partition_bytes) > 0
|
||||
partition_bytes, module_name, function_name =\
|
||||
gateway_client.deserialize(descriptor_partition_bytes)
|
||||
if partition_bytes:
|
||||
return deserialize(partition_bytes)
|
||||
else:
|
||||
assert module_name
|
||||
mod = importlib.import_module(module_name)
|
||||
# If class_name is not None, user partition is a sub class
|
||||
# of Partition.
|
||||
# If function_name is not None, user partition is a simple python
|
||||
assert function_name
|
||||
func = getattr(mod, function_name)
|
||||
# If func is a python function, user partition is a simple python
|
||||
# function, which will be wrapped as a SimplePartition.
|
||||
if class_name:
|
||||
assert function_name is None
|
||||
cls = getattr(mod, class_name)
|
||||
return cls()
|
||||
else:
|
||||
assert function_name
|
||||
func = getattr(mod, function_name)
|
||||
# If func is a python class, user partition is a sub class
|
||||
# of Partition.
|
||||
if inspect.isfunction(func):
|
||||
return SimplePartition(func)
|
||||
else:
|
||||
assert issubclass(func, Partition)
|
||||
return func()
|
||||
|
||||
@@ -55,6 +55,11 @@ class GatewayClient:
|
||||
call = self._python_gateway_actor.callMethod.remote(java_params)
|
||||
return deserialize(ray.get(call))
|
||||
|
||||
def new_instance(self, java_class_name):
|
||||
call = self._python_gateway_actor.newInstance.remote(
|
||||
serialize(java_class_name))
|
||||
return deserialize(ray.get(call))
|
||||
|
||||
|
||||
def serialize(obj) -> bytes:
|
||||
"""Serialize a python object which can be deserialized by `PythonGateway`
|
||||
|
||||
@@ -53,7 +53,9 @@ class ExecutionEdge:
|
||||
self.src_node_id = edge_pb.src_node_id
|
||||
self.target_node_id = edge_pb.target_node_id
|
||||
partition_bytes = edge_pb.partition
|
||||
if language == Language.PYTHON:
|
||||
# Sink node doesn't have partition function,
|
||||
# so we only deserialize partition_bytes when it's not None or empty
|
||||
if language == Language.PYTHON and partition_bytes:
|
||||
self.partition = partition.load_partition(partition_bytes)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import pickle
|
||||
import msgpack
|
||||
from ray.streaming import message
|
||||
|
||||
_RECORD_TYPE_ID = 0
|
||||
_KEY_RECORD_TYPE_ID = 1
|
||||
_CROSS_LANG_TYPE_ID = b"0"
|
||||
_JAVA_TYPE_ID = b"1"
|
||||
_PYTHON_TYPE_ID = b"2"
|
||||
|
||||
|
||||
class Serializer(ABC):
|
||||
@abstractmethod
|
||||
def serialize(self, obj):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def deserialize(self, serialized_bytes):
|
||||
pass
|
||||
|
||||
|
||||
class PythonSerializer(Serializer):
|
||||
def serialize(self, obj):
|
||||
return pickle.dumps(obj)
|
||||
|
||||
def deserialize(self, serialized_bytes):
|
||||
return pickle.loads(serialized_bytes)
|
||||
|
||||
|
||||
class CrossLangSerializer(Serializer):
|
||||
"""Serialize stream element between java/python"""
|
||||
|
||||
def serialize(self, obj):
|
||||
if type(obj) is message.Record:
|
||||
fields = [_RECORD_TYPE_ID, obj.stream, obj.value]
|
||||
elif type(obj) is message.KeyRecord:
|
||||
fields = [_KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value]
|
||||
else:
|
||||
raise Exception("Unsupported value {}".format(obj))
|
||||
return msgpack.packb(fields, use_bin_type=True)
|
||||
|
||||
def deserialize(self, data):
|
||||
fields = msgpack.unpackb(data, raw=False)
|
||||
if fields[0] == _RECORD_TYPE_ID:
|
||||
stream, value = fields[1:]
|
||||
record = message.Record(value)
|
||||
record.stream = stream
|
||||
return record
|
||||
elif fields[0] == _KEY_RECORD_TYPE_ID:
|
||||
stream, key, value = fields[1:]
|
||||
key_record = message.KeyRecord(key, value)
|
||||
key_record.stream = stream
|
||||
return key_record
|
||||
else:
|
||||
raise Exception("Unsupported type id {}, type {}".format(
|
||||
fields[0], type(fields[0])))
|
||||
@@ -1,11 +1,13 @@
|
||||
import logging
|
||||
import pickle
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray.streaming.collector import OutputCollector
|
||||
from ray.streaming.config import Config
|
||||
from ray.streaming.context import RuntimeContextImpl
|
||||
from ray.streaming.runtime import serialization
|
||||
from ray.streaming.runtime.serialization import \
|
||||
PythonSerializer, CrossLangSerializer
|
||||
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -38,36 +40,40 @@ class StreamTask(ABC):
|
||||
# writers
|
||||
collectors = []
|
||||
for edge in execution_node.output_edges:
|
||||
output_actor_ids = {}
|
||||
output_actors_map = {}
|
||||
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
|
||||
edge.target_node_id)
|
||||
for target_task_id, target_actor in task_id2_worker.items():
|
||||
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
|
||||
execution_graph.build_time())
|
||||
output_actor_ids[channel_name] = target_actor
|
||||
if len(output_actor_ids) > 0:
|
||||
channel_ids = list(output_actor_ids.keys())
|
||||
to_actor_ids = list(output_actor_ids.values())
|
||||
writer = DataWriter(channel_ids, to_actor_ids, channel_conf)
|
||||
logger.info("Create DataWriter succeed.")
|
||||
output_actors_map[channel_name] = target_actor
|
||||
if len(output_actors_map) > 0:
|
||||
channel_ids = list(output_actors_map.keys())
|
||||
target_actors = list(output_actors_map.values())
|
||||
logger.info(
|
||||
"Create DataWriter channel_ids {}, target_actors {}."
|
||||
.format(channel_ids, target_actors))
|
||||
writer = DataWriter(channel_ids, target_actors, channel_conf)
|
||||
self.writers[edge] = writer
|
||||
collectors.append(
|
||||
OutputCollector(channel_ids, writer, edge.partition))
|
||||
OutputCollector(writer, channel_ids, target_actors,
|
||||
edge.partition))
|
||||
|
||||
# readers
|
||||
input_actor_ids = {}
|
||||
input_actor_map = {}
|
||||
for edge in execution_node.input_edges:
|
||||
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
|
||||
edge.src_node_id)
|
||||
for src_task_id, src_actor in task_id2_worker.items():
|
||||
channel_name = ChannelID.gen_id(src_task_id, self.task_id,
|
||||
execution_graph.build_time())
|
||||
input_actor_ids[channel_name] = src_actor
|
||||
if len(input_actor_ids) > 0:
|
||||
channel_ids = list(input_actor_ids.keys())
|
||||
from_actor_ids = list(input_actor_ids.values())
|
||||
logger.info("Create DataReader, channels {}.".format(channel_ids))
|
||||
self.reader = DataReader(channel_ids, from_actor_ids, channel_conf)
|
||||
input_actor_map[channel_name] = src_actor
|
||||
if len(input_actor_map) > 0:
|
||||
channel_ids = list(input_actor_map.keys())
|
||||
from_actors = list(input_actor_map.values())
|
||||
logger.info("Create DataReader, channels {}, input_actors {}."
|
||||
.format(channel_ids, from_actors))
|
||||
self.reader = DataReader(channel_ids, from_actors, channel_conf)
|
||||
|
||||
def exit_handler():
|
||||
# Make DataReader stop read data when MockQueue destructor
|
||||
@@ -111,6 +117,8 @@ class InputStreamTask(StreamTask):
|
||||
self.read_timeout_millis = \
|
||||
int(worker.config.get(Config.READ_TIMEOUT_MS,
|
||||
Config.DEFAULT_READ_TIMEOUT_MS))
|
||||
self.python_serializer = PythonSerializer()
|
||||
self.cross_lang_serializer = CrossLangSerializer()
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
@@ -120,7 +128,11 @@ class InputStreamTask(StreamTask):
|
||||
item = self.reader.read(self.read_timeout_millis)
|
||||
if item is not None:
|
||||
msg_data = item.body()
|
||||
msg = pickle.loads(msg_data)
|
||||
type_id = msg_data[:1]
|
||||
if (type_id == serialization._PYTHON_TYPE_ID):
|
||||
msg = self.python_serializer.deserialize(msg_data[1:])
|
||||
else:
|
||||
msg = self.cross_lang_serializer.deserialize(msg_data[1:])
|
||||
self.processor.process(msg)
|
||||
self.stopped = True
|
||||
|
||||
|
||||
@@ -147,13 +147,17 @@ class ChannelCreationParametersBuilder:
|
||||
wrap initial parameters needed by a streaming queue
|
||||
"""
|
||||
_java_reader_async_function_descriptor = JavaFunctionDescriptor(
|
||||
"io.ray.streaming.runtime.worker", "onReaderMessage", "([B)V")
|
||||
"io.ray.streaming.runtime.worker.JobWorker", "onReaderMessage",
|
||||
"([B)V")
|
||||
_java_reader_sync_function_descriptor = JavaFunctionDescriptor(
|
||||
"io.ray.streaming.runtime.worker", "onReaderMessageSync", "([B)[B")
|
||||
"io.ray.streaming.runtime.worker.JobWorker", "onReaderMessageSync",
|
||||
"([B)[B")
|
||||
_java_writer_async_function_descriptor = JavaFunctionDescriptor(
|
||||
"io.ray.streaming.runtime.worker", "onWriterMessage", "([B)V")
|
||||
"io.ray.streaming.runtime.worker.JobWorker", "onWriterMessage",
|
||||
"([B)V")
|
||||
_java_writer_sync_function_descriptor = JavaFunctionDescriptor(
|
||||
"io.ray.streaming.runtime.worker", "onWriterMessageSync", "([B)[B")
|
||||
"io.ray.streaming.runtime.worker.JobWorker", "onWriterMessageSync",
|
||||
"([B)[B")
|
||||
_python_reader_async_function_descriptor = PythonFunctionDescriptor(
|
||||
"ray.streaming.runtime.worker", "on_reader_message", "JobWorker")
|
||||
_python_reader_sync_function_descriptor = PythonFunctionDescriptor(
|
||||
|
||||
@@ -10,6 +10,9 @@ from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# special flag to indicate this actor not ready
|
||||
_NOT_READY_FLAG_ = b" " * 4
|
||||
|
||||
|
||||
@ray.remote
|
||||
class JobWorker(object):
|
||||
@@ -66,23 +69,31 @@ class JobWorker(object):
|
||||
type(self.stream_processor))
|
||||
|
||||
def on_reader_message(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
"""Called by upstream queue writer to send data message to downstream
|
||||
queue reader.
|
||||
"""
|
||||
self.reader_client.on_reader_message(buffer)
|
||||
|
||||
def on_reader_message_sync(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
"""Called by upstream queue writer to send control message to downstream
|
||||
downstream queue reader.
|
||||
"""
|
||||
if self.reader_client is None:
|
||||
return b" " * 4 # special flag to indicate this actor not ready
|
||||
return _NOT_READY_FLAG_
|
||||
result = self.reader_client.on_reader_message_sync(buffer)
|
||||
return result.to_pybytes()
|
||||
|
||||
def on_writer_message(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
"""Called by downstream queue reader to send notify message to
|
||||
upstream queue writer.
|
||||
"""
|
||||
self.writer_client.on_writer_message(buffer)
|
||||
|
||||
def on_writer_message_sync(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
"""Called by downstream queue reader to send control message to
|
||||
upstream queue writer.
|
||||
"""
|
||||
if self.writer_client is None:
|
||||
return b" " * 4 # special flag to indicate this actor not ready
|
||||
return _NOT_READY_FLAG_
|
||||
result = self.writer_client.on_writer_message_sync(buffer)
|
||||
return result.to_pybytes()
|
||||
|
||||
@@ -14,9 +14,9 @@ class MapFunc(function.MapFunction):
|
||||
|
||||
|
||||
def test_load_function():
|
||||
# function_bytes, module_name, class_name, function_name,
|
||||
# function_bytes, module_name, function_name/class_name,
|
||||
# function_interface
|
||||
descriptor_func_bytes = gateway_client.serialize(
|
||||
[None, __name__, MapFunc.__name__, None, "MapFunction"])
|
||||
[None, __name__, MapFunc.__name__, "MapFunction"])
|
||||
func = function.load_function(descriptor_func_bytes)
|
||||
assert type(func) is MapFunc
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import json
|
||||
import ray
|
||||
from ray.streaming import StreamingContext
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
|
||||
def map_func1(x):
|
||||
print("HybridStreamTest map_func1", x)
|
||||
return str(x)
|
||||
|
||||
|
||||
def filter_func1(x):
|
||||
print("HybridStreamTest filter_func1", x)
|
||||
return "b" not in x
|
||||
|
||||
|
||||
def sink_func1(x):
|
||||
print("HybridStreamTest sink_func1 value:", x)
|
||||
|
||||
|
||||
def test_hybrid_stream():
|
||||
subprocess.check_call(
|
||||
["bazel", "build", "//streaming/java:all_streaming_tests_deploy.jar"])
|
||||
current_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
jar_path = os.path.join(
|
||||
current_dir,
|
||||
"../../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar")
|
||||
jar_path = os.path.abspath(jar_path)
|
||||
print("jar_path", jar_path)
|
||||
java_worker_options = json.dumps(["-classpath", jar_path])
|
||||
print("java_worker_options", java_worker_options)
|
||||
assert not ray.is_initialized()
|
||||
ray.init(
|
||||
load_code_from_local=True,
|
||||
include_java=True,
|
||||
java_worker_options=java_worker_options,
|
||||
_internal_config=json.dumps({
|
||||
"num_workers_per_process_java": 1
|
||||
}))
|
||||
|
||||
sink_file = "/tmp/ray_streaming_test_hybrid_stream.txt"
|
||||
if os.path.exists(sink_file):
|
||||
os.remove(sink_file)
|
||||
|
||||
def sink_func(x):
|
||||
print("HybridStreamTest", x)
|
||||
with open(sink_file, "a") as f:
|
||||
f.write(str(x))
|
||||
|
||||
ctx = StreamingContext.Builder().build()
|
||||
ctx.from_values("a", "b", "c") \
|
||||
.as_java_stream() \
|
||||
.map("io.ray.streaming.runtime.demo.HybridStreamTest$Mapper1") \
|
||||
.filter("io.ray.streaming.runtime.demo.HybridStreamTest$Filter1") \
|
||||
.as_python_stream() \
|
||||
.sink(sink_func)
|
||||
ctx.submit("HybridStreamTest")
|
||||
import time
|
||||
time.sleep(3)
|
||||
ray.shutdown()
|
||||
with open(sink_file, "r") as f:
|
||||
result = f.read()
|
||||
assert "a" in result
|
||||
assert "b" not in result
|
||||
assert "c" in result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hybrid_stream()
|
||||
@@ -0,0 +1,13 @@
|
||||
from ray.streaming.runtime.serialization import CrossLangSerializer
|
||||
from ray.streaming.message import Record, KeyRecord
|
||||
|
||||
|
||||
def test_serialize():
|
||||
serializer = CrossLangSerializer()
|
||||
record = Record("value")
|
||||
record.stream = "stream1"
|
||||
key_record = KeyRecord("key", "value")
|
||||
key_record.stream = "stream2"
|
||||
assert record == serializer.deserialize(serializer.serialize(record))
|
||||
assert key_record == serializer.\
|
||||
deserialize(serializer.serialize(key_record))
|
||||
@@ -0,0 +1,31 @@
|
||||
import ray
|
||||
from ray.streaming import StreamingContext
|
||||
|
||||
|
||||
def test_data_stream():
|
||||
ray.init(load_code_from_local=True, include_java=True)
|
||||
ctx = StreamingContext.Builder().build()
|
||||
stream = ctx.from_values(1, 2, 3)
|
||||
java_stream = stream.as_java_stream()
|
||||
python_stream = java_stream.as_python_stream()
|
||||
assert stream.get_id() == java_stream.get_id()
|
||||
assert stream.get_id() == python_stream.get_id()
|
||||
python_stream.set_parallelism(10)
|
||||
assert stream.get_parallelism() == java_stream.get_parallelism()
|
||||
assert stream.get_parallelism() == python_stream.get_parallelism()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
def test_key_data_stream():
|
||||
ray.init(load_code_from_local=True, include_java=True)
|
||||
ctx = StreamingContext.Builder().build()
|
||||
key_stream = ctx.from_values(
|
||||
"a", "b", "c").map(lambda x: (x, 1)).key_by(lambda x: x[0])
|
||||
java_stream = key_stream.as_java_stream()
|
||||
python_stream = java_stream.as_python_stream()
|
||||
assert key_stream.get_id() == java_stream.get_id()
|
||||
assert key_stream.get_id() == python_stream.get_id()
|
||||
python_stream.set_parallelism(10)
|
||||
assert key_stream.get_parallelism() == java_stream.get_parallelism()
|
||||
assert key_stream.get_parallelism() == python_stream.get_parallelism()
|
||||
ray.shutdown()
|
||||
@@ -32,7 +32,9 @@ def test_simple_word_count():
|
||||
|
||||
def sink_func(x):
|
||||
with open(sink_file, "a") as f:
|
||||
f.write("{}:{},".format(x[0], x[1]))
|
||||
line = "{}:{},".format(x[0], x[1])
|
||||
print("sink_func", line)
|
||||
f.write(line)
|
||||
|
||||
ctx.from_values("a", "b", "c") \
|
||||
.set_parallelism(1) \
|
||||
|
||||
Reference in New Issue
Block a user