[Streaming] Streaming Cross-Lang API (#7464)

This commit is contained in:
chaokunyang
2020-04-29 13:42:08 +08:00
committed by GitHub
parent 101255f782
commit 91f630f709
72 changed files with 1612 additions and 408 deletions
+38 -8
View File
@@ -1,10 +1,13 @@
import logging
import pickle
import typing
from abc import ABC, abstractmethod
from ray import Language
from ray.actor import ActorHandle
from ray.streaming import function
from ray.streaming import message
from ray.streaming import partition
from ray.streaming.runtime import serialization
from ray.streaming.runtime.transfer import ChannelID, DataWriter
logger = logging.getLogger(__name__)
@@ -31,19 +34,46 @@ class CollectionCollector(Collector):
class OutputCollector(Collector):
def __init__(self, channel_ids: typing.List[str], writer: DataWriter,
def __init__(self, writer: DataWriter, channel_ids: typing.List[str],
target_actors: typing.List[ActorHandle],
partition_func: partition.Partition):
self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
self._writer = writer
self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
self._target_languages = []
for actor in target_actors:
if actor._ray_actor_language == Language.PYTHON:
self._target_languages.append(function.Language.PYTHON)
elif actor._ray_actor_language == Language.JAVA:
self._target_languages.append(function.Language.JAVA)
else:
raise Exception("Unsupported language {}"
.format(actor._ray_actor_language))
self._partition_func = partition_func
self.python_serializer = serialization.PythonSerializer()
self.cross_lang_serializer = serialization.CrossLangSerializer()
logger.info(
"Create OutputCollector, channel_ids {}, partition_func {}".format(
channel_ids, partition_func))
def collect(self, record):
partitions = self._partition_func.partition(record,
len(self._channel_ids))
serialized_message = pickle.dumps(record)
partitions = self._partition_func \
.partition(record, len(self._channel_ids))
python_buffer = None
cross_lang_buffer = None
for partition_index in partitions:
self._writer.write(self._channel_ids[partition_index],
serialized_message)
if self._target_languages[partition_index] == \
function.Language.PYTHON:
# avoid repeated serialization
if python_buffer is None:
python_buffer = self.python_serializer.serialize(record)
self._writer.write(
self._channel_ids[partition_index],
serialization._PYTHON_TYPE_ID + python_buffer)
else:
# avoid repeated serialization
if cross_lang_buffer is None:
cross_lang_buffer = self.cross_lang_serializer.serialize(
record)
self._writer.write(
self._channel_ids[partition_index],
serialization._CROSS_LANG_TYPE_ID + cross_lang_buffer)
+197 -6
View File
@@ -1,4 +1,4 @@
from abc import ABC
from abc import ABC, abstractmethod
from ray.streaming import function
from ray.streaming import partition
@@ -19,7 +19,6 @@ class Stream(ABC):
self.streaming_context = input_stream.streaming_context
else:
self.streaming_context = streaming_context
self.parallelism = 1
def get_streaming_context(self):
return self.streaming_context
@@ -29,7 +28,8 @@ class Stream(ABC):
Returns:
the parallelism of this transformation
"""
return self.parallelism
return self._gateway_client(). \
call_method(self._j_stream, "getParallelism")
def set_parallelism(self, parallelism: int):
"""Sets the parallelism of this transformation
@@ -40,7 +40,6 @@ class Stream(ABC):
Returns:
self
"""
self.parallelism = parallelism
self._gateway_client(). \
call_method(self._j_stream, "setParallelism", parallelism)
return self
@@ -60,6 +59,10 @@ class Stream(ABC):
return self._gateway_client(). \
call_method(self._j_stream, "getId")
@abstractmethod
def get_language(self):
pass
def _gateway_client(self):
return self.get_streaming_context()._gateway_client
@@ -75,6 +78,9 @@ class DataStream(Stream):
super().__init__(
input_stream, j_stream, streaming_context=streaming_context)
def get_language(self):
return function.Language.PYTHON
def map(self, func):
"""
Applies a Map transformation on a :class:`DataStream`.
@@ -158,6 +164,7 @@ class DataStream(Stream):
Returns:
A KeyDataStream
"""
self._check_partition_call()
if not isinstance(func, function.KeyFunction):
func = function.SimpleKeyFunction(func)
j_func = self._gateway_client().create_py_func(
@@ -175,6 +182,7 @@ class DataStream(Stream):
Returns:
The DataStream with broadcast partitioning set.
"""
self._check_partition_call()
self._gateway_client().call_method(self._j_stream, "broadcast")
return self
@@ -191,6 +199,7 @@ class DataStream(Stream):
Returns:
The DataStream with specified partitioning set.
"""
self._check_partition_call()
if not isinstance(partition_func, partition.Partition):
partition_func = partition.SimplePartition(partition_func)
j_partition = self._gateway_client().create_py_func(
@@ -199,6 +208,16 @@ class DataStream(Stream):
call_method(self._j_stream, "partitionBy", j_partition)
return self
def _check_partition_call(self):
"""
If parent stream is a java stream, we can't call partition related
methods in the python stream
"""
if self.input_stream is not None and \
self.input_stream.get_language() == function.Language.JAVA:
raise Exception("Partition related methods can't be called on a "
"python stream if parent stream is a java stream.")
def sink(self, func):
"""
Create a StreamSink with the given sink.
@@ -217,8 +236,97 @@ class DataStream(Stream):
call_method(self._j_stream, "sink", j_func)
return StreamSink(self, j_stream, func)
def as_java_stream(self):
"""
Convert this stream as a java JavaDataStream.
The converted stream and this stream are the same logical stream,
which has same stream id. Changes in converted stream will be reflected
in this stream and vice versa.
"""
j_stream = self._gateway_client(). \
call_method(self._j_stream, "asJavaStream")
return JavaDataStream(self, j_stream)
class KeyDataStream(Stream):
class JavaDataStream(Stream):
"""
Represents a stream of data which applies a transformation executed by
java. It's also a wrapper of java
`org.ray.streaming.api.stream.DataStream`
"""
def __init__(self, input_stream, j_stream, streaming_context=None):
super().__init__(
input_stream, j_stream, streaming_context=streaming_context)
def get_language(self):
return function.Language.JAVA
def map(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.map"""
return JavaDataStream(self, self._unary_call("map", java_func_class))
def flat_map(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.flatMap"""
return JavaDataStream(self, self._unary_call("flatMap",
java_func_class))
def filter(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.filter"""
return JavaDataStream(self, self._unary_call("filter",
java_func_class))
def key_by(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.keyBy"""
self._check_partition_call()
return JavaKeyDataStream(self,
self._unary_call("keyBy", java_func_class))
def broadcast(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.broadcast"""
self._check_partition_call()
return JavaDataStream(self,
self._unary_call("broadcast", java_func_class))
def partition_by(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.partitionBy"""
self._check_partition_call()
return JavaDataStream(self,
self._unary_call("partitionBy", java_func_class))
def sink(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.sink"""
return JavaStreamSink(self, self._unary_call("sink", java_func_class))
def as_python_stream(self):
"""
Convert this stream as a python DataStream.
The converted stream and this stream are the same logical stream,
which has same stream id. Changes in converted stream will be reflected
in this stream and vice versa.
"""
j_stream = self._gateway_client(). \
call_method(self._j_stream, "asPythonStream")
return DataStream(self, j_stream)
def _check_partition_call(self):
"""
If parent stream is a python stream, we can't call partition related
methods in the java stream
"""
if self.input_stream is not None and \
self.input_stream.get_language() == function.Language.PYTHON:
raise Exception("Partition related methods can't be called on a"
"java stream if parent stream is a python stream.")
def _unary_call(self, func_name, java_func_class):
j_func = self._gateway_client().new_instance(java_func_class)
j_stream = self._gateway_client(). \
call_method(self._j_stream, func_name, j_func)
return j_stream
class KeyDataStream(DataStream):
"""Represents a DataStream returned by a key-by operation.
Wrapper of java io.ray.streaming.python.stream.PythonKeyDataStream
"""
@@ -251,6 +359,43 @@ class KeyDataStream(Stream):
call_method(self._j_stream, "reduce", j_func)
return DataStream(self, j_stream)
def as_java_stream(self):
"""
Convert this stream as a java KeyDataStream.
The converted stream and this stream are the same logical stream,
which has same stream id. Changes in converted stream will be reflected
in this stream and vice versa.
"""
j_stream = self._gateway_client(). \
call_method(self._j_stream, "asJavaStream")
return JavaKeyDataStream(self, j_stream)
class JavaKeyDataStream(JavaDataStream):
"""
Represents a DataStream returned by a key-by operation in java.
Wrapper of org.ray.streaming.api.stream.KeyDataStream
"""
def __init__(self, input_stream, j_stream):
super().__init__(input_stream, j_stream)
def reduce(self, java_func_class):
"""See org.ray.streaming.api.stream.KeyDataStream.reduce"""
return JavaDataStream(self,
super()._unary_call("reduce", java_func_class))
def as_python_stream(self):
"""
Convert this stream as a python KeyDataStream.
The converted stream and this stream are the same logical stream,
which has same stream id. Changes in converted stream will be reflected
in this stream and vice versa.
"""
j_stream = self._gateway_client(). \
call_method(self._j_stream, "asPythonStream")
return KeyDataStream(self, j_stream)
class StreamSource(DataStream):
"""Represents a source of the DataStream.
@@ -261,9 +406,12 @@ class StreamSource(DataStream):
super().__init__(None, j_stream, streaming_context=streaming_context)
self.source_func = source_func
def get_language(self):
return function.Language.PYTHON
@staticmethod
def build_source(streaming_context, func):
"""Build a StreamSource source from a collection.
"""Build a StreamSource source from a source function.
Args:
streaming_context: Stream context
func: A instance of `SourceFunction`
@@ -275,6 +423,34 @@ class StreamSource(DataStream):
return StreamSource(j_stream, streaming_context, func)
class JavaStreamSource(JavaDataStream):
"""Represents a source of the java DataStream.
Wrapper of java org.ray.streaming.api.stream.DataStreamSource
"""
def __init__(self, j_stream, streaming_context):
super().__init__(None, j_stream, streaming_context=streaming_context)
def get_language(self):
return function.Language.JAVA
@staticmethod
def build_source(streaming_context, java_source_func_class):
"""Build a java StreamSource source from a java source function.
Args:
streaming_context: Stream context
java_source_func_class: qualified class name of java SourceFunction
Returns:
A java StreamSource
"""
j_func = streaming_context._gateway_client() \
.new_instance(java_source_func_class)
j_stream = streaming_context._gateway_client() \
.call_function("org.ray.streaming.api.stream.DataStreamSource"
"fromSource", streaming_context._j_ctx, j_func)
return JavaStreamSource(j_stream, streaming_context)
class StreamSink(Stream):
"""Represents a sink of the DataStream.
Wrapper of java io.ray.streaming.python.stream.PythonStreamSink
@@ -282,3 +458,18 @@ class StreamSink(Stream):
def __init__(self, input_stream, j_stream, func):
super().__init__(input_stream, j_stream)
def get_language(self):
return function.Language.PYTHON
class JavaStreamSink(Stream):
"""Represents a sink of the java DataStream.
Wrapper of java org.ray.streaming.api.stream.StreamSink
"""
def __init__(self, input_stream, j_stream):
super().__init__(input_stream, j_stream)
def get_language(self):
return function.Language.JAVA
+25 -10
View File
@@ -1,13 +1,19 @@
import enum
import importlib
import inspect
import sys
from abc import ABC, abstractmethod
import typing
from abc import ABC, abstractmethod
from ray import cloudpickle
from ray.streaming.runtime import gateway_client
class Language(enum.Enum):
JAVA = 0
PYTHON = 1
class Function(ABC):
"""The base interface for all user-defined functions."""
@@ -60,6 +66,7 @@ class MapFunction(Function):
for each input element.
"""
@abstractmethod
def map(self, value):
pass
@@ -70,6 +77,7 @@ class FlatMapFunction(Function):
transform them into zero, one, or more elements.
"""
@abstractmethod
def flat_map(self, value, collector):
"""Takes an element from the input data set and transforms it into zero,
one, or more elements.
@@ -87,6 +95,7 @@ class FilterFunction(Function):
The predicate decides whether to keep the element, or to discard it.
"""
@abstractmethod
def filter(self, value):
"""The filter function that evaluates the predicate.
@@ -106,6 +115,7 @@ class KeyFunction(Function):
deterministic key for that object.
"""
@abstractmethod
def key_by(self, value):
"""User-defined function that deterministically extracts the key from
an object.
@@ -126,6 +136,7 @@ class ReduceFunction(Function):
them into one.
"""
@abstractmethod
def reduce(self, old_value, new_value):
"""
The core method of ReduceFunction, combining two values into one value
@@ -145,6 +156,7 @@ class ReduceFunction(Function):
class SinkFunction(Function):
"""Interface for implementing user defined sink functionality."""
@abstractmethod
def sink(self, value):
"""Writes the given value to the sink. This function is called for
every record."""
@@ -283,7 +295,8 @@ def load_function(descriptor_func_bytes: bytes):
Returns:
a streaming function
"""
function_bytes, module_name, class_name, function_name, function_interface\
assert len(descriptor_func_bytes) > 0
function_bytes, module_name, function_name, function_interface\
= gateway_client.deserialize(descriptor_func_bytes)
if function_bytes:
return deserialize(function_bytes)
@@ -292,16 +305,18 @@ def load_function(descriptor_func_bytes: bytes):
assert function_interface
function_interface = getattr(sys.modules[__name__], function_interface)
mod = importlib.import_module(module_name)
if class_name:
assert function_name is None
cls = getattr(mod, class_name)
assert issubclass(cls, function_interface)
return cls()
else:
assert function_name
func = getattr(mod, function_name)
assert function_name
func = getattr(mod, function_name)
# If func is a python function, user function is a simple python
# function, which will be wrapped as a SimpleXXXFunction.
# If func is a python class, user function is a sub class
# of XXXFunction.
if inspect.isfunction(func):
simple_func_class = _get_simple_function_class(function_interface)
return simple_func_class(func)
else:
assert issubclass(func, function_interface)
return func()
def _get_simple_function_class(function_interface):
+17
View File
@@ -8,6 +8,14 @@ class Record:
def __repr__(self):
return "Record(%s)".format(self.value)
def __eq__(self, other):
if type(self) is type(other):
return (self.stream, self.value) == (other.stream, other.value)
return False
def __hash__(self):
return hash((self.stream, self.value))
class KeyRecord(Record):
"""Data record in a keyed data stream"""
@@ -15,3 +23,12 @@ class KeyRecord(Record):
def __init__(self, key, value):
super().__init__(value)
self.key = key
def __eq__(self, other):
if type(self) is type(other):
return (self.stream, self.key, self.value) ==\
(other.stream, other.key, other.value)
return False
def __hash__(self):
return hash((self.stream, self.key, self.value))
+12 -11
View File
@@ -1,4 +1,5 @@
import importlib
import inspect
from abc import ABC, abstractmethod
from ray import cloudpickle
@@ -96,22 +97,22 @@ def load_partition(descriptor_partition_bytes: bytes):
Returns:
partition function
"""
partition_bytes, module_name, class_name, function_name =\
assert len(descriptor_partition_bytes) > 0
partition_bytes, module_name, function_name =\
gateway_client.deserialize(descriptor_partition_bytes)
if partition_bytes:
return deserialize(partition_bytes)
else:
assert module_name
mod = importlib.import_module(module_name)
# If class_name is not None, user partition is a sub class
# of Partition.
# If function_name is not None, user partition is a simple python
assert function_name
func = getattr(mod, function_name)
# If func is a python function, user partition is a simple python
# function, which will be wrapped as a SimplePartition.
if class_name:
assert function_name is None
cls = getattr(mod, class_name)
return cls()
else:
assert function_name
func = getattr(mod, function_name)
# If func is a python class, user partition is a sub class
# of Partition.
if inspect.isfunction(func):
return SimplePartition(func)
else:
assert issubclass(func, Partition)
return func()
@@ -55,6 +55,11 @@ class GatewayClient:
call = self._python_gateway_actor.callMethod.remote(java_params)
return deserialize(ray.get(call))
def new_instance(self, java_class_name):
call = self._python_gateway_actor.newInstance.remote(
serialize(java_class_name))
return deserialize(ray.get(call))
def serialize(obj) -> bytes:
"""Serialize a python object which can be deserialized by `PythonGateway`
+3 -1
View File
@@ -53,7 +53,9 @@ class ExecutionEdge:
self.src_node_id = edge_pb.src_node_id
self.target_node_id = edge_pb.target_node_id
partition_bytes = edge_pb.partition
if language == Language.PYTHON:
# Sink node doesn't have partition function,
# so we only deserialize partition_bytes when it's not None or empty
if language == Language.PYTHON and partition_bytes:
self.partition = partition.load_partition(partition_bytes)
+57
View File
@@ -0,0 +1,57 @@
from abc import ABC, abstractmethod
import pickle
import msgpack
from ray.streaming import message
_RECORD_TYPE_ID = 0
_KEY_RECORD_TYPE_ID = 1
_CROSS_LANG_TYPE_ID = b"0"
_JAVA_TYPE_ID = b"1"
_PYTHON_TYPE_ID = b"2"
class Serializer(ABC):
@abstractmethod
def serialize(self, obj):
pass
@abstractmethod
def deserialize(self, serialized_bytes):
pass
class PythonSerializer(Serializer):
def serialize(self, obj):
return pickle.dumps(obj)
def deserialize(self, serialized_bytes):
return pickle.loads(serialized_bytes)
class CrossLangSerializer(Serializer):
"""Serialize stream element between java/python"""
def serialize(self, obj):
if type(obj) is message.Record:
fields = [_RECORD_TYPE_ID, obj.stream, obj.value]
elif type(obj) is message.KeyRecord:
fields = [_KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value]
else:
raise Exception("Unsupported value {}".format(obj))
return msgpack.packb(fields, use_bin_type=True)
def deserialize(self, data):
fields = msgpack.unpackb(data, raw=False)
if fields[0] == _RECORD_TYPE_ID:
stream, value = fields[1:]
record = message.Record(value)
record.stream = stream
return record
elif fields[0] == _KEY_RECORD_TYPE_ID:
stream, key, value = fields[1:]
key_record = message.KeyRecord(key, value)
key_record.stream = stream
return key_record
else:
raise Exception("Unsupported type id {}, type {}".format(
fields[0], type(fields[0])))
+29 -17
View File
@@ -1,11 +1,13 @@
import logging
import pickle
import threading
from abc import ABC, abstractmethod
from ray.streaming.collector import OutputCollector
from ray.streaming.config import Config
from ray.streaming.context import RuntimeContextImpl
from ray.streaming.runtime import serialization
from ray.streaming.runtime.serialization import \
PythonSerializer, CrossLangSerializer
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
logger = logging.getLogger(__name__)
@@ -38,36 +40,40 @@ class StreamTask(ABC):
# writers
collectors = []
for edge in execution_node.output_edges:
output_actor_ids = {}
output_actors_map = {}
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
edge.target_node_id)
for target_task_id, target_actor in task_id2_worker.items():
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
execution_graph.build_time())
output_actor_ids[channel_name] = target_actor
if len(output_actor_ids) > 0:
channel_ids = list(output_actor_ids.keys())
to_actor_ids = list(output_actor_ids.values())
writer = DataWriter(channel_ids, to_actor_ids, channel_conf)
logger.info("Create DataWriter succeed.")
output_actors_map[channel_name] = target_actor
if len(output_actors_map) > 0:
channel_ids = list(output_actors_map.keys())
target_actors = list(output_actors_map.values())
logger.info(
"Create DataWriter channel_ids {}, target_actors {}."
.format(channel_ids, target_actors))
writer = DataWriter(channel_ids, target_actors, channel_conf)
self.writers[edge] = writer
collectors.append(
OutputCollector(channel_ids, writer, edge.partition))
OutputCollector(writer, channel_ids, target_actors,
edge.partition))
# readers
input_actor_ids = {}
input_actor_map = {}
for edge in execution_node.input_edges:
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
edge.src_node_id)
for src_task_id, src_actor in task_id2_worker.items():
channel_name = ChannelID.gen_id(src_task_id, self.task_id,
execution_graph.build_time())
input_actor_ids[channel_name] = src_actor
if len(input_actor_ids) > 0:
channel_ids = list(input_actor_ids.keys())
from_actor_ids = list(input_actor_ids.values())
logger.info("Create DataReader, channels {}.".format(channel_ids))
self.reader = DataReader(channel_ids, from_actor_ids, channel_conf)
input_actor_map[channel_name] = src_actor
if len(input_actor_map) > 0:
channel_ids = list(input_actor_map.keys())
from_actors = list(input_actor_map.values())
logger.info("Create DataReader, channels {}, input_actors {}."
.format(channel_ids, from_actors))
self.reader = DataReader(channel_ids, from_actors, channel_conf)
def exit_handler():
# Make DataReader stop read data when MockQueue destructor
@@ -111,6 +117,8 @@ class InputStreamTask(StreamTask):
self.read_timeout_millis = \
int(worker.config.get(Config.READ_TIMEOUT_MS,
Config.DEFAULT_READ_TIMEOUT_MS))
self.python_serializer = PythonSerializer()
self.cross_lang_serializer = CrossLangSerializer()
def init(self):
pass
@@ -120,7 +128,11 @@ class InputStreamTask(StreamTask):
item = self.reader.read(self.read_timeout_millis)
if item is not None:
msg_data = item.body()
msg = pickle.loads(msg_data)
type_id = msg_data[:1]
if (type_id == serialization._PYTHON_TYPE_ID):
msg = self.python_serializer.deserialize(msg_data[1:])
else:
msg = self.cross_lang_serializer.deserialize(msg_data[1:])
self.processor.process(msg)
self.stopped = True
+8 -4
View File
@@ -147,13 +147,17 @@ class ChannelCreationParametersBuilder:
wrap initial parameters needed by a streaming queue
"""
_java_reader_async_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker", "onReaderMessage", "([B)V")
"io.ray.streaming.runtime.worker.JobWorker", "onReaderMessage",
"([B)V")
_java_reader_sync_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker", "onReaderMessageSync", "([B)[B")
"io.ray.streaming.runtime.worker.JobWorker", "onReaderMessageSync",
"([B)[B")
_java_writer_async_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker", "onWriterMessage", "([B)V")
"io.ray.streaming.runtime.worker.JobWorker", "onWriterMessage",
"([B)V")
_java_writer_sync_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker", "onWriterMessageSync", "([B)[B")
"io.ray.streaming.runtime.worker.JobWorker", "onWriterMessageSync",
"([B)[B")
_python_reader_async_function_descriptor = PythonFunctionDescriptor(
"ray.streaming.runtime.worker", "on_reader_message", "JobWorker")
_python_reader_sync_function_descriptor = PythonFunctionDescriptor(
+17 -6
View File
@@ -10,6 +10,9 @@ from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
logger = logging.getLogger(__name__)
# special flag to indicate this actor not ready
_NOT_READY_FLAG_ = b" " * 4
@ray.remote
class JobWorker(object):
@@ -66,23 +69,31 @@ class JobWorker(object):
type(self.stream_processor))
def on_reader_message(self, buffer: bytes):
"""used in direct call mode"""
"""Called by upstream queue writer to send data message to downstream
queue reader.
"""
self.reader_client.on_reader_message(buffer)
def on_reader_message_sync(self, buffer: bytes):
"""used in direct call mode"""
"""Called by upstream queue writer to send control message to downstream
downstream queue reader.
"""
if self.reader_client is None:
return b" " * 4 # special flag to indicate this actor not ready
return _NOT_READY_FLAG_
result = self.reader_client.on_reader_message_sync(buffer)
return result.to_pybytes()
def on_writer_message(self, buffer: bytes):
"""used in direct call mode"""
"""Called by downstream queue reader to send notify message to
upstream queue writer.
"""
self.writer_client.on_writer_message(buffer)
def on_writer_message_sync(self, buffer: bytes):
"""used in direct call mode"""
"""Called by downstream queue reader to send control message to
upstream queue writer.
"""
if self.writer_client is None:
return b" " * 4 # special flag to indicate this actor not ready
return _NOT_READY_FLAG_
result = self.writer_client.on_writer_message_sync(buffer)
return result.to_pybytes()
+2 -2
View File
@@ -14,9 +14,9 @@ class MapFunc(function.MapFunction):
def test_load_function():
# function_bytes, module_name, class_name, function_name,
# function_bytes, module_name, function_name/class_name,
# function_interface
descriptor_func_bytes = gateway_client.serialize(
[None, __name__, MapFunc.__name__, None, "MapFunction"])
[None, __name__, MapFunc.__name__, "MapFunction"])
func = function.load_function(descriptor_func_bytes)
assert type(func) is MapFunc
@@ -0,0 +1,70 @@
import json
import ray
from ray.streaming import StreamingContext
import subprocess
import os
def map_func1(x):
print("HybridStreamTest map_func1", x)
return str(x)
def filter_func1(x):
print("HybridStreamTest filter_func1", x)
return "b" not in x
def sink_func1(x):
print("HybridStreamTest sink_func1 value:", x)
def test_hybrid_stream():
subprocess.check_call(
["bazel", "build", "//streaming/java:all_streaming_tests_deploy.jar"])
current_dir = os.path.abspath(os.path.dirname(__file__))
jar_path = os.path.join(
current_dir,
"../../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar")
jar_path = os.path.abspath(jar_path)
print("jar_path", jar_path)
java_worker_options = json.dumps(["-classpath", jar_path])
print("java_worker_options", java_worker_options)
assert not ray.is_initialized()
ray.init(
load_code_from_local=True,
include_java=True,
java_worker_options=java_worker_options,
_internal_config=json.dumps({
"num_workers_per_process_java": 1
}))
sink_file = "/tmp/ray_streaming_test_hybrid_stream.txt"
if os.path.exists(sink_file):
os.remove(sink_file)
def sink_func(x):
print("HybridStreamTest", x)
with open(sink_file, "a") as f:
f.write(str(x))
ctx = StreamingContext.Builder().build()
ctx.from_values("a", "b", "c") \
.as_java_stream() \
.map("io.ray.streaming.runtime.demo.HybridStreamTest$Mapper1") \
.filter("io.ray.streaming.runtime.demo.HybridStreamTest$Filter1") \
.as_python_stream() \
.sink(sink_func)
ctx.submit("HybridStreamTest")
import time
time.sleep(3)
ray.shutdown()
with open(sink_file, "r") as f:
result = f.read()
assert "a" in result
assert "b" not in result
assert "c" in result
if __name__ == "__main__":
test_hybrid_stream()
@@ -0,0 +1,13 @@
from ray.streaming.runtime.serialization import CrossLangSerializer
from ray.streaming.message import Record, KeyRecord
def test_serialize():
serializer = CrossLangSerializer()
record = Record("value")
record.stream = "stream1"
key_record = KeyRecord("key", "value")
key_record.stream = "stream2"
assert record == serializer.deserialize(serializer.serialize(record))
assert key_record == serializer.\
deserialize(serializer.serialize(key_record))
+31
View File
@@ -0,0 +1,31 @@
import ray
from ray.streaming import StreamingContext
def test_data_stream():
ray.init(load_code_from_local=True, include_java=True)
ctx = StreamingContext.Builder().build()
stream = ctx.from_values(1, 2, 3)
java_stream = stream.as_java_stream()
python_stream = java_stream.as_python_stream()
assert stream.get_id() == java_stream.get_id()
assert stream.get_id() == python_stream.get_id()
python_stream.set_parallelism(10)
assert stream.get_parallelism() == java_stream.get_parallelism()
assert stream.get_parallelism() == python_stream.get_parallelism()
ray.shutdown()
def test_key_data_stream():
ray.init(load_code_from_local=True, include_java=True)
ctx = StreamingContext.Builder().build()
key_stream = ctx.from_values(
"a", "b", "c").map(lambda x: (x, 1)).key_by(lambda x: x[0])
java_stream = key_stream.as_java_stream()
python_stream = java_stream.as_python_stream()
assert key_stream.get_id() == java_stream.get_id()
assert key_stream.get_id() == python_stream.get_id()
python_stream.set_parallelism(10)
assert key_stream.get_parallelism() == java_stream.get_parallelism()
assert key_stream.get_parallelism() == python_stream.get_parallelism()
ray.shutdown()
+3 -1
View File
@@ -32,7 +32,9 @@ def test_simple_word_count():
def sink_func(x):
with open(sink_file, "a") as f:
f.write("{}:{},".format(x[0], x[1]))
line = "{}:{},".format(x[0], x[1])
print("sink_func", line)
f.write(line)
ctx.from_values("a", "b", "c") \
.set_parallelism(1) \