Streaming rich function (#8602)

This commit is contained in:
chaokunyang
2020-05-27 18:36:07 +08:00
committed by GitHub
parent bd4fbcd7fc
commit bcdbe2d3d4
20 changed files with 264 additions and 71 deletions
+25 -1
View File
@@ -151,12 +151,30 @@ class RuntimeContext(ABC):
"""
pass
@abstractmethod
def get_config(self):
"""
Returns:
The config with which the parallel task runs.
"""
pass
@abstractmethod
def get_job_config(self):
"""
Returns:
The job config.
"""
pass
class RuntimeContextImpl(RuntimeContext):
def __init__(self, task_id, task_index, parallelism):
def __init__(self, task_id, task_index, parallelism, **kargs):
self.task_id = task_id
self.task_index = task_index
self.parallelism = parallelism
self.config = kargs.get("config", {})
self.job_config = kargs.get("job_config", {})
def get_task_id(self):
return self.task_id
@@ -166,3 +184,9 @@ class RuntimeContextImpl(RuntimeContext):
def get_parallelism(self):
return self.parallelism
def get_config(self):
return self.config
def get_job_config(self):
return self.job_config
+45 -13
View File
@@ -59,6 +59,38 @@ class Stream(ABC):
return self._gateway_client(). \
call_method(self._j_stream, "getId")
def with_config(self, key=None, value=None, conf=None):
"""Set stream config.
Args:
key: a key name string for configuration property
value: a value string for configuration property
conf: multi key-value pairs as a dict
Returns:
self
"""
if key is not None:
assert value
assert type(key) is str
assert type(value) is str
self._gateway_client(). \
call_method(self._j_stream, "withConfig", key, value)
if conf is not None:
for k, v in conf.items():
assert type(k) is str
assert type(v) is str
self._gateway_client(). \
call_method(self._j_stream, "withConfig", conf)
return self
def get_config(self):
"""
Returns:
A dict config for this stream
"""
return self._gateway_client().call_method(self._j_stream, "getConfig")
@abstractmethod
def get_language(self):
pass
@@ -252,7 +284,7 @@ class JavaDataStream(Stream):
"""
Represents a stream of data which applies a transformation executed by
java. It's also a wrapper of java
`org.ray.streaming.api.stream.DataStream`
`io.ray.streaming.api.stream.DataStream`
"""
def __init__(self, input_stream, j_stream, streaming_context=None):
@@ -263,39 +295,39 @@ class JavaDataStream(Stream):
return function.Language.JAVA
def map(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.map"""
"""See io.ray.streaming.api.stream.DataStream.map"""
return JavaDataStream(self, self._unary_call("map", java_func_class))
def flat_map(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.flatMap"""
"""See io.ray.streaming.api.stream.DataStream.flatMap"""
return JavaDataStream(self, self._unary_call("flatMap",
java_func_class))
def filter(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.filter"""
"""See io.ray.streaming.api.stream.DataStream.filter"""
return JavaDataStream(self, self._unary_call("filter",
java_func_class))
def key_by(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.keyBy"""
"""See io.ray.streaming.api.stream.DataStream.keyBy"""
self._check_partition_call()
return JavaKeyDataStream(self,
self._unary_call("keyBy", java_func_class))
def broadcast(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.broadcast"""
"""See io.ray.streaming.api.stream.DataStream.broadcast"""
self._check_partition_call()
return JavaDataStream(self,
self._unary_call("broadcast", java_func_class))
def partition_by(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.partitionBy"""
"""See io.ray.streaming.api.stream.DataStream.partitionBy"""
self._check_partition_call()
return JavaDataStream(self,
self._unary_call("partitionBy", java_func_class))
def sink(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.sink"""
"""See io.ray.streaming.api.stream.DataStream.sink"""
return JavaStreamSink(self, self._unary_call("sink", java_func_class))
def as_python_stream(self):
@@ -374,14 +406,14 @@ class KeyDataStream(DataStream):
class JavaKeyDataStream(JavaDataStream):
"""
Represents a DataStream returned by a key-by operation in java.
Wrapper of org.ray.streaming.api.stream.KeyDataStream
Wrapper of io.ray.streaming.api.stream.KeyDataStream
"""
def __init__(self, input_stream, j_stream):
super().__init__(input_stream, j_stream)
def reduce(self, java_func_class):
"""See org.ray.streaming.api.stream.KeyDataStream.reduce"""
"""See io.ray.streaming.api.stream.KeyDataStream.reduce"""
return JavaDataStream(self,
super()._unary_call("reduce", java_func_class))
@@ -425,7 +457,7 @@ class StreamSource(DataStream):
class JavaStreamSource(JavaDataStream):
"""Represents a source of the java DataStream.
Wrapper of java org.ray.streaming.api.stream.DataStreamSource
Wrapper of java io.ray.streaming.api.stream.DataStreamSource
"""
def __init__(self, j_stream, streaming_context):
@@ -446,7 +478,7 @@ class JavaStreamSource(JavaDataStream):
j_func = streaming_context._gateway_client() \
.new_instance(java_source_func_class)
j_stream = streaming_context._gateway_client() \
.call_function("org.ray.streaming.api.stream.DataStreamSource"
.call_function("io.ray.streaming.api.stream.DataStreamSource"
"fromSource", streaming_context._j_ctx, j_func)
return JavaStreamSource(j_stream, streaming_context)
@@ -465,7 +497,7 @@ class StreamSink(Stream):
class JavaStreamSink(Stream):
"""Represents a sink of the java DataStream.
Wrapper of java org.ray.streaming.api.stream.StreamSink
Wrapper of java io.ray.streaming.api.stream.StreamSink
"""
def __init__(self, input_stream, j_stream):
+1 -5
View File
@@ -2,7 +2,6 @@ import enum
import importlib
import inspect
import sys
import typing
from abc import ABC, abstractmethod
from ray import cloudpickle
@@ -17,7 +16,7 @@ class Language(enum.Enum):
class Function(ABC):
"""The base interface for all user-defined functions."""
def open(self, conf: typing.Dict[str, str]):
def open(self, runtime_context):
pass
def close(self):
@@ -55,9 +54,6 @@ class SourceFunction(Function):
"""
pass
def close(self):
pass
class MapFunction(Function):
"""
+2 -1
View File
@@ -71,12 +71,13 @@ class StreamOperator(Operator, ABC):
def open(self, collectors, runtime_context):
self.collectors = collectors
self.runtime_context = runtime_context
self.func.open(runtime_context)
def finish(self):
pass
def close(self):
pass
self.func.close()
def collect(self, record):
for collector in self.collectors:
+1
View File
@@ -83,6 +83,7 @@ class StreamTask(ABC):
import atexit
atexit.register(exit_handler)
# TODO(chaokunyang) add task/job config
runtime_context = RuntimeContextImpl(
self.worker.execution_task.task_id,
self.worker.execution_task.task_index, execution_node.parallelism)
+18
View File
@@ -29,3 +29,21 @@ def test_key_data_stream():
assert key_stream.get_parallelism() == java_stream.get_parallelism()
assert key_stream.get_parallelism() == python_stream.get_parallelism()
ray.shutdown()
def test_stream_config():
ray.init(load_code_from_local=True, include_java=True)
ctx = StreamingContext.Builder().build()
stream = ctx.from_values(1, 2, 3)
stream.with_config("k1", "v1")
print("config", stream.get_config())
assert stream.get_config() == {"k1": "v1"}
stream.with_config(conf={"k2": "v2", "k3": "v3"})
print("config", stream.get_config())
assert stream.get_config() == {"k1": "v1", "k2": "v2", "k3": "v3"}
java_stream = stream.as_java_stream()
java_stream.with_config(conf={"k4": "v4"})
config = java_stream.get_config()
print("config", config)
assert config == {"k1": "v1", "k2": "v2", "k3": "v3", "k4": "v4"}
ray.shutdown()