From bcdbe2d3d4a55f5de978b8b7a6da235a0b288672 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Wed, 27 May 2020 18:36:07 +0800 Subject: [PATCH] Streaming rich function (#8602) --- .../streaming/api/context/RuntimeContext.java | 14 +++-- .../streaming/api/function/RichFunction.java | 24 ++++++++ .../api/function/internal/Functions.java | 40 +++++++++++++ .../io/ray/streaming/api/stream/Stream.java | 22 +++++++ .../streaming/jobgraph/JobGraphBuilder.java | 1 + .../io/ray/streaming/jobgraph/JobVertex.java | 19 ++++-- .../streaming/operator/StreamOperator.java | 11 +++- .../java/io/ray/streaming/util/Config.java | 11 ---- .../runtime/core/graph/ExecutionNode.java | 9 ++- .../runtime/python/PythonGateway.java | 38 ++++++++++-- .../runtime/schedule/TaskAssignerImpl.java | 2 +- .../worker/context/RayRuntimeContext.java | 30 ++++------ .../streaming/runtime/demo/WordCountTest.java | 1 - .../streamingqueue/StreamingQueueTest.java | 1 - streaming/python/context.py | 26 ++++++++- streaming/python/datastream.py | 58 ++++++++++++++----- streaming/python/function.py | 6 +- streaming/python/operator.py | 3 +- streaming/python/runtime/task.py | 1 + streaming/python/tests/test_stream.py | 18 ++++++ 20 files changed, 264 insertions(+), 71 deletions(-) create mode 100644 streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/RichFunction.java create mode 100644 streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/Functions.java diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/RuntimeContext.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/RuntimeContext.java index ce92142fb..0ff16330c 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/RuntimeContext.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/RuntimeContext.java @@ -20,14 +20,20 @@ public interface RuntimeContext { int getParallelism(); + /** + * @return config of current function + */ + Map getConfig(); + + /** + * @return config of the job + */ + Map getJobConfig(); + Long getCheckpointId(); void setCheckpointId(long checkpointId); - Long getMaxBatch(); - - Map getConfig(); - void setCurrentKey(Object key); KeyStateBackend getKeyStateBackend(); diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/RichFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/RichFunction.java new file mode 100644 index 000000000..8ae2d05fa --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/RichFunction.java @@ -0,0 +1,24 @@ +package io.ray.streaming.api.function; + +import io.ray.streaming.api.context.RuntimeContext; + +/** + * An interface for all user-defined functions to define the life cycle methods of the + * functions, and access the task context where the functions get executed. + */ +public interface RichFunction extends Function { + + /** + * Initialization method for user function which called before the first call to the user + * function. + * @param runtimeContext runtime context + */ + void open(RuntimeContext runtimeContext); + + /** + * Tear-down method for the user function which called after the last call to + * the user function. + */ + void close(); + +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/Functions.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/Functions.java new file mode 100644 index 000000000..3472da79e --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/Functions.java @@ -0,0 +1,40 @@ +package io.ray.streaming.api.function.internal; + +import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.api.function.Function; +import io.ray.streaming.api.function.RichFunction; + +/** + * A util class for {@link Function} + */ +public class Functions { + + private static class DefaultRichFunction implements RichFunction { + private final Function function; + + private DefaultRichFunction(Function function) { + this.function = function; + } + + @Override + public void open(RuntimeContext runtimeContext) { + } + + @Override + public void close() { + } + + public Function getFunction() { + return function; + } + } + + public static RichFunction wrap(Function function) { + if (function instanceof RichFunction) { + return (RichFunction) function; + } else { + return new DefaultRichFunction(function); + } + } + +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java index 4c74780cd..241cef4f5 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java @@ -9,6 +9,8 @@ import io.ray.streaming.operator.Operator; import io.ray.streaming.operator.StreamOperator; import io.ray.streaming.python.PythonPartition; import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; /** * Abstract base class of all stream types. @@ -23,6 +25,7 @@ public abstract class Stream, T> private final Stream inputStream; private final StreamOperator operator; private int parallelism = 1; + private Map config = new HashMap<>(); private Partition partition; private Stream originalStream; @@ -134,6 +137,25 @@ public abstract class Stream, T> return self(); } + public S withConfig(Map config) { + config.forEach(this::withConfig); + return self(); + } + + public S withConfig(String key, String value) { + if (isProxyStream()) { + originalStream.withConfig(key, value); + } else { + this.config.put(key, value); + } + return self(); + } + + @SuppressWarnings("unchecked") + public Map getConfig() { + return isProxyStream() ? originalStream.getConfig() : config; + } + public boolean isProxyStream() { return originalStream != null; } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java index d0f6a7dc3..30e26ce9b 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java @@ -77,6 +77,7 @@ public class JobGraphBuilder { } else { throw new UnsupportedOperationException("Unsupported stream: " + stream); } + jobVertex.setConfig(stream.getConfig()); this.jobGraph.addVertex(jobVertex); } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobVertex.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobVertex.java index 6ab13cb66..98bb14b62 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobVertex.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobVertex.java @@ -4,20 +4,23 @@ import com.google.common.base.MoreObjects; import io.ray.streaming.api.Language; import io.ray.streaming.operator.StreamOperator; import java.io.Serializable; +import java.util.Map; /** * Job vertex is a cell node where logic is executed. */ public class JobVertex implements Serializable { - private int vertexId; private int parallelism; private VertexType vertexType; private Language language; private StreamOperator streamOperator; + private Map config; - public JobVertex(int vertexId, int parallelism, VertexType vertexType, - StreamOperator streamOperator) { + public JobVertex(int vertexId, + int parallelism, + VertexType vertexType, + StreamOperator streamOperator) { this.vertexId = vertexId; this.parallelism = parallelism; this.vertexType = vertexType; @@ -45,6 +48,14 @@ public class JobVertex implements Serializable { return language; } + public Map getConfig() { + return config; + } + + public void setConfig(Map config) { + this.config = config; + } + @Override public String toString() { return MoreObjects.toStringHelper(this) @@ -53,7 +64,7 @@ public class JobVertex implements Serializable { .add("vertexType", vertexType) .add("language", language) .add("streamOperator", streamOperator) + .add("config", config) .toString(); } - } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java index 3ae688ccf..4160c736f 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java @@ -4,25 +4,30 @@ import io.ray.streaming.api.Language; import io.ray.streaming.api.collector.Collector; import io.ray.streaming.api.context.RuntimeContext; import io.ray.streaming.api.function.Function; +import io.ray.streaming.api.function.RichFunction; +import io.ray.streaming.api.function.internal.Functions; import io.ray.streaming.message.KeyRecord; import io.ray.streaming.message.Record; import java.util.List; public abstract class StreamOperator implements Operator { - protected String name; - protected F function; + protected final String name; + protected final F function; + protected final RichFunction richFunction; protected List collectorList; protected RuntimeContext runtimeContext; public StreamOperator(F function) { this.name = getClass().getSimpleName(); this.function = function; + this.richFunction = Functions.wrap(function); } @Override public void open(List collectorList, RuntimeContext runtimeContext) { this.collectorList = collectorList; this.runtimeContext = runtimeContext; + richFunction.open(runtimeContext); } @Override @@ -32,7 +37,7 @@ public abstract class StreamOperator implements Operator { @Override public void close() { - + richFunction.close(); } @Override diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/util/Config.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/util/Config.java index 9862c1f92..b998ddf9f 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/util/Config.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/util/Config.java @@ -2,17 +2,6 @@ package io.ray.streaming.util; public class Config { - /** - * Maximum number of batches to run in a streaming job. - */ - public static final String STREAMING_BATCH_MAX_COUNT = "streaming.batch.max.count"; - - /** - * batch frequency in milliseconds - */ - public static final String STREAMING_BATCH_FREQUENCY = "streaming.batch.frequency"; - public static final long STREAMING_BATCH_FREQUENCY_DEFAULT = 1000; - public static final String STREAMING_JOB_NAME = "streaming.job.name"; public static final String STREAMING_OP_NAME = "streaming.op_name"; public static final String STREAMING_WORKER_NAME = "streaming.worker_name"; diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/ExecutionNode.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/ExecutionNode.java index c4e29b4e8..409d7d49d 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/ExecutionNode.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/ExecutionNode.java @@ -6,6 +6,7 @@ import io.ray.streaming.operator.StreamOperator; import java.io.Serializable; import java.util.ArrayList; import java.util.List; +import java.util.Map; /** * A node in the physical execution graph. @@ -13,15 +14,17 @@ import java.util.List; public class ExecutionNode implements Serializable { private int nodeId; private int parallelism; + private Map config; private NodeType nodeType; private StreamOperator streamOperator; private List executionTasks; private List inputsEdges; private List outputEdges; - public ExecutionNode(int nodeId, int parallelism) { + public ExecutionNode(int nodeId, int parallelism, Map config) { this.nodeId = nodeId; this.parallelism = parallelism; + this.config = config; this.executionTasks = new ArrayList<>(); this.inputsEdges = new ArrayList<>(); this.outputEdges = new ArrayList<>(); @@ -43,6 +46,10 @@ public class ExecutionNode implements Serializable { this.parallelism = parallelism; } + public Map getConfig() { + return config; + } + public List getExecutionTasks() { return executionTasks; } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java index 826f1c935..b5ca58e78 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java @@ -31,13 +31,12 @@ import org.slf4j.LoggerFactory; public class PythonGateway { private static final Logger LOG = LoggerFactory.getLogger(PythonGateway.class); private static final String REFERENCE_ID_PREFIX = "__gateway_reference_id__"; + private static MsgPackSerializer serializer = new MsgPackSerializer(); - private MsgPackSerializer serializer; private Map referenceMap; private StreamingContext streamingContext; public PythonGateway() { - serializer = new MsgPackSerializer(); referenceMap = new HashMap<>(); LOG.info("PythonGateway created"); } @@ -156,8 +155,23 @@ public class PythonGateway { .map((Function) Primitives::unwrap) .toArray(Class[]::new); Optional any = methods.stream() - .filter(m -> Arrays.equals(m.getParameterTypes(), paramsTypes) || - Arrays.equals(m.getParameterTypes(), unwrappedTypes)) + .filter(m -> { + boolean exactMatch = Arrays.equals(m.getParameterTypes(), paramsTypes) || + Arrays.equals(m.getParameterTypes(), unwrappedTypes); + if (exactMatch) { + return true; + } else if (paramsTypes.length == m.getParameterTypes().length) { + for (int i = 0; i < m.getParameterTypes().length; i++) { + Class parameterType = m.getParameterTypes()[i]; + if (!parameterType.isAssignableFrom(paramsTypes[i])) { + return false; + } + } + return true; + } else { + return false; + } + }) .findAny(); Preconditions.checkArgument(any.isPresent(), String.format("Method %s with type %s doesn't exist on class %s", @@ -166,7 +180,21 @@ public class PythonGateway { } private static boolean returnReference(Object value) { - return !(value instanceof Number) && !(value instanceof String) && !(value instanceof byte[]); + if (isBasic(value)) { + return false; + } else { + try { + serializer.serialize(value); + return false; + } catch (Exception e) { + return true; + } + } + } + + private static boolean isBasic(Object value) { + return value == null || (value instanceof Boolean) || (value instanceof Number) || + (value instanceof String) || (value instanceof byte[]); } public byte[] newInstance(byte[] classNameBytes) { diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/TaskAssignerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/TaskAssignerImpl.java index 04520b441..1b68f9dde 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/TaskAssignerImpl.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/TaskAssignerImpl.java @@ -38,7 +38,7 @@ public class TaskAssignerImpl implements TaskAssigner { Map idToExecutionNode = new HashMap<>(); for (JobVertex jobVertex : jobVertices) { ExecutionNode executionNode = new ExecutionNode(jobVertex.getVertexId(), - jobVertex.getParallelism()); + jobVertex.getParallelism(), jobVertex.getConfig()); executionNode.setNodeType(jobVertex.getVertexType()); List vertexTasks = new ArrayList<>(); for (int taskIndex = 0; taskIndex < jobVertex.getParallelism(); taskIndex++) { diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/RayRuntimeContext.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/RayRuntimeContext.java index 345f67a38..6a12d1832 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/RayRuntimeContext.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/RayRuntimeContext.java @@ -1,7 +1,5 @@ package io.ray.streaming.runtime.worker.context; -import static io.ray.streaming.util.Config.STREAMING_BATCH_MAX_COUNT; - import com.google.common.base.Preconditions; import io.ray.streaming.api.context.RuntimeContext; import io.ray.streaming.runtime.core.graph.ExecutionTask; @@ -21,8 +19,6 @@ import java.util.Map; * Use Ray to implement RuntimeContext. */ public class RayRuntimeContext implements RuntimeContext { - - private final Long maxBatch; /** * Backend for keyed state. This might be empty if we're not on a keyed stream. */ @@ -43,11 +39,6 @@ public class RayRuntimeContext implements RuntimeContext { this.config = config; this.taskIndex = executionTask.getTaskIndex(); this.parallelism = parallelism; - if (config.containsKey(STREAMING_BATCH_MAX_COUNT)) { - this.maxBatch = Long.valueOf(config.get(STREAMING_BATCH_MAX_COUNT)); - } else { - this.maxBatch = Long.MAX_VALUE; - } } @Override @@ -65,6 +56,16 @@ public class RayRuntimeContext implements RuntimeContext { return parallelism; } + @Override + public Map getConfig() { + return config; + } + + @Override + public Map getJobConfig() { + return config; + } + @Override public Long getCheckpointId() { return checkpointId; @@ -81,17 +82,6 @@ public class RayRuntimeContext implements RuntimeContext { this.checkpointId = checkpointId; } - @Override - public Long getMaxBatch() { - return maxBatch; - } - - @Override - public Map getConfig() { - return config; - } - - @Override public void setCurrentKey(Object key) { this.keyStateBackend.setCurrentKey(key); diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java index 5669ad12f..d88e8a5ec 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java @@ -33,7 +33,6 @@ public class WordCountTest extends BaseUnitTest implements Serializable { Ray.shutdown(); StreamingContext streamingContext = StreamingContext.buildContext(); Map config = new HashMap<>(); - config.put(Config.STREAMING_BATCH_MAX_COUNT, "1"); config.put(Config.CHANNEL_TYPE, Config.MEMORY_CHANNEL); streamingContext.withConfig(config); List text = new ArrayList<>(); diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java index c48293cea..381900332 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java @@ -160,7 +160,6 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable { Map wordCount = new ConcurrentHashMap<>(); StreamingContext streamingContext = StreamingContext.buildContext(); Map config = new HashMap<>(); - config.put(Config.STREAMING_BATCH_MAX_COUNT, "1"); config.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL); config.put(Config.CHANNEL_SIZE, "100000"); streamingContext.withConfig(config); diff --git a/streaming/python/context.py b/streaming/python/context.py index 7ad80a909..8c3b9a5f6 100644 --- a/streaming/python/context.py +++ b/streaming/python/context.py @@ -151,12 +151,30 @@ class RuntimeContext(ABC): """ pass + @abstractmethod + def get_config(self): + """ + Returns: + The config with which the parallel task runs. + """ + pass + + @abstractmethod + def get_job_config(self): + """ + Returns: + The job config. + """ + pass + class RuntimeContextImpl(RuntimeContext): - def __init__(self, task_id, task_index, parallelism): + def __init__(self, task_id, task_index, parallelism, **kargs): self.task_id = task_id self.task_index = task_index self.parallelism = parallelism + self.config = kargs.get("config", {}) + self.job_config = kargs.get("job_config", {}) def get_task_id(self): return self.task_id @@ -166,3 +184,9 @@ class RuntimeContextImpl(RuntimeContext): def get_parallelism(self): return self.parallelism + + def get_config(self): + return self.config + + def get_job_config(self): + return self.job_config diff --git a/streaming/python/datastream.py b/streaming/python/datastream.py index 26297da11..21837e072 100644 --- a/streaming/python/datastream.py +++ b/streaming/python/datastream.py @@ -59,6 +59,38 @@ class Stream(ABC): return self._gateway_client(). \ call_method(self._j_stream, "getId") + def with_config(self, key=None, value=None, conf=None): + """Set stream config. + + Args: + key: a key name string for configuration property + value: a value string for configuration property + conf: multi key-value pairs as a dict + + Returns: + self + """ + if key is not None: + assert value + assert type(key) is str + assert type(value) is str + self._gateway_client(). \ + call_method(self._j_stream, "withConfig", key, value) + if conf is not None: + for k, v in conf.items(): + assert type(k) is str + assert type(v) is str + self._gateway_client(). \ + call_method(self._j_stream, "withConfig", conf) + return self + + def get_config(self): + """ + Returns: + A dict config for this stream + """ + return self._gateway_client().call_method(self._j_stream, "getConfig") + @abstractmethod def get_language(self): pass @@ -252,7 +284,7 @@ class JavaDataStream(Stream): """ Represents a stream of data which applies a transformation executed by java. It's also a wrapper of java - `org.ray.streaming.api.stream.DataStream` + `io.ray.streaming.api.stream.DataStream` """ def __init__(self, input_stream, j_stream, streaming_context=None): @@ -263,39 +295,39 @@ class JavaDataStream(Stream): return function.Language.JAVA def map(self, java_func_class): - """See org.ray.streaming.api.stream.DataStream.map""" + """See io.ray.streaming.api.stream.DataStream.map""" return JavaDataStream(self, self._unary_call("map", java_func_class)) def flat_map(self, java_func_class): - """See org.ray.streaming.api.stream.DataStream.flatMap""" + """See io.ray.streaming.api.stream.DataStream.flatMap""" return JavaDataStream(self, self._unary_call("flatMap", java_func_class)) def filter(self, java_func_class): - """See org.ray.streaming.api.stream.DataStream.filter""" + """See io.ray.streaming.api.stream.DataStream.filter""" return JavaDataStream(self, self._unary_call("filter", java_func_class)) def key_by(self, java_func_class): - """See org.ray.streaming.api.stream.DataStream.keyBy""" + """See io.ray.streaming.api.stream.DataStream.keyBy""" self._check_partition_call() return JavaKeyDataStream(self, self._unary_call("keyBy", java_func_class)) def broadcast(self, java_func_class): - """See org.ray.streaming.api.stream.DataStream.broadcast""" + """See io.ray.streaming.api.stream.DataStream.broadcast""" self._check_partition_call() return JavaDataStream(self, self._unary_call("broadcast", java_func_class)) def partition_by(self, java_func_class): - """See org.ray.streaming.api.stream.DataStream.partitionBy""" + """See io.ray.streaming.api.stream.DataStream.partitionBy""" self._check_partition_call() return JavaDataStream(self, self._unary_call("partitionBy", java_func_class)) def sink(self, java_func_class): - """See org.ray.streaming.api.stream.DataStream.sink""" + """See io.ray.streaming.api.stream.DataStream.sink""" return JavaStreamSink(self, self._unary_call("sink", java_func_class)) def as_python_stream(self): @@ -374,14 +406,14 @@ class KeyDataStream(DataStream): class JavaKeyDataStream(JavaDataStream): """ Represents a DataStream returned by a key-by operation in java. - Wrapper of org.ray.streaming.api.stream.KeyDataStream + Wrapper of io.ray.streaming.api.stream.KeyDataStream """ def __init__(self, input_stream, j_stream): super().__init__(input_stream, j_stream) def reduce(self, java_func_class): - """See org.ray.streaming.api.stream.KeyDataStream.reduce""" + """See io.ray.streaming.api.stream.KeyDataStream.reduce""" return JavaDataStream(self, super()._unary_call("reduce", java_func_class)) @@ -425,7 +457,7 @@ class StreamSource(DataStream): class JavaStreamSource(JavaDataStream): """Represents a source of the java DataStream. - Wrapper of java org.ray.streaming.api.stream.DataStreamSource + Wrapper of java io.ray.streaming.api.stream.DataStreamSource """ def __init__(self, j_stream, streaming_context): @@ -446,7 +478,7 @@ class JavaStreamSource(JavaDataStream): j_func = streaming_context._gateway_client() \ .new_instance(java_source_func_class) j_stream = streaming_context._gateway_client() \ - .call_function("org.ray.streaming.api.stream.DataStreamSource" + .call_function("io.ray.streaming.api.stream.DataStreamSource" "fromSource", streaming_context._j_ctx, j_func) return JavaStreamSource(j_stream, streaming_context) @@ -465,7 +497,7 @@ class StreamSink(Stream): class JavaStreamSink(Stream): """Represents a sink of the java DataStream. - Wrapper of java org.ray.streaming.api.stream.StreamSink + Wrapper of java io.ray.streaming.api.stream.StreamSink """ def __init__(self, input_stream, j_stream): diff --git a/streaming/python/function.py b/streaming/python/function.py index 8d38ae6bc..b4d4d2383 100644 --- a/streaming/python/function.py +++ b/streaming/python/function.py @@ -2,7 +2,6 @@ import enum import importlib import inspect import sys -import typing from abc import ABC, abstractmethod from ray import cloudpickle @@ -17,7 +16,7 @@ class Language(enum.Enum): class Function(ABC): """The base interface for all user-defined functions.""" - def open(self, conf: typing.Dict[str, str]): + def open(self, runtime_context): pass def close(self): @@ -55,9 +54,6 @@ class SourceFunction(Function): """ pass - def close(self): - pass - class MapFunction(Function): """ diff --git a/streaming/python/operator.py b/streaming/python/operator.py index e03be9707..d6937543f 100644 --- a/streaming/python/operator.py +++ b/streaming/python/operator.py @@ -71,12 +71,13 @@ class StreamOperator(Operator, ABC): def open(self, collectors, runtime_context): self.collectors = collectors self.runtime_context = runtime_context + self.func.open(runtime_context) def finish(self): pass def close(self): - pass + self.func.close() def collect(self, record): for collector in self.collectors: diff --git a/streaming/python/runtime/task.py b/streaming/python/runtime/task.py index c207c4727..df878d624 100644 --- a/streaming/python/runtime/task.py +++ b/streaming/python/runtime/task.py @@ -83,6 +83,7 @@ class StreamTask(ABC): import atexit atexit.register(exit_handler) + # TODO(chaokunyang) add task/job config runtime_context = RuntimeContextImpl( self.worker.execution_task.task_id, self.worker.execution_task.task_index, execution_node.parallelism) diff --git a/streaming/python/tests/test_stream.py b/streaming/python/tests/test_stream.py index 8eb0fbe6a..c88d5b933 100644 --- a/streaming/python/tests/test_stream.py +++ b/streaming/python/tests/test_stream.py @@ -29,3 +29,21 @@ def test_key_data_stream(): assert key_stream.get_parallelism() == java_stream.get_parallelism() assert key_stream.get_parallelism() == python_stream.get_parallelism() ray.shutdown() + + +def test_stream_config(): + ray.init(load_code_from_local=True, include_java=True) + ctx = StreamingContext.Builder().build() + stream = ctx.from_values(1, 2, 3) + stream.with_config("k1", "v1") + print("config", stream.get_config()) + assert stream.get_config() == {"k1": "v1"} + stream.with_config(conf={"k2": "v2", "k3": "v3"}) + print("config", stream.get_config()) + assert stream.get_config() == {"k1": "v1", "k2": "v2", "k3": "v3"} + java_stream = stream.as_java_stream() + java_stream.with_config(conf={"k4": "v4"}) + config = java_stream.get_config() + print("config", config) + assert config == {"k1": "v1", "k2": "v2", "k3": "v3", "k4": "v4"} + ray.shutdown()