From 8b6784de061a313da9d5be5fe8448862d048b3ef Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 25 Feb 2020 10:33:33 +0800 Subject: [PATCH] [Streaming] Streaming Python API (#6755) --- ci/travis/install-dependencies.sh | 5 + java/BUILD.bazel | 1 + python/setup.py | 2 + streaming/BUILD.bazel | 37 +- streaming/java/BUILD.bazel | 11 +- streaming/java/dependencies.bzl | 1 + .../java/org/ray/streaming/api/Language.java | 6 + .../api/context/StreamingContext.java | 4 + .../streaming/jobgraph/JobGraphBuilder.java | 4 +- .../org/ray/streaming/jobgraph/JobVertex.java | 8 + .../ray/streaming/jobgraph/VertexType.java | 1 - .../org/ray/streaming/operator/Operator.java | 6 + .../ray/streaming/operator/OperatorType.java | 1 - .../streaming/operator/StreamOperator.java | 12 +- .../ray/streaming/python/PythonFunction.java | 37 +- .../ray/streaming/python/PythonOperator.java | 6 +- .../ray/streaming/python/PythonPartition.java | 15 + .../python/stream/PythonDataStream.java | 4 + .../python/stream/PythonKeyDataStream.java | 8 +- .../python/stream/PythonStreamSink.java | 2 +- streaming/java/streaming-runtime/pom.xml | 5 + .../runtime/cluster/ResourceManager.java | 23 - .../runtime/core/graph/ExecutionGraph.java | 17 +- .../runtime/core/graph/ExecutionNode.java | 12 +- .../runtime/core/graph/ExecutionTask.java | 9 +- .../runtime/python/GraphPbBuilder.java | 101 +++ .../runtime/python/MsgPackSerializer.java | 119 +++ .../runtime/python/PythonGateway.java | 152 ++++ .../runtime/schedule/JobSchedulerImpl.java | 61 +- .../runtime/schedule/TaskAssigner.java | 5 +- .../runtime/schedule/TaskAssignerImpl.java | 28 +- .../runtime/util/ReflectionUtils.java | 86 +++ .../streaming/runtime/worker/JobWorker.java | 1 + .../runtime/worker/tasks/StreamTask.java | 4 +- .../runtime/python/MsgPackSerializerTest.java | 39 + .../runtime/python/PythonGatewayTest.java | 48 ++ .../schedule/TaskAssignerImplTest.java | 29 +- .../runtime/util/ReflectionUtilsTest.java | 38 + streaming/java/test.sh | 3 +- streaming/python/README.rst | 16 - streaming/python/__init__.py | 3 + streaming/python/collector.py | 49 ++ streaming/python/communication.py | 279 ------- streaming/python/config.py | 3 +- streaming/python/context.py | 168 +++++ streaming/python/datastream.py | 284 ++++++++ streaming/python/examples/key_selectors.py | 67 -- streaming/python/examples/simple.py | 52 -- streaming/python/examples/toy.txt | 5 - streaming/python/examples/wordcount.py | 52 +- streaming/python/function.py | 315 ++++++++ streaming/python/includes/transfer.pxi | 6 +- streaming/python/jobworker.py | 120 --- streaming/python/message.py | 17 + streaming/python/operator.py | 324 +++++--- streaming/python/partition.py | 117 +++ streaming/python/processor.py | 222 ------ streaming/python/runtime/gateway_client.py | 67 ++ streaming/python/runtime/graph.py | 102 +++ streaming/python/runtime/processor.py | 113 +++ streaming/python/runtime/task.py | 158 ++++ streaming/python/runtime/worker.py | 104 +++ streaming/python/streaming.py | 689 ------------------ streaming/python/tests/test_function.py | 22 + streaming/python/tests/test_logical_graph.py | 206 ------ streaming/python/tests/test_operator.py | 8 + streaming/python/tests/test_word_count.py | 25 +- streaming/src/config/streaming_config.cc | 4 +- streaming/src/config/streaming_config.h | 5 +- streaming/src/protobuf/remote_call.proto | 59 ++ streaming/src/protobuf/streaming.proto | 17 +- 71 files changed, 2701 insertions(+), 1928 deletions(-) create mode 100644 streaming/java/streaming-api/src/main/java/org/ray/streaming/api/Language.java delete mode 100644 streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/cluster/ResourceManager.java create mode 100644 streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/python/GraphPbBuilder.java create mode 100644 streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/python/MsgPackSerializer.java create mode 100644 streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/python/PythonGateway.java create mode 100644 streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/ReflectionUtils.java create mode 100644 streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/python/MsgPackSerializerTest.java create mode 100644 streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/python/PythonGatewayTest.java create mode 100644 streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/util/ReflectionUtilsTest.java delete mode 100644 streaming/python/README.rst create mode 100644 streaming/python/collector.py delete mode 100644 streaming/python/communication.py create mode 100644 streaming/python/context.py create mode 100644 streaming/python/datastream.py delete mode 100644 streaming/python/examples/key_selectors.py delete mode 100644 streaming/python/examples/simple.py delete mode 100644 streaming/python/examples/toy.txt create mode 100644 streaming/python/function.py delete mode 100644 streaming/python/jobworker.py create mode 100644 streaming/python/message.py create mode 100644 streaming/python/partition.py delete mode 100644 streaming/python/processor.py create mode 100644 streaming/python/runtime/gateway_client.py create mode 100644 streaming/python/runtime/graph.py create mode 100644 streaming/python/runtime/processor.py create mode 100644 streaming/python/runtime/task.py create mode 100644 streaming/python/runtime/worker.py delete mode 100644 streaming/python/streaming.py create mode 100644 streaming/python/tests/test_function.py delete mode 100644 streaming/python/tests/test_logical_graph.py create mode 100644 streaming/python/tests/test_operator.py create mode 100644 streaming/src/protobuf/remote_call.proto diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index e5991eabb..c2ff89d76 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -89,6 +89,11 @@ if [[ "$RLLIB_TESTING" == "1" ]]; then gym[atari] atari_py smart_open lz4 fi +# Additional streaming dependencies. +if [[ "$RAY_CI_STREAMING_PYTHON_AFFECTED" == "1" ]]; then + pip install -q msgpack>=0.6.2 +fi + if [[ "$PYTHON" == "3.6" ]] || [[ "$MAC_WHEELS" == "1" ]]; then # Install the latest version of Node.js in order to build the dashboard. source "$HOME/.nvm/nvm.sh" diff --git a/java/BUILD.bazel b/java/BUILD.bazel index cddf543ef..9baaa3f27 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -231,6 +231,7 @@ genrule( srcs = [ "//java:ray_dist_deploy.jar", "//java:gen_maven_deps", + "//streaming/java:gen_maven_deps", ], outs = ["ray_java_pkg.out"], cmd = """ diff --git a/python/setup.py b/python/setup.py index 445dd4242..5a014fa88 100644 --- a/python/setup.py +++ b/python/setup.py @@ -86,6 +86,8 @@ extras["rllib"] = extras["tune"] + [ "scipy", ] +extras["streaming"] = ["msgpack >= 0.6.2"] + extras["all"] = list(set(chain.from_iterable(extras.values()))) diff --git a/streaming/BUILD.bazel b/streaming/BUILD.bazel index 875a2f913..c458389a3 100644 --- a/streaming/BUILD.bazel +++ b/streaming/BUILD.bazel @@ -10,16 +10,23 @@ proto_library( visibility = ["//visibility:public"], ) -cc_proto_library( - name = "streaming_cc_proto", - deps = [":streaming_proto"], -) - proto_library( name = "streaming_queue_proto", srcs = ["src/protobuf/streaming_queue.proto"], ) +proto_library( + name = "remote_call_proto", + srcs = ["src/protobuf/remote_call.proto"], + visibility = ["//visibility:public"], + deps = ["streaming_proto"], +) + +cc_proto_library( + name = "streaming_cc_proto", + deps = [":streaming_proto"], +) + cc_proto_library( name = "streaming_queue_cc_proto", deps = ["streaming_queue_proto"], @@ -231,10 +238,23 @@ python_proto_compile( deps = ["//streaming:streaming_proto"], ) +python_proto_compile( + name = "remote_call_py_proto", + deps = ["//streaming:remote_call_proto"], +) + +filegroup( + name = "all_py_proto", + srcs = [ + ":remote_call_py_proto", + ":streaming_py_proto", + ], +) + genrule( name = "copy_streaming_py_proto", srcs = [ - ":streaming_py_proto", + ":all_py_proto", ], outs = [ "copy_streaming_py_proto.out", @@ -248,9 +268,10 @@ genrule( rm -rf "$$GENERATED_DIR" mkdir -p "$$GENERATED_DIR" touch "$$GENERATED_DIR/__init__.py" - for f in $(locations //streaming:streaming_py_proto); do - cp "$$f" "$$GENERATED_DIR" + for f in $(locations //streaming:all_py_proto); do + cp -f "$$f" "$$GENERATED_DIR" done + sed -i -E 's/from streaming.src.protobuf/from ./' "$$GENERATED_DIR/remote_call_pb2.py" date > $@ """, local = 1, diff --git a/streaming/java/BUILD.bazel b/streaming/java/BUILD.bazel index c740546e7..53331edcd 100644 --- a/streaming/java/BUILD.bazel +++ b/streaming/java/BUILD.bazel @@ -102,6 +102,7 @@ define_java_module( ":org_ray_ray_streaming-runtime", "@ray_streaming_maven//:com_google_guava_guava", "@ray_streaming_maven//:de_ruedigermoeller_fst", + "@ray_streaming_maven//:org_msgpack_msgpack_core", "@ray_streaming_maven//:org_aeonbits_owner_owner", "@ray_streaming_maven//:org_slf4j_slf4j_api", "@ray_streaming_maven//:org_slf4j_slf4j_log4j12", @@ -117,6 +118,7 @@ define_java_module( "@ray_streaming_maven//:com_google_protobuf_protobuf_java", "@ray_streaming_maven//:de_ruedigermoeller_fst", "@ray_streaming_maven//:org_aeonbits_owner_owner", + "@ray_streaming_maven//:org_msgpack_msgpack_core", "@ray_streaming_maven//:org_slf4j_slf4j_api", "@ray_streaming_maven//:org_slf4j_slf4j_log4j12", ], @@ -143,9 +145,15 @@ java_proto_compile( deps = ["//streaming:streaming_proto"], ) +java_proto_compile( + name = "remote_call_java_proto", + deps = ["//streaming:remote_call_proto"], +) + filegroup( name = "all_java_proto", srcs = [ + ":remote_call_java_proto", ":streaming_java_proto", ], ) @@ -183,7 +191,7 @@ genrule( mkdir -p "$$GENERATED_DIR" # Copy protobuf-generated files. for f in $(locations //streaming/java:all_java_proto); do - unzip "$$f" -x META-INF/MANIFEST.MF -d "$$WORK_DIR/streaming/java/streaming-runtime/src/main/java" + unzip -o "$$f" -x META-INF/MANIFEST.MF -d "$$WORK_DIR/streaming/java/streaming-runtime/src/main/java" done date > $@ """, @@ -214,4 +222,5 @@ genrule( """, local = 1, tags = ["no-cache"], + visibility = ["//visibility:public"], ) diff --git a/streaming/java/dependencies.bzl b/streaming/java/dependencies.bzl index b3b327a27..cfcdab4c4 100644 --- a/streaming/java/dependencies.bzl +++ b/streaming/java/dependencies.bzl @@ -14,6 +14,7 @@ def gen_streaming_java_deps(): "org.slf4j:slf4j-log4j12:1.7.25", "org.apache.logging.log4j:log4j-core:2.8.2", "org.testng:testng:6.9.10", + "org.msgpack:msgpack-core:0.8.20", ], repositories = [ "https://repo1.maven.org/maven2/", diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/Language.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/Language.java new file mode 100644 index 000000000..80f254f43 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/Language.java @@ -0,0 +1,6 @@ +package org.ray.streaming.api; + +public enum Language { + JAVA, + PYTHON +} diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/context/StreamingContext.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/context/StreamingContext.java index 20aa57d60..cee30fddb 100644 --- a/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/context/StreamingContext.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/context/StreamingContext.java @@ -70,6 +70,10 @@ public class StreamingContext implements Serializable { streamSinks.add(streamSink); } + public List getStreamSinks() { + return streamSinks; + } + public void withConfig(Map jobConfig) { this.jobConfig = jobConfig; } diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/jobgraph/JobGraphBuilder.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/jobgraph/JobGraphBuilder.java index c190a4317..aac840dd4 100644 --- a/streaming/java/streaming-api/src/main/java/org/ray/streaming/jobgraph/JobGraphBuilder.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/jobgraph/JobGraphBuilder.java @@ -27,7 +27,7 @@ public class JobGraphBuilder { } public JobGraphBuilder(List streamSinkList, String jobName, - Map jobConfig) { + Map jobConfig) { this.jobGraph = new JobGraph(jobName, jobConfig); this.streamSinkList = streamSinkList; this.edgeIdGenerator = new AtomicInteger(0); @@ -63,6 +63,8 @@ public class JobGraphBuilder { JobEdge jobEdge = new JobEdge(inputVertexId, vertexId, parentStream.getPartition()); this.jobGraph.addEdge(jobEdge); processStream(parentStream); + } else { + throw new UnsupportedOperationException("Unsupported stream: " + stream); } this.jobGraph.addVertex(jobVertex); } diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/jobgraph/JobVertex.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/jobgraph/JobVertex.java index 6fbb6aaec..3fd04608d 100644 --- a/streaming/java/streaming-api/src/main/java/org/ray/streaming/jobgraph/JobVertex.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/jobgraph/JobVertex.java @@ -2,6 +2,7 @@ package org.ray.streaming.jobgraph; import com.google.common.base.MoreObjects; import java.io.Serializable; +import org.ray.streaming.api.Language; import org.ray.streaming.operator.StreamOperator; /** @@ -12,6 +13,7 @@ public class JobVertex implements Serializable { private int vertexId; private int parallelism; private VertexType vertexType; + private Language language; private StreamOperator streamOperator; public JobVertex(int vertexId, int parallelism, VertexType vertexType, @@ -20,6 +22,7 @@ public class JobVertex implements Serializable { this.parallelism = parallelism; this.vertexType = vertexType; this.streamOperator = streamOperator; + this.language = streamOperator.getLanguage(); } public int getVertexId() { @@ -38,12 +41,17 @@ public class JobVertex implements Serializable { return vertexType; } + public Language getLanguage() { + return language; + } + @Override public String toString() { return MoreObjects.toStringHelper(this) .add("vertexId", vertexId) .add("parallelism", parallelism) .add("vertexType", vertexType) + .add("language", language) .add("streamOperator", streamOperator) .toString(); } diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/jobgraph/VertexType.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/jobgraph/VertexType.java index 664b835af..396fe498a 100644 --- a/streaming/java/streaming-api/src/main/java/org/ray/streaming/jobgraph/VertexType.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/jobgraph/VertexType.java @@ -4,7 +4,6 @@ package org.ray.streaming.jobgraph; * Different roles for a node. */ public enum VertexType { - MASTER, SOURCE, TRANSFORMATION, SINK, diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/Operator.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/Operator.java index 39542def8..51f365ddd 100644 --- a/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/Operator.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/Operator.java @@ -2,8 +2,10 @@ package org.ray.streaming.operator; import java.io.Serializable; import java.util.List; +import org.ray.streaming.api.Language; import org.ray.streaming.api.collector.Collector; import org.ray.streaming.api.context.RuntimeContext; +import org.ray.streaming.api.function.Function; public interface Operator extends Serializable { @@ -13,5 +15,9 @@ public interface Operator extends Serializable { void close(); + Function getFunction(); + + Language getLanguage(); + OperatorType getOpType(); } diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/OperatorType.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/OperatorType.java index 840372ad1..be3916846 100644 --- a/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/OperatorType.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/OperatorType.java @@ -1,6 +1,5 @@ package org.ray.streaming.operator; - public enum OperatorType { SOURCE, ONE_INPUT, diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/StreamOperator.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/StreamOperator.java index bdab36a28..dcf89b5ca 100644 --- a/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/StreamOperator.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/StreamOperator.java @@ -1,6 +1,7 @@ package org.ray.streaming.operator; import java.util.List; +import org.ray.streaming.api.Language; import org.ray.streaming.api.collector.Collector; import org.ray.streaming.api.context.RuntimeContext; import org.ray.streaming.api.function.Function; @@ -8,7 +9,6 @@ import org.ray.streaming.message.KeyRecord; import org.ray.streaming.message.Record; public abstract class StreamOperator implements Operator { - protected String name; protected F function; protected List collectorList; @@ -35,6 +35,16 @@ public abstract class StreamOperator implements Operator { } + @Override + public Function getFunction() { + return function; + } + + @Override + public Language getLanguage() { + return Language.JAVA; + } + protected void collect(Record record) { for (Collector collector : this.collectorList) { collector.collect(record); diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/PythonFunction.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/PythonFunction.java index 9751d7176..f446e9738 100644 --- a/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/PythonFunction.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/PythonFunction.java @@ -20,16 +20,19 @@ import org.ray.streaming.api.function.Function; */ public class PythonFunction implements Function { public enum FunctionInterface { - SOURCE_FUNCTION("ray.streaming.function.SourceFunction"), - MAP_FUNCTION("ray.streaming.function.MapFunction"), - FLAT_MAP_FUNCTION("ray.streaming.function.FlatMapFunction"), - FILTER_FUNCTION("ray.streaming.function.FilterFunction"), - KEY_FUNCTION("ray.streaming.function.KeyFunction"), - REDUCE_FUNCTION("ray.streaming.function.ReduceFunction"), - SINK_FUNCTION("ray.streaming.function.SinkFunction"); + SOURCE_FUNCTION("SourceFunction"), + MAP_FUNCTION("MapFunction"), + FLAT_MAP_FUNCTION("FlatMapFunction"), + FILTER_FUNCTION("FilterFunction"), + KEY_FUNCTION("KeyFunction"), + REDUCE_FUNCTION("ReduceFunction"), + SINK_FUNCTION("SinkFunction"); private String functionInterface; + /** + * @param functionInterface function class name in `ray.streaming.function` module. + */ FunctionInterface(String functionInterface) { this.functionInterface = functionInterface; } @@ -59,6 +62,26 @@ public class PythonFunction implements Function { this.functionInterface = functionInterface.functionInterface; } + public byte[] getFunction() { + return function; + } + + public String getModuleName() { + return moduleName; + } + + public String getClassName() { + return className; + } + + public String getFunctionName() { + return functionName; + } + + public String getFunctionInterface() { + return functionInterface; + } + /** * Create a {@link PythonFunction} using python serialized function * diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/PythonOperator.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/PythonOperator.java index 179d6ef2e..1a0a127af 100644 --- a/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/PythonOperator.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/PythonOperator.java @@ -1,6 +1,7 @@ package org.ray.streaming.python; import java.util.List; +import org.ray.streaming.api.Language; import org.ray.streaming.api.context.RuntimeContext; import org.ray.streaming.operator.OperatorType; import org.ray.streaming.operator.StreamOperator; @@ -39,5 +40,8 @@ public class PythonOperator extends StreamOperator { throw new UnsupportedOperationException(msg); } - + @Override + public Language getLanguage() { + return Language.PYTHON; + } } diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/PythonPartition.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/PythonPartition.java index 8067855da..58275f4b6 100644 --- a/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/PythonPartition.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/PythonPartition.java @@ -45,4 +45,19 @@ public class PythonPartition implements Partition { throw new UnsupportedOperationException(msg); } + public byte[] getPartition() { + return partition; + } + + public String getModuleName() { + return moduleName; + } + + public String getClassName() { + return className; + } + + public String getFunctionName() { + return functionName; + } } diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/stream/PythonDataStream.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/stream/PythonDataStream.java index e45afdf5e..f200ab69f 100644 --- a/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/stream/PythonDataStream.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/stream/PythonDataStream.java @@ -17,6 +17,10 @@ public class PythonDataStream extends Stream implements PythonStream { super(streamingContext, pythonOperator); } + public PythonDataStream(PythonDataStream input, PythonOperator pythonOperator) { + super(input, pythonOperator); + } + protected PythonDataStream(Stream inputStream, PythonOperator pythonOperator) { super(inputStream, pythonOperator); } diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/stream/PythonKeyDataStream.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/stream/PythonKeyDataStream.java index 2042e478c..49bb9b11f 100644 --- a/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/stream/PythonKeyDataStream.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/stream/PythonKeyDataStream.java @@ -1,7 +1,5 @@ package org.ray.streaming.python.stream; -import org.ray.streaming.api.stream.Stream; -import org.ray.streaming.operator.StreamOperator; import org.ray.streaming.python.PythonFunction; import org.ray.streaming.python.PythonFunction.FunctionInterface; import org.ray.streaming.python.PythonOperator; @@ -10,10 +8,10 @@ import org.ray.streaming.python.PythonPartition; /** * Represents a python DataStream returned by a key-by operation. */ -public class PythonKeyDataStream extends Stream implements PythonStream { +public class PythonKeyDataStream extends PythonDataStream implements PythonStream { - public PythonKeyDataStream(PythonDataStream input, StreamOperator streamOperator) { - super(input, streamOperator); + public PythonKeyDataStream(PythonDataStream input, PythonOperator pythonOperator) { + super(input, pythonOperator); this.partition = PythonPartition.KeyPartition; } diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/stream/PythonStreamSink.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/stream/PythonStreamSink.java index ef691cf05..e23b96a58 100644 --- a/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/stream/PythonStreamSink.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/python/stream/PythonStreamSink.java @@ -8,7 +8,7 @@ import org.ray.streaming.python.PythonOperator; */ public class PythonStreamSink extends StreamSink implements PythonStream { public PythonStreamSink(PythonDataStream input, PythonOperator sinkOperator) { - super(input, null); + super(input, sinkOperator); this.streamingContext.addSink(this); } diff --git a/streaming/java/streaming-runtime/pom.xml b/streaming/java/streaming-runtime/pom.xml index ff0841f1c..827cb8b9a 100755 --- a/streaming/java/streaming-runtime/pom.xml +++ b/streaming/java/streaming-runtime/pom.xml @@ -56,6 +56,11 @@ owner 1.0.10 + + org.msgpack + msgpack-core + 0.8.20 + org.slf4j slf4j-api diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/cluster/ResourceManager.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/cluster/ResourceManager.java deleted file mode 100644 index 73fe0b621..000000000 --- a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/cluster/ResourceManager.java +++ /dev/null @@ -1,23 +0,0 @@ -package org.ray.streaming.runtime.cluster; - -import java.util.ArrayList; -import java.util.List; -import org.ray.api.Ray; -import org.ray.api.RayActor; -import org.ray.streaming.runtime.worker.JobWorker; - -/** - * Resource-Manager is used to do the management of resources - */ -public class ResourceManager { - - public List> createWorkers(int workerNum) { - List> workers = new ArrayList<>(); - for (int i = 0; i < workerNum; i++) { - RayActor worker = Ray.createActor(JobWorker::new); - workers.add(worker); - } - return workers; - } - -} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionGraph.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionGraph.java index 6dbd204d9..104ed2853 100644 --- a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionGraph.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionGraph.java @@ -7,7 +7,6 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; import org.ray.api.RayActor; -import org.ray.streaming.runtime.worker.JobWorker; /** * Physical execution graph. @@ -19,19 +18,19 @@ import org.ray.streaming.runtime.worker.JobWorker; public class ExecutionGraph implements Serializable { private long buildTime; private List executionNodeList; - private List> sourceWorkers = new ArrayList<>(); - private List> sinkWorkers = new ArrayList<>(); + private List sourceWorkers = new ArrayList<>(); + private List sinkWorkers = new ArrayList<>(); public ExecutionGraph(List executionNodes) { this.executionNodeList = executionNodes; for (ExecutionNode executionNode : executionNodeList) { if (executionNode.getNodeType() == ExecutionNode.NodeType.SOURCE) { - List> actors = executionNode.getExecutionTasks().stream() + List actors = executionNode.getExecutionTasks().stream() .map(ExecutionTask::getWorker).collect(Collectors.toList()); sourceWorkers.addAll(actors); } if (executionNode.getNodeType() == ExecutionNode.NodeType.SINK) { - List> actors = executionNode.getExecutionTasks().stream() + List actors = executionNode.getExecutionTasks().stream() .map(ExecutionTask::getWorker).collect(Collectors.toList()); sinkWorkers.addAll(actors); } @@ -39,11 +38,11 @@ public class ExecutionGraph implements Serializable { buildTime = System.currentTimeMillis(); } - public List> getSourceWorkers() { + public List getSourceWorkers() { return sourceWorkers; } - public List> getSinkWorkers() { + public List getSinkWorkers() { return sinkWorkers; } @@ -82,10 +81,10 @@ public class ExecutionGraph implements Serializable { throw new RuntimeException("Task " + taskId + " does not exist!"); } - public Map> getTaskId2WorkerByNodeId(int nodeId) { + public Map getTaskId2WorkerByNodeId(int nodeId) { for (ExecutionNode executionNode : executionNodeList) { if (executionNode.getNodeId() == nodeId) { - Map> taskId2Worker = new HashMap<>(); + Map taskId2Worker = new HashMap<>(); for (ExecutionTask executionTask : executionNode.getExecutionTasks()) { taskId2Worker.put(executionTask.getTaskId(), executionTask.getWorker()); } diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionNode.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionNode.java index 7c550d885..83595cc26 100644 --- a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionNode.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionNode.java @@ -3,6 +3,7 @@ package org.ray.streaming.runtime.core.graph; import java.io.Serializable; import java.util.ArrayList; import java.util.List; +import org.ray.streaming.api.Language; import org.ray.streaming.jobgraph.VertexType; import org.ray.streaming.operator.StreamOperator; @@ -10,7 +11,6 @@ import org.ray.streaming.operator.StreamOperator; * A node in the physical execution graph. */ public class ExecutionNode implements Serializable { - private int nodeId; private int parallelism; private NodeType nodeType; @@ -59,7 +59,7 @@ public class ExecutionNode implements Serializable { this.outputEdges = outputEdges; } - public void addExecutionEdge(ExecutionEdge executionEdge) { + public void addOutputEdge(ExecutionEdge executionEdge) { this.outputEdges.add(executionEdge); } @@ -79,6 +79,10 @@ public class ExecutionNode implements Serializable { this.streamOperator = streamOperator; } + public Language getLanguage() { + return streamOperator.getLanguage(); + } + public NodeType getNodeType() { return nodeType; } @@ -92,7 +96,7 @@ public class ExecutionNode implements Serializable { this.nodeType = NodeType.SINK; break; default: - this.nodeType = NodeType.PROCESS; + this.nodeType = NodeType.TRANSFORM; } } @@ -109,7 +113,7 @@ public class ExecutionNode implements Serializable { public enum NodeType { SOURCE, - PROCESS, + TRANSFORM, SINK, } } diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionTask.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionTask.java index afc831841..6337bee51 100644 --- a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionTask.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionTask.java @@ -2,7 +2,6 @@ package org.ray.streaming.runtime.core.graph; import java.io.Serializable; import org.ray.api.RayActor; -import org.ray.streaming.runtime.worker.JobWorker; /** * ExecutionTask is minimal execution unit. @@ -12,9 +11,9 @@ import org.ray.streaming.runtime.worker.JobWorker; public class ExecutionTask implements Serializable { private int taskId; private int taskIndex; - private RayActor worker; + private RayActor worker; - public ExecutionTask(int taskId, int taskIndex, RayActor worker) { + public ExecutionTask(int taskId, int taskIndex, RayActor worker) { this.taskId = taskId; this.taskIndex = taskIndex; this.worker = worker; @@ -36,11 +35,11 @@ public class ExecutionTask implements Serializable { this.taskIndex = taskIndex; } - public RayActor getWorker() { + public RayActor getWorker() { return worker; } - public void setWorker(RayActor worker) { + public void setWorker(RayActor worker) { this.worker = worker; } } diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/python/GraphPbBuilder.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/python/GraphPbBuilder.java new file mode 100644 index 000000000..c2d7393eb --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/python/GraphPbBuilder.java @@ -0,0 +1,101 @@ +package org.ray.streaming.runtime.python; + +import com.google.protobuf.ByteString; +import java.util.Arrays; +import org.ray.runtime.actor.NativeRayActor; +import org.ray.streaming.api.function.Function; +import org.ray.streaming.api.partition.Partition; +import org.ray.streaming.python.PythonFunction; +import org.ray.streaming.python.PythonPartition; +import org.ray.streaming.runtime.core.graph.ExecutionEdge; +import org.ray.streaming.runtime.core.graph.ExecutionGraph; +import org.ray.streaming.runtime.core.graph.ExecutionNode; +import org.ray.streaming.runtime.core.graph.ExecutionTask; +import org.ray.streaming.runtime.generated.RemoteCall; +import org.ray.streaming.runtime.generated.Streaming; + +public class GraphPbBuilder { + + private MsgPackSerializer serializer = new MsgPackSerializer(); + + /** + * For simple scenario, a single ExecutionNode is enough. But some cases may need + * sub-graph information, so we serialize entire graph. + */ + public RemoteCall.ExecutionGraph buildExecutionGraphPb(ExecutionGraph graph) { + RemoteCall.ExecutionGraph.Builder builder = RemoteCall.ExecutionGraph.newBuilder(); + builder.setBuildTime(graph.getBuildTime()); + for (ExecutionNode node : graph.getExecutionNodeList()) { + RemoteCall.ExecutionGraph.ExecutionNode.Builder nodeBuilder = + RemoteCall.ExecutionGraph.ExecutionNode.newBuilder(); + nodeBuilder.setNodeId(node.getNodeId()); + nodeBuilder.setParallelism(node.getParallelism()); + nodeBuilder.setNodeType( + Streaming.NodeType.valueOf(node.getNodeType().name())); + nodeBuilder.setLanguage(Streaming.Language.valueOf(node.getLanguage().name())); + byte[] functionBytes = serializeFunction(node.getStreamOperator().getFunction()); + nodeBuilder.setFunction(ByteString.copyFrom(functionBytes)); + + // build tasks + for (ExecutionTask task : node.getExecutionTasks()) { + RemoteCall.ExecutionGraph.ExecutionTask.Builder taskBuilder = + RemoteCall.ExecutionGraph.ExecutionTask.newBuilder(); + byte[] serializedActorHandle = ((NativeRayActor) task.getWorker()).toBytes(); + taskBuilder + .setTaskId(task.getTaskId()) + .setTaskIndex(task.getTaskIndex()) + .setWorkerActor(ByteString.copyFrom(serializedActorHandle)); + nodeBuilder.addExecutionTasks(taskBuilder.build()); + } + + // build edges + for (ExecutionEdge edge : node.getInputsEdges()) { + nodeBuilder.addInputEdges(buildEdgePb(edge)); + } + for (ExecutionEdge edge : node.getOutputEdges()) { + nodeBuilder.addOutputEdges(buildEdgePb(edge)); + } + + builder.addExecutionNodes(nodeBuilder.build()); + } + + return builder.build(); + } + + private RemoteCall.ExecutionGraph.ExecutionEdge buildEdgePb(ExecutionEdge edge) { + RemoteCall.ExecutionGraph.ExecutionEdge.Builder edgeBuilder = + RemoteCall.ExecutionGraph.ExecutionEdge.newBuilder(); + edgeBuilder.setSrcNodeId(edge.getSrcNodeId()); + edgeBuilder.setTargetNodeId(edge.getTargetNodeId()); + edgeBuilder.setPartition(ByteString.copyFrom(serializePartition(edge.getPartition()))); + return edgeBuilder.build(); + } + + private byte[] serializeFunction(Function function) { + if (function instanceof PythonFunction) { + PythonFunction pyFunc = (PythonFunction) function; + // function_bytes, module_name, class_name, function_name, function_interface + return serializer.serialize(Arrays.asList( + pyFunc.getFunction(), pyFunc.getModuleName(), + pyFunc.getClassName(), pyFunc.getFunctionName(), + pyFunc.getFunctionInterface() + )); + } else { + return new byte[0]; + } + } + + private byte[] serializePartition(Partition partition) { + if (partition instanceof PythonPartition) { + PythonPartition pythonPartition = (PythonPartition) partition; + // partition_bytes, module_name, class_name, function_name + return serializer.serialize(Arrays.asList( + pythonPartition.getPartition(), pythonPartition.getModuleName(), + pythonPartition.getClassName(), pythonPartition.getFunctionName() + )); + } else { + return new byte[0]; + } + } + +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/python/MsgPackSerializer.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/python/MsgPackSerializer.java new file mode 100644 index 000000000..1e3b48cb7 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/python/MsgPackSerializer.java @@ -0,0 +1,119 @@ +package org.ray.streaming.runtime.python; + +import com.google.common.io.BaseEncoding; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.msgpack.core.MessageBufferPacker; +import org.msgpack.core.MessagePack; +import org.msgpack.core.MessageUnpacker; +import org.msgpack.value.ArrayValue; +import org.msgpack.value.FloatValue; +import org.msgpack.value.IntegerValue; +import org.msgpack.value.MapValue; +import org.msgpack.value.Value; + +public class MsgPackSerializer { + + public byte[] serialize(Object obj) { + MessageBufferPacker packer = MessagePack.newDefaultBufferPacker(); + serialize(obj, packer); + return packer.toByteArray(); + } + + private void serialize(Object obj, MessageBufferPacker packer) { + try { + if (obj == null) { + packer.packNil(); + } else { + Class clz = obj.getClass(); + if (clz == Boolean.class) { + packer.packBoolean((Boolean) obj); + } else if (clz == Integer.class) { + packer.packInt((Integer) obj); + } else if (clz == Long.class) { + packer.packLong((Long) obj); + } else if (clz == Double.class) { + packer.packDouble((Double) obj); + } else if (clz == byte[].class) { + byte[] bytes = (byte[]) obj; + packer.packBinaryHeader(bytes.length); + packer.writePayload(bytes); + } else if (clz == String.class) { + packer.packString((String) obj); + } else if (obj instanceof Collection) { + Collection collection = (Collection) (obj); + packer.packArrayHeader(collection.size()); + for (Object o : collection) { + serialize(o, packer); + } + } else if (obj instanceof Map) { + Map map = (Map) (obj); + packer.packMapHeader(map.size()); + for (Object o : map.entrySet()) { + Map.Entry e = (Map.Entry) o; + serialize(e.getKey(), packer); + serialize(e.getValue(), packer); + } + } else { + throw new UnsupportedOperationException("Unsupported type " + clz); + } + } + } catch (Exception e) { + throw new RuntimeException("Serialize error for object " + obj, e); + } + } + + public Object deserialize(byte[] bytes) { + try { + MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(bytes); + return deserialize(unpacker.unpackValue()); + } catch (Exception e) { + String hex = BaseEncoding.base16().lowerCase().encode(bytes); + throw new RuntimeException("Deserialize error: " + hex, e); + } + } + + private Object deserialize(Value value) { + switch (value.getValueType()) { + case NIL: + return null; + case BOOLEAN: + return value.asBooleanValue().getBoolean(); + case INTEGER: + IntegerValue iv = value.asIntegerValue(); + if (iv.isInIntRange()) { + return iv.toInt(); + } else if (iv.isInLongRange()) { + return iv.toLong(); + } else { + return iv.toBigInteger(); + } + case FLOAT: + FloatValue fv = value.asFloatValue(); + return fv.toDouble(); + case STRING: + return value.asStringValue().asString(); + case BINARY: + return value.asBinaryValue().asByteArray(); + case ARRAY: + ArrayValue arrayValue = value.asArrayValue(); + List list = new ArrayList<>(arrayValue.size()); + for (Value elem : arrayValue) { + list.add(deserialize(elem)); + } + return list; + case MAP: + MapValue mapValue = value.asMapValue(); + Map map = new HashMap<>(); + for (Map.Entry entry : mapValue.entrySet()) { + map.put(deserialize(entry.getKey()), deserialize(entry.getValue())); + } + return map; + default: + throw new UnsupportedOperationException("Unsupported type " + value.getValueType()); + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/python/PythonGateway.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/python/PythonGateway.java new file mode 100644 index 000000000..42872a93c --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/python/PythonGateway.java @@ -0,0 +1,152 @@ +package org.ray.streaming.runtime.python; + +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.msgpack.core.Preconditions; +import org.ray.api.annotation.RayRemote; +import org.ray.streaming.api.context.StreamingContext; +import org.ray.streaming.python.PythonFunction; +import org.ray.streaming.python.PythonPartition; +import org.ray.streaming.python.stream.PythonStreamSource; +import org.ray.streaming.runtime.util.ReflectionUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Gateway for streaming python api. + * All calls on DataStream in python will be mapped to DataStream call in java by this + * PythonGateway using ray calls. + *

+ * Note: this class needs to be in sync with `GatewayClient` in + * `streaming/python/runtime/gateway_client.py` + */ +@SuppressWarnings("unchecked") +@RayRemote +public class PythonGateway { + private static final Logger LOG = LoggerFactory.getLogger(PythonGateway.class); + private static final String REFERENCE_ID_PREFIX = "__gateway_reference_id__"; + + private MsgPackSerializer serializer; + private Map referenceMap; + private StreamingContext streamingContext; + + public PythonGateway() { + serializer = new MsgPackSerializer(); + referenceMap = new HashMap<>(); + LOG.info("PythonGateway created"); + } + + public byte[] createStreamingContext() { + streamingContext = StreamingContext.buildContext(); + LOG.info("StreamingContext created"); + referenceMap.put(getReferenceId(streamingContext), streamingContext); + return serializer.serialize(getReferenceId(streamingContext)); + } + + public StreamingContext getStreamingContext() { + return streamingContext; + } + + public byte[] withConfig(byte[] confBytes) { + Preconditions.checkNotNull(streamingContext); + try { + Map config = (Map) serializer.deserialize(confBytes); + LOG.info("Set config {}", config); + streamingContext.withConfig(config); + // We can't use `return void`, that will make `ray.get()` hang forever. + // We can't using `return new byte[0]`, that will make `ray::CoreWorker::ExecuteTask` crash. + // So we `return new byte[1]` for method execution success. + // Same for other methods in this class which return new byte[1]. + return new byte[1]; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public byte[] createPythonStreamSource(byte[] pySourceFunc) { + Preconditions.checkNotNull(streamingContext); + try { + PythonStreamSource pythonStreamSource = PythonStreamSource.from( + streamingContext, PythonFunction.fromFunction(pySourceFunc)); + referenceMap.put(getReferenceId(pythonStreamSource), pythonStreamSource); + return serializer.serialize(getReferenceId(pythonStreamSource)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public byte[] execute(byte[] jobNameBytes) { + LOG.info("Starting executing"); + streamingContext.execute((String) serializer.deserialize(jobNameBytes)); + // see `withConfig` method. + return new byte[1]; + } + + public byte[] createPyFunc(byte[] pyFunc) { + PythonFunction function = PythonFunction.fromFunction(pyFunc); + referenceMap.put(getReferenceId(function), function); + return serializer.serialize(getReferenceId(function)); + } + + public byte[] createPyPartition(byte[] pyPartition) { + PythonPartition partition = new PythonPartition(pyPartition); + referenceMap.put(getReferenceId(partition), partition); + return serializer.serialize(getReferenceId(partition)); + } + + public byte[] callFunction(byte[] paramsBytes) { + try { + List params = (List) serializer.deserialize(paramsBytes); + params = processReferenceParameters(params); + LOG.info("callFunction params {}", params); + String className = (String) params.get(0); + String funcName = (String) params.get(1); + Class clz = Class.forName(className, true, this.getClass().getClassLoader()); + Method method = ReflectionUtils.findMethod(clz, funcName); + Object result = method.invoke(null, params.subList(2, params.size()).toArray()); + referenceMap.put(getReferenceId(result), result); + return serializer.serialize(getReferenceId(result)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public byte[] callMethod(byte[] paramsBytes) { + try { + List params = (List) serializer.deserialize(paramsBytes); + params = processReferenceParameters(params); + LOG.info("callMethod params {}", params); + Object obj = params.get(0); + String methodName = (String) params.get(1); + Method method = ReflectionUtils.findMethod(obj.getClass(), methodName); + Object result = method.invoke(obj, params.subList(2, params.size()).toArray()); + referenceMap.put(getReferenceId(result), result); + return serializer.serialize(getReferenceId(result)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private List processReferenceParameters(List params) { + return params.stream().map(this::processReferenceParameter) + .collect(Collectors.toList()); + } + + private Object processReferenceParameter(Object o) { + if (o instanceof String) { + Object value = referenceMap.get(o); + if (value != null) { + return value; + } + } + return o; + } + + private String getReferenceId(Object o) { + return REFERENCE_ID_PREFIX + System.identityHashCode(o); + } + +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/JobSchedulerImpl.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/JobSchedulerImpl.java index a6fb38bb9..45191befe 100644 --- a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/JobSchedulerImpl.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/JobSchedulerImpl.java @@ -6,12 +6,14 @@ import java.util.Map; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; +import org.ray.api.RayPyActor; +import org.ray.streaming.api.Language; import org.ray.streaming.jobgraph.JobGraph; -import org.ray.streaming.jobgraph.JobVertex; -import org.ray.streaming.runtime.cluster.ResourceManager; import org.ray.streaming.runtime.core.graph.ExecutionGraph; import org.ray.streaming.runtime.core.graph.ExecutionNode; import org.ray.streaming.runtime.core.graph.ExecutionTask; +import org.ray.streaming.runtime.generated.RemoteCall; +import org.ray.streaming.runtime.python.GraphPbBuilder; import org.ray.streaming.runtime.worker.JobWorker; import org.ray.streaming.runtime.worker.context.WorkerContext; import org.ray.streaming.schedule.JobScheduler; @@ -23,43 +25,70 @@ import org.ray.streaming.schedule.JobScheduler; public class JobSchedulerImpl implements JobScheduler { private JobGraph jobGraph; private Map jobConfig; - private ResourceManager resourceManager; private TaskAssigner taskAssigner; public JobSchedulerImpl() { - this.resourceManager = new ResourceManager(); this.taskAssigner = new TaskAssignerImpl(); } /** * Schedule physical plan to execution graph, and call streaming worker to init and run. */ + @SuppressWarnings("unchecked") @Override public void schedule(JobGraph jobGraph, Map jobConfig) { this.jobConfig = jobConfig; this.jobGraph = jobGraph; - System.setProperty("ray.raylet.config.num_workers_per_process_java", "1"); - Ray.init(); - - List> workers = this.resourceManager.createWorkers(getPlanWorker()); - ExecutionGraph executionGraph = this.taskAssigner.assign(this.jobGraph, workers); + if (Ray.internal() == null) { + System.setProperty("ray.raylet.config.num_workers_per_process_java", "1"); + Ray.init(); + } + ExecutionGraph executionGraph = this.taskAssigner.assign(this.jobGraph); List executionNodes = executionGraph.getExecutionNodeList(); - List> waits = new ArrayList<>(); + boolean hasPythonNode = executionNodes.stream() + .allMatch(node -> node.getLanguage() == Language.PYTHON); + RemoteCall.ExecutionGraph executionGraphPb = null; + if (hasPythonNode) { + executionGraphPb = new GraphPbBuilder().buildExecutionGraphPb(executionGraph); + } + List> waits = new ArrayList<>(); for (ExecutionNode executionNode : executionNodes) { List executionTasks = executionNode.getExecutionTasks(); for (ExecutionTask executionTask : executionTasks) { int taskId = executionTask.getTaskId(); - RayActor streamWorker = executionTask.getWorker(); - waits.add(Ray.call(JobWorker::init, streamWorker, - new WorkerContext(taskId, executionGraph, jobConfig))); + RayActor worker = executionTask.getWorker(); + switch (executionNode.getLanguage()) { + case JAVA: + RayActor jobWorker = (RayActor) worker; + waits.add(Ray.call(JobWorker::init, jobWorker, + new WorkerContext(taskId, executionGraph, jobConfig))); + break; + case PYTHON: + byte[] workerContextBytes = buildPythonWorkerContext( + taskId, executionGraphPb, jobConfig); + waits.add(Ray.callPy((RayPyActor) worker, + "init", workerContextBytes)); + break; + default: + throw new UnsupportedOperationException( + "Unsupported language " + executionNode.getLanguage()); + } } } Ray.wait(waits); } - private int getPlanWorker() { - List jobVertexList = jobGraph.getJobVertexList(); - return jobVertexList.stream().map(JobVertex::getParallelism).reduce(0, Integer::sum); + private byte[] buildPythonWorkerContext( + int taskId, + RemoteCall.ExecutionGraph executionGraphPb, + Map jobConfig) { + return RemoteCall.WorkerContext.newBuilder() + .setTaskId(taskId) + .putAllConf(jobConfig) + .setGraph(executionGraphPb) + .build() + .toByteArray(); } + } diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/TaskAssigner.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/TaskAssigner.java index 9927b6ae6..82f7d5f6d 100644 --- a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/TaskAssigner.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/TaskAssigner.java @@ -1,11 +1,8 @@ package org.ray.streaming.runtime.schedule; import java.io.Serializable; -import java.util.List; -import org.ray.api.RayActor; import org.ray.streaming.jobgraph.JobGraph; import org.ray.streaming.runtime.core.graph.ExecutionGraph; -import org.ray.streaming.runtime.worker.JobWorker; /** * Interface of the task assigning strategy. @@ -15,6 +12,6 @@ public interface TaskAssigner extends Serializable { /** * Assign logical plan to physical execution graph. */ - ExecutionGraph assign(JobGraph jobGraph, List> workers); + ExecutionGraph assign(JobGraph jobGraph); } diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/TaskAssignerImpl.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/TaskAssignerImpl.java index 4e0d2e31d..1f2027948 100644 --- a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/TaskAssignerImpl.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/TaskAssignerImpl.java @@ -4,7 +4,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; +import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.streaming.jobgraph.JobEdge; import org.ray.streaming.jobgraph.JobGraph; @@ -20,12 +20,11 @@ public class TaskAssignerImpl implements TaskAssigner { /** * Assign an optimized logical plan to execution graph. * - * @param jobGraph The logical plan. - * @param workers The worker actors. + * @param jobGraph The logical plan. * @return The physical execution graph. */ @Override - public ExecutionGraph assign(JobGraph jobGraph, List> workers) { + public ExecutionGraph assign(JobGraph jobGraph) { List jobVertices = jobGraph.getJobVertexList(); List jobEdges = jobGraph.getJobEdgeList(); @@ -37,7 +36,7 @@ public class TaskAssignerImpl implements TaskAssigner { executionNode.setNodeType(jobVertex.getVertexType()); List vertexTasks = new ArrayList<>(); for (int taskIndex = 0; taskIndex < jobVertex.getParallelism(); taskIndex++) { - vertexTasks.add(new ExecutionTask(taskId, taskIndex, workers.get(taskId))); + vertexTasks.add(new ExecutionTask(taskId, taskIndex, createWorker(jobVertex))); taskId++; } executionNode.setExecutionTasks(vertexTasks); @@ -51,12 +50,25 @@ public class TaskAssignerImpl implements TaskAssigner { ExecutionEdge executionEdge = new ExecutionEdge(srcNodeId, targetNodeId, jobEdge.getPartition()); - idToExecutionNode.get(srcNodeId).addExecutionEdge(executionEdge); + idToExecutionNode.get(srcNodeId).addOutputEdge(executionEdge); idToExecutionNode.get(targetNodeId).addInputEdge(executionEdge); } - List executionNodes = idToExecutionNode.values().stream() - .collect(Collectors.toList()); + List executionNodes = new ArrayList<>(idToExecutionNode.values()); return new ExecutionGraph(executionNodes); } + + private RayActor createWorker(JobVertex jobVertex) { + switch (jobVertex.getLanguage()) { + case PYTHON: + return Ray.createPyActor( + "ray.streaming.runtime.worker", "JobWorker"); + case JAVA: + return Ray.createActor(JobWorker::new); + default: + throw new UnsupportedOperationException( + "Unsupported language " + jobVertex.getLanguage()); + + } + } } diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/ReflectionUtils.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/ReflectionUtils.java new file mode 100644 index 000000000..89430e653 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/ReflectionUtils.java @@ -0,0 +1,86 @@ +package org.ray.streaming.runtime.util; + +import com.google.common.base.Preconditions; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; + +@SuppressWarnings("UnstableApiUsage") +public class ReflectionUtils { + + public static Method findMethod(Class cls, String methodName) { + List methods = findMethods(cls, methodName); + Preconditions.checkArgument(methods.size() == 1); + return methods.get(0); + } + + /** + * For covariant return type, return the most specific method. + * @return all methods named by {@code methodName}, + */ + public static List findMethods(Class cls, String methodName) { + List> classes = new ArrayList<>(); + Class clazz = cls; + while (clazz != null) { + classes.add(clazz); + clazz = clazz.getSuperclass(); + } + classes.addAll(getAllInterfaces(cls)); + if (classes.indexOf(Object.class) == -1) { + classes.add(Object.class); + } + + LinkedHashMap>, Method> methods = new LinkedHashMap<>(); + for (Class superClass : classes) { + for (Method m : superClass.getDeclaredMethods()) { + if (m.getName().equals(methodName)) { + List> params = Arrays.asList(m.getParameterTypes()); + Method method = methods.get(params); + if (method == null) { + methods.put(params, m); + } else { + // for covariant return type, use the most specific method + if (method.getReturnType().isAssignableFrom(m.getReturnType())) { + methods.put(params, m); + } + } + } + } + } + return new ArrayList<>(methods.values()); + } + + /** + *

Gets a List of all interfaces implemented by the given + * class and its superclasses.

+ *

The order is determined by looking through each interface in turn as + * declared in the source file and following its hierarchy up.

+ */ + public static List> getAllInterfaces(Class cls) { + if (cls == null) { + return null; + } + + LinkedHashSet> interfacesFound = new LinkedHashSet<>(); + getAllInterfaces(cls, interfacesFound); + return new ArrayList<>(interfacesFound); + } + + private static void getAllInterfaces(Class cls, LinkedHashSet> interfacesFound) { + while (cls != null) { + Class[] interfaces = cls.getInterfaces(); + for (Class anInterface : interfaces) { + if (!interfacesFound.contains(anInterface)) { + interfacesFound.add(anInterface); + getAllInterfaces(anInterface, interfacesFound); + } + } + + cls = cls.getSuperclass(); + } + } + +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/JobWorker.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/JobWorker.java index 6fabfc4e3..8f36ebd97 100644 --- a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/JobWorker.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/JobWorker.java @@ -2,6 +2,7 @@ package org.ray.streaming.runtime.worker; import java.io.Serializable; import java.util.Map; + import org.ray.api.Ray; import org.ray.api.annotation.RayRemote; import org.ray.runtime.RayMultiWorkerNativeRuntime; diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/StreamTask.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/StreamTask.java index cdea7092d..f4a8c5d7a 100644 --- a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/StreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/StreamTask.java @@ -65,7 +65,7 @@ public abstract class StreamTask implements Runnable { List collectors = new ArrayList<>(); for (ExecutionEdge edge : outputEdges) { Map outputActorIds = new HashMap<>(); - Map> taskId2Worker = executionGraph + Map taskId2Worker = executionGraph .getTaskId2WorkerByNodeId(edge.getTargetNodeId()); taskId2Worker.forEach((targetTaskId, targetActor) -> { String queueName = ChannelID.genIdStr(taskId, targetTaskId, executionGraph.getBuildTime()); @@ -91,7 +91,7 @@ public abstract class StreamTask implements Runnable { List inputEdges = executionNode.getInputsEdges(); Map inputActorIds = new HashMap<>(); for (ExecutionEdge edge : inputEdges) { - Map> taskId2Worker = executionGraph + Map taskId2Worker = executionGraph .getTaskId2WorkerByNodeId(edge.getSrcNodeId()); taskId2Worker.forEach((srcTaskId, srcActor) -> { String queueName = ChannelID.genIdStr(srcTaskId, taskId, executionGraph.getBuildTime()); diff --git a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/python/MsgPackSerializerTest.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/python/MsgPackSerializerTest.java new file mode 100644 index 000000000..930e8556e --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/python/MsgPackSerializerTest.java @@ -0,0 +1,39 @@ +package org.ray.streaming.runtime.python; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.testng.annotations.Test; + +@SuppressWarnings("unchecked") +public class MsgPackSerializerTest { + + @Test + public void testSerialize() { + MsgPackSerializer serializer = new MsgPackSerializer(); + + Map map = new HashMap(); + List list = new ArrayList<>(); + list.add(null); + list.add(true); + list.add(1); + list.add(1.0d); + list.add("str"); + map.put("k1", "value1"); + map.put("k2", 2); + map.put("k3", list); + byte[] bytes = serializer.serialize(map); + Object o = serializer.deserialize(bytes); + assertEquals(o, map); + + byte[] binary = {1, 2, 3, 4}; + assertTrue(Arrays.equals( + binary, (byte[]) (serializer.deserialize(serializer.serialize(binary))))); + } + +} \ No newline at end of file diff --git a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/python/PythonGatewayTest.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/python/PythonGatewayTest.java new file mode 100644 index 000000000..c8a1a5ea4 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/python/PythonGatewayTest.java @@ -0,0 +1,48 @@ +package org.ray.streaming.runtime.python; + +import static org.testng.Assert.assertEquals; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.ray.streaming.api.stream.StreamSink; +import org.ray.streaming.jobgraph.JobGraph; +import org.ray.streaming.jobgraph.JobGraphBuilder; +import org.testng.annotations.Test; + +public class PythonGatewayTest { + + @Test + public void testPythonGateway() { + MsgPackSerializer serializer = new MsgPackSerializer(); + PythonGateway gateway = new PythonGateway(); + gateway.createStreamingContext(); + Map config = new HashMap<>(); + config.put("k1", "v1"); + gateway.withConfig(serializer.serialize(config)); + byte[] mockPySource = new byte[0]; + Object source = serializer.deserialize( + gateway.createPythonStreamSource(mockPySource)); + byte[] mockPyFunc = new byte[0]; + Object mapPyFunc = serializer.deserialize(gateway.createPyFunc(mockPyFunc)); + Object mapStream = serializer.deserialize( + gateway.callMethod( + serializer.serialize(Arrays.asList(source, "map", mapPyFunc)))); + byte[] mockPyPartition = new byte[0]; + Object partition = serializer.deserialize( + gateway.createPyPartition(mockPyPartition)); + Object partitionedStream = serializer.deserialize( + gateway.callMethod( + serializer.serialize(Arrays.asList(mapStream, "partitionBy", partition)))); + byte[] mockSinkFunc = new byte[0]; + Object sinkPyFunc = serializer.deserialize(gateway.createPyFunc(mockSinkFunc)); + gateway.callMethod( + serializer.serialize(Arrays.asList(partitionedStream, "sink", sinkPyFunc))); + List streamSinks = gateway.getStreamingContext().getStreamSinks(); + assertEquals(streamSinks.size(), 1); + JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(streamSinks, "py_job"); + JobGraph jobGraph = jobGraphBuilder.build(); + jobGraph.printJobGraph(); + } +} \ No newline at end of file diff --git a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/schedule/TaskAssignerImplTest.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/schedule/TaskAssignerImplTest.java index 94ef3c80a..14708cb91 100644 --- a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/schedule/TaskAssignerImplTest.java +++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/schedule/TaskAssignerImplTest.java @@ -1,26 +1,20 @@ package org.ray.streaming.runtime.schedule; -import java.util.ArrayList; -import java.util.List; - import com.google.common.collect.Lists; -import org.ray.api.RayActor; -import org.ray.api.id.ActorId; -import org.ray.api.id.ObjectId; -import org.ray.runtime.actor.LocalModeRayActor; +import java.util.List; +import org.ray.api.Ray; import org.ray.streaming.api.context.StreamingContext; import org.ray.streaming.api.partition.impl.RoundRobinPartition; import org.ray.streaming.api.stream.DataStream; import org.ray.streaming.api.stream.DataStreamSink; import org.ray.streaming.api.stream.DataStreamSource; +import org.ray.streaming.jobgraph.JobGraph; +import org.ray.streaming.jobgraph.JobGraphBuilder; import org.ray.streaming.runtime.BaseUnitTest; import org.ray.streaming.runtime.core.graph.ExecutionEdge; import org.ray.streaming.runtime.core.graph.ExecutionGraph; import org.ray.streaming.runtime.core.graph.ExecutionNode; import org.ray.streaming.runtime.core.graph.ExecutionNode.NodeType; -import org.ray.streaming.runtime.worker.JobWorker; -import org.ray.streaming.jobgraph.JobGraph; -import org.ray.streaming.jobgraph.JobGraphBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.Assert; @@ -32,15 +26,11 @@ public class TaskAssignerImplTest extends BaseUnitTest { @Test public void testTaskAssignImpl() { + Ray.init(); JobGraph jobGraph = buildDataSyncPlan(); - List> workers = new ArrayList<>(); - for(int i = 0; i < jobGraph.getJobVertexList().size(); i++) { - workers.add(new LocalModeRayActor(ActorId.fromRandom(), ObjectId.fromRandom())); - } - TaskAssigner taskAssigner = new TaskAssignerImpl(); - ExecutionGraph executionGraph = taskAssigner.assign(jobGraph, workers); + ExecutionGraph executionGraph = taskAssigner.assign(jobGraph); List executionNodeList = executionGraph.getExecutionNodeList(); @@ -61,16 +51,17 @@ public class TaskAssignerImplTest extends BaseUnitTest { Assert.assertEquals(sinkNode.getNodeType(), NodeType.SINK); Assert.assertEquals(sinkNode.getExecutionTasks().size(), 1); Assert.assertEquals(sinkNode.getOutputEdges().size(), 0); + + Ray.shutdown(); } public JobGraph buildDataSyncPlan() { StreamingContext streamingContext = StreamingContext.buildContext(); DataStream dataStream = DataStreamSource.buildSource(streamingContext, Lists.newArrayList("a", "b", "c")); - DataStreamSink streamSink = dataStream.sink(x -> LOGGER.info(x)); + DataStreamSink streamSink = dataStream.sink(LOGGER::info); JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink)); - JobGraph jobGraph = jobGraphBuilder.build(); - return jobGraph; + return jobGraphBuilder.build(); } } diff --git a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/util/ReflectionUtilsTest.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/util/ReflectionUtilsTest.java new file mode 100644 index 000000000..6f42b1c0a --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/util/ReflectionUtilsTest.java @@ -0,0 +1,38 @@ +package org.ray.streaming.runtime.util; + +import static org.testng.Assert.assertEquals; + +import java.io.Serializable; +import java.util.Collections; +import org.testng.annotations.Test; + +public class ReflectionUtilsTest { + + static class Foo implements Serializable { + public void f1() { + } + + public void f2() { + } + + public void f2(boolean a1) { + } + } + + @Test + public void testFindMethod() throws NoSuchMethodException { + assertEquals(Foo.class.getDeclaredMethod("f1"), + ReflectionUtils.findMethod(Foo.class, "f1")); + } + + @Test + public void testFindMethods() { + assertEquals(ReflectionUtils.findMethods(Foo.class, "f2").size(), 2); + } + + @Test + public void testGetAllInterfaces() { + assertEquals(ReflectionUtils.getAllInterfaces(Foo.class), + Collections.singletonList(Serializable.class)); + } +} \ No newline at end of file diff --git a/streaming/java/test.sh b/streaming/java/test.sh index 30c73002e..d450b384e 100755 --- a/streaming/java/test.sh +++ b/streaming/java/test.sh @@ -23,8 +23,7 @@ bazel test //streaming/java:all --test_tag_filters="checkstyle" --build_tests_on echo "Running streaming tests." java -cp "$ROOT_DIR"/../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar\ - org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml -exit_code=$? + org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml || exit_code=$? echo "Streaming TestNG results" cat /tmp/ray_streaming_java_test_output/testng-results.xml # exit_code == 2 means there are skipped tests. diff --git a/streaming/python/README.rst b/streaming/python/README.rst deleted file mode 100644 index 4daab3a8d..000000000 --- a/streaming/python/README.rst +++ /dev/null @@ -1,16 +0,0 @@ -Streaming Library -================= - -Dependencies: - -Install NetworkX: ``pip install networkx`` - -Examples: - -- simple.py: A simple example with stateless operators and different parallelism per stage. - -Run ``python simple.py --input-file toy.txt`` - -- wordcount.py: A streaming wordcount example with a stateful operator (rolling sum). - -Run ``python wordcount.py --titles-file articles.txt`` diff --git a/streaming/python/__init__.py b/streaming/python/__init__.py index 4126425aa..2eb090c6f 100644 --- a/streaming/python/__init__.py +++ b/streaming/python/__init__.py @@ -1,3 +1,6 @@ # flake8: noqa # Ray should be imported before streaming import ray +from ray.streaming.context import StreamingContext + +__all__ = ['StreamingContext'] diff --git a/streaming/python/collector.py b/streaming/python/collector.py new file mode 100644 index 000000000..cc803eaf4 --- /dev/null +++ b/streaming/python/collector.py @@ -0,0 +1,49 @@ +import logging +import pickle +import typing +from abc import ABC, abstractmethod + +from ray.streaming import message +from ray.streaming import partition +from ray.streaming.runtime.transfer import ChannelID, DataWriter + +logger = logging.getLogger(__name__) + + +class Collector(ABC): + """ + The collector that collects data from an upstream operator, + and emits data to downstream operators. + """ + + @abstractmethod + def collect(self, record): + pass + + +class CollectionCollector(Collector): + def __init__(self, collector_list): + self._collector_list = collector_list + + def collect(self, value): + for collector in self._collector_list: + collector.collect(message.Record(value)) + + +class OutputCollector(Collector): + def __init__(self, channel_ids: typing.List[str], writer: DataWriter, + partition_func: partition.Partition): + self._channel_ids = [ChannelID(id_str) for id_str in channel_ids] + self._writer = writer + self._partition_func = partition_func + logger.info( + "Create OutputCollector, channel_ids {}, partition_func {}".format( + channel_ids, partition_func)) + + def collect(self, record): + partitions = self._partition_func.partition(record, + len(self._channel_ids)) + serialized_message = pickle.dumps(record) + for partition_index in partitions: + self._writer.write(self._channel_ids[partition_index], + serialized_message) diff --git a/streaming/python/communication.py b/streaming/python/communication.py deleted file mode 100644 index 576305186..000000000 --- a/streaming/python/communication.py +++ /dev/null @@ -1,279 +0,0 @@ -import hashlib -import logging -import pickle -import sys -import time - -import ray -import ray.streaming.runtime.transfer as transfer -from ray.streaming.config import Config -from ray.streaming.operator import PStrategy -from ray.streaming.runtime.transfer import ChannelID - -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - -# Forward and broadcast stream partitioning strategies -forward_broadcast_strategies = [PStrategy.Forward, PStrategy.Broadcast] - - -# Used to choose output channel in case of hash-based shuffling -def _hash(value): - if isinstance(value, int): - return value - try: - return int(hashlib.sha1(value.encode("utf-8")).hexdigest(), 16) - except AttributeError: - return int(hashlib.sha1(value).hexdigest(), 16) - - -class DataChannel: - """A data channel for actor-to-actor communication. - - Attributes: - env (Environment): The environment the channel belongs to. - src_operator_id (UUID): The id of the source operator of the channel. - src_instance_index (int): The id of the source instance. - dst_operator_id (UUID): The id of the destination operator of the - channel. - dst_instance_index (int): The id of the destination instance. - """ - - def __init__(self, src_operator_id, src_instance_index, dst_operator_id, - dst_instance_index, str_qid): - self.src_operator_id = src_operator_id - self.src_instance_index = src_instance_index - self.dst_operator_id = dst_operator_id - self.dst_instance_index = dst_instance_index - self.str_qid = str_qid - self.qid = ChannelID(str_qid) - - def __repr__(self): - return "(src({},{}),dst({},{}), qid({}))".format( - self.src_operator_id, self.src_instance_index, - self.dst_operator_id, self.dst_instance_index, self.str_qid) - - -_CLOSE_FLAG = b" " - - -# Pulls and merges data from multiple input channels -class DataInput: - """An input gate of an operator instance. - - The input gate pulls records from all input channels in a round-robin - fashion. - - Attributes: - input_channels (list): The list of input channels. - channel_index (int): The index of the next channel to pull from. - max_index (int): The number of input channels. - closed (list): A list of flags indicating whether an input channel - has been marked as 'closed'. - all_closed (bool): Denotes whether all input channels have been - closed (True) or not (False). - """ - - def __init__(self, env, channels): - assert len(channels) > 0 - self.env = env - self.reader = None # created in `init` method - self.input_channels = channels - self.channel_index = 0 - self.max_index = len(channels) - # Tracks the channels that have been closed. qid: close status - self.closed = {} - - def init(self): - channels = [c.str_qid for c in self.input_channels] - input_actors = [] - for c in self.input_channels: - actor = self.env.execution_graph.get_actor(c.src_operator_id, - c.src_instance_index) - input_actors.append(actor) - logger.info("DataInput input_actors %s", input_actors) - conf = { - Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context() - .current_driver_id, - Config.CHANNEL_TYPE: self.env.config.channel_type - } - self.reader = transfer.DataReader(channels, input_actors, conf) - - def pull(self): - # pull from channel - item = self.reader.read(100) - while item is None: - time.sleep(0.001) - item = self.reader.read(100) - msg_data = item.body() - if msg_data == _CLOSE_FLAG: - self.closed[item.channel_id] = True - if len(self.closed) == len(self.input_channels): - return None - else: - return self.pull() - else: - return pickle.loads(msg_data) - - def close(self): - self.reader.stop() - - -# Selects output channel(s) and pushes data -class DataOutput: - """An output gate of an operator instance. - - The output gate pushes records to output channels according to the - user-defined partitioning scheme. - - Attributes: - partitioning_schemes (dict): A mapping from destination operator ids - to partitioning schemes (see: PScheme in operator.py). - forward_channels (list): A list of channels to forward records. - shuffle_channels (list(list)): A list of output channels to shuffle - records grouped by destination operator. - shuffle_key_channels (list(list)): A list of output channels to - shuffle records by a key grouped by destination operator. - shuffle_exists (bool): A flag indicating that there exists at least - one shuffle_channel. - shuffle_key_exists (bool): A flag indicating that there exists at - least one shuffle_key_channel. - """ - - def __init__(self, env, channels, partitioning_schemes): - assert len(channels) > 0 - self.env = env - self.writer = None # created in `init` method - self.channels = channels - self.key_selector = None - self.round_robin_indexes = [0] - self.partitioning_schemes = partitioning_schemes - # Prepare output -- collect channels by type - self.forward_channels = [] # Forward and broadcast channels - slots = sum(1 for scheme in self.partitioning_schemes.values() - if scheme.strategy == PStrategy.RoundRobin) - self.round_robin_channels = [[]] * slots # RoundRobin channels - self.round_robin_indexes = [-1] * slots - slots = sum(1 for scheme in self.partitioning_schemes.values() - if scheme.strategy == PStrategy.Shuffle) - # Flag used to avoid hashing when there is no shuffling - self.shuffle_exists = slots > 0 - self.shuffle_channels = [[]] * slots # Shuffle channels - slots = sum(1 for scheme in self.partitioning_schemes.values() - if scheme.strategy == PStrategy.ShuffleByKey) - # Flag used to avoid hashing when there is no shuffling by key - self.shuffle_key_exists = slots > 0 - self.shuffle_key_channels = [[]] * slots # Shuffle by key channels - # Distinct shuffle destinations - shuffle_destinations = {} - # Distinct shuffle by key destinations - shuffle_by_key_destinations = {} - # Distinct round robin destinations - round_robin_destinations = {} - index_1 = 0 - index_2 = 0 - index_3 = 0 - for channel in channels: - p_scheme = self.partitioning_schemes[channel.dst_operator_id] - strategy = p_scheme.strategy - if strategy in forward_broadcast_strategies: - self.forward_channels.append(channel) - elif strategy == PStrategy.Shuffle: - pos = shuffle_destinations.setdefault(channel.dst_operator_id, - index_1) - self.shuffle_channels[pos].append(channel) - if pos == index_1: - index_1 += 1 - elif strategy == PStrategy.ShuffleByKey: - pos = shuffle_by_key_destinations.setdefault( - channel.dst_operator_id, index_2) - self.shuffle_key_channels[pos].append(channel) - if pos == index_2: - index_2 += 1 - elif strategy == PStrategy.RoundRobin: - pos = round_robin_destinations.setdefault( - channel.dst_operator_id, index_3) - self.round_robin_channels[pos].append(channel) - if pos == index_3: - index_3 += 1 - else: # TODO (john): Add support for other strategies - sys.exit("Unrecognized or unsupported partitioning strategy.") - # A KeyedDataStream can only be shuffled by key - assert not (self.shuffle_exists and self.shuffle_key_exists) - - def init(self): - """init DataOutput which creates DataWriter""" - channel_ids = [c.str_qid for c in self.channels] - to_actors = [] - for c in self.channels: - actor = self.env.execution_graph.get_actor(c.dst_operator_id, - c.dst_instance_index) - to_actors.append(actor) - logger.info("DataOutput output_actors %s", to_actors) - - conf = { - Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context() - .current_driver_id, - Config.CHANNEL_TYPE: self.env.config.channel_type - } - self.writer = transfer.DataWriter(channel_ids, to_actors, conf) - - def close(self): - """Close the channel (True) by propagating _CLOSE_FLAG - - _CLOSE_FLAG is used as special type of record that is propagated from - sources to sink to notify that the end of data in a stream. - """ - for c in self.channels: - self.writer.write(c.qid, _CLOSE_FLAG) - # must ensure DataWriter send None flag to peer actor - self.writer.stop() - - def push(self, record): - target_channels = [] - # Forward record - for c in self.forward_channels: - logger.debug("[writer] Push record '{}' to channel {}".format( - record, c)) - target_channels.append(c) - # Forward record - index = 0 - for channels in self.round_robin_channels: - self.round_robin_indexes[index] += 1 - if self.round_robin_indexes[index] == len(channels): - self.round_robin_indexes[index] = 0 # Reset index - c = channels[self.round_robin_indexes[index]] - logger.debug("[writer] Push record '{}' to channel {}".format( - record, c)) - target_channels.append(c) - index += 1 - # Hash-based shuffling by key - if self.shuffle_key_exists: - key, _ = record - h = _hash(key) - for channels in self.shuffle_key_channels: - num_instances = len(channels) # Downstream instances - c = channels[h % num_instances] - logger.debug( - "[key_shuffle] Push record '{}' to channel {}".format( - record, c)) - target_channels.append(c) - elif self.shuffle_exists: # Hash-based shuffling per destination - h = _hash(record) - for channels in self.shuffle_channels: - num_instances = len(channels) # Downstream instances - c = channels[h % num_instances] - logger.debug("[shuffle] Push record '{}' to channel {}".format( - record, c)) - target_channels.append(c) - else: # TODO (john): Handle rescaling - pass - - msg_data = pickle.dumps(record) - for c in target_channels: - # send data to channel - self.writer.write(c.qid, msg_data) - - def push_all(self, records): - for record in records: - self.push(record) diff --git a/streaming/python/config.py b/streaming/python/config.py index e6d56488b..8f54463b6 100644 --- a/streaming/python/config.py +++ b/streaming/python/config.py @@ -13,7 +13,8 @@ class Config: # return from StreamingReader.getBundle if only empty message read in this # interval. TIMER_INTERVAL_MS = "timer_interval_ms" - + READ_TIMEOUT_MS = "read_timeout_ms" + DEFAULT_READ_TIMEOUT_MS = "10" STREAMING_RING_BUFFER_CAPACITY = "streaming.ring_buffer_capacity" # write an empty message if there is no data to be written in this # interval. diff --git a/streaming/python/context.py b/streaming/python/context.py new file mode 100644 index 000000000..e051d0894 --- /dev/null +++ b/streaming/python/context.py @@ -0,0 +1,168 @@ +from abc import ABC, abstractmethod + +from ray.streaming.datastream import StreamSource +from ray.streaming.function import LocalFileSourceFunction +from ray.streaming.function import CollectionSourceFunction +from ray.streaming.function import SourceFunction +from ray.streaming.runtime.gateway_client import GatewayClient + + +class StreamingContext: + """ + Main entry point for ray streaming functionality. + A StreamingContext is also a wrapper of java + `org.ray.streaming.api.context.StreamingContext` + """ + + class Builder: + def __init__(self): + self._options = {} + + def option(self, key=None, value=None, conf=None): + """ + Sets a config option. Options set using this method are + automatically propagated to :class:`StreamingContext`'s own + configuration. + + Args: + key: a key name string for configuration property + value: a value string for configuration property + conf: multi key-value pairs as a dict + + Returns: + self + """ + if key is not None: + assert value + self._options[key] = str(value) + if conf is not None: + for k, v in conf.items(): + self._options[k] = v + return self + + def build(self): + """ + Creates a StreamingContext based on the options set in this + builder. + """ + ctx = StreamingContext() + ctx._gateway_client.with_config(self._options) + return ctx + + def __init__(self): + self.__gateway_client = GatewayClient() + self._j_ctx = self._gateway_client.create_streaming_context() + + def source(self, source_func: SourceFunction): + """Create an input data stream with a SourceFunction + + Args: + source_func: the SourceFunction used to create the data stream + + Returns: + The data stream constructed from the source_func + """ + return StreamSource.build_source(self, source_func) + + def from_values(self, *values): + """Creates a data stream from values + + Args: + values: The elements to create the data stream from. + + Returns: + The data stream representing the given values + """ + return self.from_collection(values) + + def from_collection(self, values): + """Creates a data stream from the given non-empty collection. + + Args: + values: The collection of elements to create the data stream from. + + Returns: + The data stream representing the given collection. + """ + assert values, "values shouldn't be None or empty" + func = CollectionSourceFunction(values) + return self.source(func) + + def read_text_file(self, filename: str): + """Reads the given file line-by-line and creates a data stream that + contains a string with the contents of each such line.""" + func = LocalFileSourceFunction(filename) + return self.source(func) + + def submit(self, job_name: str): + """Submit job for execution. + + Args: + job_name: name of the job + + Returns: + An JobSubmissionResult future + """ + self._gateway_client.execute(job_name) + # TODO return a JobSubmissionResult future + + def execute(self, job_name: str): + """Execute the job. This method will block until job finished. + + Args: + job_name: name of the job + """ + # TODO support block to job finish + # job_submit_result = self.submit(job_name) + # job_submit_result.wait_finish() + raise Exception("Unsupported") + + @property + def _gateway_client(self): + return self.__gateway_client + + +class RuntimeContext(ABC): + @abstractmethod + def get_task_id(self): + """ + Returns: + Task id of the parallel task. + """ + pass + + @abstractmethod + def get_task_index(self): + """ + Gets the index of this parallel subtask. The index starts from 0 + and goes up to parallelism-1 (parallelism as returned by + `get_parallelism()`). + + Returns: + The index of the parallel subtask. + """ + pass + + @abstractmethod + def get_parallelism(self): + """ + Returns: + The parallelism with which the parallel task runs. + """ + pass + + +class RuntimeContextImpl(RuntimeContext): + def __init__(self, task_id, task_index, parallelism): + self.task_id = task_id + self.task_index = task_index + self.parallelism = parallelism + + def get_task_id(self): + return self.task_id + + def get_task_index(self): + return self.task_index + + def get_parallelism(self): + return self.parallelism diff --git a/streaming/python/datastream.py b/streaming/python/datastream.py new file mode 100644 index 000000000..7dc897d35 --- /dev/null +++ b/streaming/python/datastream.py @@ -0,0 +1,284 @@ +from abc import ABC + +from ray.streaming import function +from ray.streaming import partition + + +class Stream(ABC): + """ + Abstract base class of all stream types. A Stream represents a stream of + elements of the same type. A Stream can be transformed into another Stream + by applying a transformation. + """ + + def __init__(self, input_stream, j_stream, streaming_context=None): + self.input_stream = input_stream + self._j_stream = j_stream + if streaming_context is None: + assert input_stream is not None + self.streaming_context = input_stream.streaming_context + else: + self.streaming_context = streaming_context + self.parallelism = 1 + + def get_streaming_context(self): + return self.streaming_context + + def get_parallelism(self): + """ + Returns: + the parallelism of this transformation + """ + return self.parallelism + + def set_parallelism(self, parallelism: int): + """Sets the parallelism of this transformation + + Args: + parallelism: The new parallelism to set on this transformation + + Returns: + self + """ + self.parallelism = parallelism + self._gateway_client(). \ + call_method(self._j_stream, "setParallelism", parallelism) + return self + + def get_input_stream(self): + """ + Returns: + input stream of this stream + """ + return self.input_stream + + def get_id(self): + """ + Returns: + An unique id identifies this stream. + """ + return self._gateway_client(). \ + call_method(self._j_stream, "getId") + + def _gateway_client(self): + return self.get_streaming_context()._gateway_client + + +class DataStream(Stream): + """ + Represents a stream of data which applies a transformation executed by + python. It's also a wrapper of java + `org.ray.streaming.python.stream.PythonDataStream` + """ + + def __init__(self, input_stream, j_stream, streaming_context=None): + super().__init__( + input_stream, j_stream, streaming_context=streaming_context) + + def map(self, func): + """ + Applies a Map transformation on a :class:`DataStream`. + The transformation calls a :class:`ray.streaming.function.MapFunction` + for each element of the DataStream. + + Args: + func: The MapFunction that is called for each element of the + DataStream. If `func` is a python function instead of a subclass + of MapFunction, it will be wrapped as SimpleMapFunction. + + Returns: + A new data stream transformed by the MapFunction. + """ + if not isinstance(func, function.MapFunction): + func = function.SimpleMapFunction(func) + j_func = self._gateway_client().create_py_func( + function.serialize(func)) + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "map", j_func) + return DataStream(self, j_stream) + + def flat_map(self, func): + """ + Applies a FlatMap transformation on a :class:`DataStream`. The + transformation calls a :class:`ray.streaming.function.FlatMapFunction` + for each element of the DataStream. + Each FlatMapFunction call can return any number of elements including + none. + + Args: + func: The FlatMapFunction that is called for each element of the + DataStream. If `func` is a python function instead of a subclass + of FlatMapFunction, it will be wrapped as SimpleFlatMapFunction. + + Returns: + The transformed DataStream + """ + if not isinstance(func, function.FlatMapFunction): + func = function.SimpleFlatMapFunction(func) + j_func = self._gateway_client().create_py_func( + function.serialize(func)) + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "flatMap", j_func) + return DataStream(self, j_stream) + + def filter(self, func): + """ + Applies a Filter transformation on a :class:`DataStream`. The + transformation calls a :class:`ray.streaming.function.FilterFunction` + for each element of the DataStream. + DataStream and retains only those element for which the function + returns True. + + Args: + func: The FilterFunction that is called for each element of the + DataStream. If `func` is a python function instead of a subclass of + FilterFunction, it will be wrapped as SimpleFilterFunction. + + Returns: + The filtered DataStream + """ + if not isinstance(func, function.FilterFunction): + func = function.SimpleFilterFunction(func) + j_func = self._gateway_client().create_py_func( + function.serialize(func)) + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "filter", j_func) + return DataStream(self, j_stream) + + def key_by(self, func): + """ + Creates a new :class:`KeyDataStream` that uses the provided key to + partition data stream by key. + + Args: + func: The KeyFunction that is used for extracting the key for + partitioning. If `func` is a python function instead of a subclass + of KeyFunction, it will be wrapped as SimpleKeyFunction. + + Returns: + A KeyDataStream + """ + if not isinstance(func, function.KeyFunction): + func = function.SimpleKeyFunction(func) + j_func = self._gateway_client().create_py_func( + function.serialize(func)) + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "keyBy", j_func) + return KeyDataStream(self, j_stream) + + def broadcast(self): + """ + Sets the partitioning of the :class:`DataStream` so that the output + elements are broadcast to every parallel instance of the next + operation. + + Returns: + The DataStream with broadcast partitioning set. + """ + self._gateway_client().call_method(self._j_stream, "broadcast") + return self + + def partition_by(self, partition_func): + """ + Sets the partitioning of the :class:`DataStream` so that the elements + of stream are partitioned by specified partition function. + + Args: + partition_func: partition function. + If `func` is a python function instead of a subclass of Partition, + it will be wrapped as SimplePartition. + + Returns: + The DataStream with specified partitioning set. + """ + if not isinstance(partition_func, partition.Partition): + partition_func = partition.SimplePartition(partition_func) + j_partition = self._gateway_client().create_py_func( + partition.serialize(partition_func)) + self._gateway_client(). \ + call_method(self._j_stream, "partitionBy", j_partition) + return self + + def sink(self, func): + """ + Create a StreamSink with the given sink. + + Args: + func: sink function. + + Returns: + a StreamSink. + """ + if not isinstance(func, function.SinkFunction): + func = function.SimpleSinkFunction(func) + j_func = self._gateway_client().create_py_func( + function.serialize(func)) + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "sink", j_func) + return StreamSink(self, j_stream, func) + + +class KeyDataStream(Stream): + """Represents a DataStream returned by a key-by operation. + Wrapper of java org.ray.streaming.python.stream.PythonKeyDataStream + """ + + def __init__(self, input_stream, j_stream): + super().__init__(input_stream, j_stream) + + def reduce(self, func): + """ + Applies a reduce transformation on the grouped data stream grouped on + by the given key function. + The :class:`ray.streaming.function.ReduceFunction` will receive input + values based on the key value. Only input values with the same key will + go to the same reducer. + + Args: + func: The ReduceFunction that will be called for every element of + the input values with the same key. If `func` is a python function + instead of a subclass of ReduceFunction, it will be wrapped as + SimpleReduceFunction. + + Returns: + A transformed DataStream. + """ + if not isinstance(func, function.ReduceFunction): + func = function.SimpleReduceFunction(func) + j_func = self._gateway_client().create_py_func( + function.serialize(func)) + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "reduce", j_func) + return DataStream(self, j_stream) + + +class StreamSource(DataStream): + """Represents a source of the DataStream. + Wrapper of java org.ray.streaming.python.stream.PythonStreamSource + """ + + def __init__(self, j_stream, streaming_context, source_func): + super().__init__(None, j_stream, streaming_context=streaming_context) + self.source_func = source_func + + @staticmethod + def build_source(streaming_context, func): + """Build a StreamSource source from a collection. + Args: + streaming_context: Stream context + func: A instance of `SourceFunction` + Returns: + A StreamSource + """ + j_stream = streaming_context._gateway_client. \ + create_py_stream_source(function.serialize(func)) + return StreamSource(j_stream, streaming_context, func) + + +class StreamSink(Stream): + """Represents a sink of the DataStream. + Wrapper of java org.ray.streaming.python.stream.PythonStreamSink + """ + + def __init__(self, input_stream, j_stream, func): + super().__init__(input_stream, j_stream) diff --git a/streaming/python/examples/key_selectors.py b/streaming/python/examples/key_selectors.py deleted file mode 100644 index 2c8e784ed..000000000 --- a/streaming/python/examples/key_selectors.py +++ /dev/null @@ -1,67 +0,0 @@ -import argparse -import logging -import time - -import ray -from ray.streaming.streaming import Environment - -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - -parser = argparse.ArgumentParser() -parser.add_argument("--input-file", required=True, help="the input text file") - - -# A class used to check attribute-based key selection -class Record: - def __init__(self, record): - k, _ = record - self.word = k - self.record = record - - -# Splits input line into words and outputs objects of type Record -# each one consisting of a key (word) and a tuple (word,1) -def splitter(line): - records = [] - words = line.split() - for w in words: - records.append(Record((w, 1))) - return records - - -# Receives an object of type Record and returns the actual tuple -def as_tuple(record): - return record.record - - -if __name__ == "__main__": - # Get program parameters - args = parser.parse_args() - input_file = str(args.input_file) - - ray.init() - ray.register_custom_serializer(Record, use_dict=True) - - # A Ray streaming environment with the default configuration - env = Environment() - env.set_parallelism(2) # Each operator will be executed by two actors - - # 'key_by("word")' physically partitions the stream of records - # based on the hash value of the 'word' attribute (see Record class above) - # 'map(as_tuple)' maps a record of type Record into a tuple - # 'sum(1)' sums the 2nd element of the tuple, i.e. the word count - stream = env.read_text_file(input_file) \ - .round_robin() \ - .flat_map(splitter) \ - .key_by("word") \ - .map(as_tuple) \ - .sum(1) \ - .inspect(print) # Prints the content of the - # stream to stdout - start = time.time() - env_handle = env.execute() # Deploys and executes the dataflow - ray.get(env_handle) # Stay alive until execution finishes - end = time.time() - logger.info("Elapsed time: {} secs".format(end - start)) - logger.debug("Output stream id: {}".format(stream.id)) diff --git a/streaming/python/examples/simple.py b/streaming/python/examples/simple.py deleted file mode 100644 index 43264abb4..000000000 --- a/streaming/python/examples/simple.py +++ /dev/null @@ -1,52 +0,0 @@ -import argparse -import logging -import time - -import ray -from ray.streaming.config import Config -from ray.streaming.streaming import Environment, Conf - -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - -parser = argparse.ArgumentParser() -parser.add_argument("--input-file", required=True, help="the input text file") - - -# Test functions -def splitter(line): - return line.split() - - -def filter_fn(word): - if "f" in word: - return True - return False - - -if __name__ == "__main__": - - args = parser.parse_args() - - ray.init(local_mode=False) - - # A Ray streaming environment with the default configuration - env = Environment(config=Conf(channel_type=Config.NATIVE_CHANNEL)) - - # Stream represents the ouput of the filter and - # can be forked into other dataflows - stream = env.read_text_file(args.input_file) \ - .shuffle() \ - .flat_map(splitter) \ - .set_parallelism(2) \ - .filter(filter_fn) \ - .set_parallelism(2) \ - .inspect(lambda x: print("result", x)) # Prints the contents of the - # stream to stdout - start = time.time() - env_handle = env.execute() - ray.get(env_handle) # Stay alive until execution finishes - env.wait_finish() - end = time.time() - logger.info("Elapsed time: {} secs".format(end - start)) - logger.debug("Output stream id: {}".format(stream.id)) diff --git a/streaming/python/examples/toy.txt b/streaming/python/examples/toy.txt deleted file mode 100644 index fabe58790..000000000 --- a/streaming/python/examples/toy.txt +++ /dev/null @@ -1,5 +0,0 @@ -This is -a test file -to test if example -works -fine diff --git a/streaming/python/examples/wordcount.py b/streaming/python/examples/wordcount.py index d9ce4cd54..685716013 100644 --- a/streaming/python/examples/wordcount.py +++ b/streaming/python/examples/wordcount.py @@ -4,7 +4,8 @@ import time import ray import wikipedia -from ray.streaming.streaming import Environment +from ray.streaming import StreamingContext +from ray.streaming.config import Config logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -23,7 +24,6 @@ class Wikipedia: def __init__(self, title_file): # Titles in this file will be as queries self.title_file = title_file - # TODO (john): Handle possible exception here self.title_reader = iter(list(open(self.title_file, "r").readlines())) self.done = False self.article_done = True @@ -57,21 +57,7 @@ class Wikipedia: # Splits input line into words and # outputs records of the form (word,1) def splitter(line): - records = [] - words = line.split() - for w in words: - records.append((w, 1)) - return records - - -# Returns the first attribute of a tuple -def key_selector(tuple): - return tuple[0] - - -# Returns the second attribute of a tuple -def attribute_selector(tuple): - return tuple[1] + return [(word, 1) for word in line.split()] if __name__ == "__main__": @@ -79,27 +65,23 @@ if __name__ == "__main__": args = parser.parse_args() titles_file = str(args.titles_file) - ray.init() + ray.init(load_code_from_local=True, include_java=True) + ctx = StreamingContext.Builder() \ + .option(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL) \ + .build() # A Ray streaming environment with the default configuration - env = Environment() - env.set_parallelism(2) # Each operator will be executed by two actors + ctx.set_parallelism(1) # Each operator will be executed by two actors - # The following dataflow is a simple streaming wordcount - # with a rolling sum operator. - # It reads articles from wikipedia, splits them in words, - # shuffles words, and counts the occurences of each word. - stream = env.source(Wikipedia(titles_file)) \ - .round_robin() \ - .flat_map(splitter) \ - .key_by(key_selector) \ - .sum(attribute_selector) \ - .inspect(print) # Prints the contents of the - # stream to stdout + # Reads articles from wikipedia, splits them in words, + # shuffles words, and counts the occurrences of each word. + stream = ctx.source(Wikipedia(titles_file)) \ + .flat_map(splitter) \ + .key_by(lambda x: x[0]) \ + .reduce(lambda old_value, new_value: + (old_value[0], old_value[1] + new_value[1])) \ + .sink(print) start = time.time() - env_handle = env.execute() # Deploys and executes the dataflow - ray.get(env_handle) # Stay alive until execution finishes - env.wait_finish() + ctx.execute("wordcount") end = time.time() logger.info("Elapsed time: {} secs".format(end - start)) - logger.debug("Output stream id: {}".format(stream.id)) diff --git a/streaming/python/function.py b/streaming/python/function.py new file mode 100644 index 000000000..e94535d0a --- /dev/null +++ b/streaming/python/function.py @@ -0,0 +1,315 @@ +import importlib +import inspect +import sys +from abc import ABC, abstractmethod +import typing + +import cloudpickle +from ray.streaming.runtime import gateway_client + + +class Function(ABC): + """The base interface for all user-defined functions.""" + + def open(self, conf: typing.Dict[str, str]): + pass + + def close(self): + pass + + +class SourceContext(ABC): + """ + Interface that source functions use to emit elements, and possibly + watermarks.""" + + @abstractmethod + def collect(self, element): + """Emits one element from the source, without attaching a timestamp.""" + pass + + +class SourceFunction(Function): + """Interface of Source functions.""" + + @abstractmethod + def init(self, parallel, index): + """ + Args: + parallel: parallelism of source function + index: task index of this function and goes up from 0 to + parallel-1. + """ + pass + + @abstractmethod + def run(self, ctx: SourceContext): + """Starts the source. Implementations can use the + :class:`SourceContext` to emit elements. + """ + pass + + def close(self): + pass + + +class MapFunction(Function): + """ + Base interface for Map functions. Map functions take elements and transform + them element wise. A Map function always produces a single result element + for each input element. + """ + + def map(self, value): + pass + + +class FlatMapFunction(Function): + """ + Base interface for flatMap functions. FlatMap functions take elements and + transform them into zero, one, or more elements. + """ + + def flat_map(self, value, collector): + """Takes an element from the input data set and transforms it into zero, + one, or more elements. + + Args: + value: The input value. + collector: The collector for returning result values. + """ + pass + + +class FilterFunction(Function): + """ + A filter function is a predicate applied individually to each record. + The predicate decides whether to keep the element, or to discard it. + """ + + def filter(self, value): + """The filter function that evaluates the predicate. + + Args: + value: The value to be filtered. + + Returns: + True for values that should be retained, false for values to be + filtered out. + """ + pass + + +class KeyFunction(Function): + """ + A key function is extractor which takes an object and returns the + deterministic key for that object. + """ + + def key_by(self, value): + """User-defined function that deterministically extracts the key from + an object. + + Args: + value: The object to get the key from. + + Returns: + The extracted key. + """ + pass + + +class ReduceFunction(Function): + """ + Base interface for Reduce functions. Reduce functions combine groups of + elements to a single value, by taking always two elements and combining + them into one. + """ + + def reduce(self, old_value, new_value): + """ + The core method of ReduceFunction, combining two values into one value + of the same type. The reduce function is consecutively applied to all + values of a group until only a single value remains. + + Args: + old_value: The old value to combine. + new_value: The new input value to combine. + + Returns: + The combined value of both values. + """ + pass + + +class SinkFunction(Function): + """Interface for implementing user defined sink functionality.""" + + def sink(self, value): + """Writes the given value to the sink. This function is called for + every record.""" + pass + + +class CollectionSourceFunction(SourceFunction): + def __init__(self, values): + self.values = values + + def init(self, parallel, index): + pass + + def run(self, ctx: SourceContext): + for v in self.values: + ctx.collect(v) + + +class LocalFileSourceFunction(SourceFunction): + def __init__(self, filename): + self.filename = filename + + def init(self, parallel, index): + pass + + def run(self, ctx: SourceContext): + with open(self.filename, "r") as f: + line = f.readline() + while line != "": + ctx.collect(line[:-1]) + line = f.readline() + + +class SimpleMapFunction(MapFunction): + def __init__(self, func): + self.func = func + + def map(self, value): + return self.func(value) + + +class SimpleFlatMapFunction(FlatMapFunction): + """ + Wrap a python function as :class:`FlatMapFunction` + + >>> assert SimpleFlatMapFunction(lambda x: x.split()) + >>> def flat_func(x, collector): + ... for item in x.split(): + ... collector.collect(item) + >>> assert SimpleFlatMapFunction(flat_func) + """ + + def __init__(self, func): + """ + Args: + func: a python function which takes an element from input augment + and transforms it into zero, one, or more elements. + Or takes an element from input augment, and used provided collector + to collect zero, one, or more elements. + """ + self.func = func + self.process_func = None + sig = inspect.signature(func) + assert len(sig.parameters) <= 2,\ + "func should receive value [, collector] as arguments" + if len(sig.parameters) == 2: + + def process(value, collector): + func(value, collector) + + self.process_func = process + else: + + def process(value, collector): + for elem in func(value): + collector.collect(elem) + + self.process_func = process + + def flat_map(self, value, collector): + self.process_func(value, collector) + + +class SimpleFilterFunction(FilterFunction): + def __init__(self, func): + self.func = func + + def filter(self, value): + return self.func(value) + + +class SimpleKeyFunction(KeyFunction): + def __init__(self, func): + self.func = func + + def key_by(self, value): + return self.func(value) + + +class SimpleReduceFunction(ReduceFunction): + def __init__(self, func): + self.func = func + + def reduce(self, old_value, new_value): + return self.func(old_value, new_value) + + +class SimpleSinkFunction(SinkFunction): + def __init__(self, func): + self.func = func + + def sink(self, value): + return self.func(value) + + +def serialize(func: Function): + """Serialize a streaming :class:`Function`""" + return cloudpickle.dumps(func) + + +def deserialize(func_bytes): + """Deserialize a binary function serialized by `serialize` method.""" + return cloudpickle.loads(func_bytes) + + +def load_function(descriptor_func_bytes: bytes): + """ + Deserialize `descriptor_func_bytes` to get function info, then + get or load streaming function. + Note that this function must be kept in sync with + `org.ray.streaming.runtime.python.GraphPbBuilder.serializeFunction` + + Args: + descriptor_func_bytes: serialized function info + + Returns: + a streaming function + """ + function_bytes, module_name, class_name, function_name, function_interface\ + = gateway_client.deserialize(descriptor_func_bytes) + if function_bytes: + return deserialize(function_bytes) + else: + assert module_name + assert function_interface + function_interface = getattr(sys.modules[__name__], function_interface) + mod = importlib.import_module(module_name) + if class_name: + assert function_name is None + cls = getattr(mod, class_name) + assert issubclass(cls, function_interface) + return cls() + else: + assert function_name + func = getattr(mod, function_name) + simple_func_class = _get_simple_function_class(function_interface) + return simple_func_class(func) + + +def _get_simple_function_class(function_interface): + """Get the wrapper function for the given `function_interface`.""" + for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj) and issubclass(obj, function_interface): + if obj is not function_interface and obj.__name__.startswith( + "Simple"): + return obj + raise Exception( + "SimpleFunction for %s doesn't exist".format(function_interface)) diff --git a/streaming/python/includes/transfer.pxi b/streaming/python/includes/transfer.pxi index 7dcd91f28..b57f30f10 100644 --- a/streaming/python/includes/transfer.pxi +++ b/streaming/python/includes/transfer.pxi @@ -155,7 +155,7 @@ cdef class DataWriter: ctx.get().MarkMockTest() if config_bytes: config_data = config_bytes - channel_logger.info("load config, config bytes size: %s", config_data.nbytes) + channel_logger.info("DataWriter load config, config bytes size: %s", config_data.nbytes) ctx.get().SetConfig((&config_data[0]), config_data.nbytes) c_writer = new CDataWriter(ctx) cdef: @@ -235,7 +235,7 @@ cdef class DataReader: cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]() if config_bytes: config_data = config_bytes - channel_logger.info("load config, config bytes size: %s", config_data.nbytes) + channel_logger.info("DataReader load config, config bytes size: %s", config_data.nbytes) ctx.get().SetConfig((&(config_data[0])), config_data.nbytes) if is_mock: ctx.get().MarkMockTest() @@ -289,7 +289,7 @@ cdef class DataReader: msg_id = msg.get().GetMessageSeqId() msgs.append((msg_bytes, msg_id, timestamp, qid_bytes)) return msgs - elif bundle_type == libstreaming.BundleTypeEmpty: + elif bundle_type == libstreaming.BundleTypeEmpty: return [] else: raise Exception("Unsupported bundle type {}".format(bundle_type)) diff --git a/streaming/python/jobworker.py b/streaming/python/jobworker.py deleted file mode 100644 index 07cbd0fb8..000000000 --- a/streaming/python/jobworker.py +++ /dev/null @@ -1,120 +0,0 @@ -import logging -import pickle -import threading - -import ray -import ray.streaming._streaming as _streaming -from ray.streaming.config import Config -from ray._raylet import PythonFunctionDescriptor -from ray.streaming.communication import DataInput, DataOutput - -logger = logging.getLogger(__name__) - - -@ray.remote -class JobWorker: - """A streaming job worker. - - Attributes: - worker_id: The id of the instance. - input_channels: The input gate that manages input channels of - the instance (see: DataInput in communication.py). - output_channels (DataOutput): The output gate that manages output - channels of the instance (see: DataOutput in communication.py). - the operator instance. - """ - - def __init__(self, worker_id, operator, input_channels, output_channels): - self.env = None - self.worker_id = worker_id - self.operator = operator - processor_name = operator.processor_class.__name__ - processor_instance = operator.processor_class(operator) - self.processor_name = processor_name - self.processor_instance = processor_instance - self.input_channels = input_channels - self.output_channels = output_channels - self.input_gate = None - self.output_gate = None - self.reader_client = None - self.writer_client = None - - def init(self, env): - """init streaming actor""" - env = pickle.loads(env) - self.env = env - logger.info("init operator instance %s", self.processor_name) - - if env.config.channel_type == Config.NATIVE_CHANNEL: - core_worker = ray.worker.global_worker.core_worker - reader_async_func = PythonFunctionDescriptor( - __name__, self.on_reader_message.__name__, - self.__class__.__name__) - reader_sync_func = PythonFunctionDescriptor( - __name__, self.on_reader_message_sync.__name__, - self.__class__.__name__) - self.reader_client = _streaming.ReaderClient( - core_worker, reader_async_func, reader_sync_func) - writer_async_func = PythonFunctionDescriptor( - __name__, self.on_writer_message.__name__, - self.__class__.__name__) - writer_sync_func = PythonFunctionDescriptor( - __name__, self.on_writer_message_sync.__name__, - self.__class__.__name__) - self.writer_client = _streaming.WriterClient( - core_worker, writer_async_func, writer_sync_func) - if len(self.input_channels) > 0: - self.input_gate = DataInput(env, self.input_channels) - self.input_gate.init() - if len(self.output_channels) > 0: - self.output_gate = DataOutput( - env, self.output_channels, - self.operator.partitioning_strategies) - self.output_gate.init() - logger.info("init operator instance %s succeed", self.processor_name) - return True - - # Starts the actor - def start(self): - self.t = threading.Thread(target=self.run, daemon=True) - self.t.start() - actor_id = ray.worker.global_worker.actor_id - logger.info("%s %s started, actor id %s", self.__class__.__name__, - self.processor_name, actor_id) - - def run(self): - logger.info("%s start running", self.processor_name) - self.processor_instance.run(self.input_gate, self.output_gate) - logger.info("%s finished running", self.processor_name) - self.close() - - def close(self): - if self.input_gate: - self.input_gate.close() - if self.output_gate: - self.output_gate.close() - - def is_finished(self): - return not self.t.is_alive() - - def on_reader_message(self, buffer: bytes): - """used in direct call mode""" - self.reader_client.on_reader_message(buffer) - - def on_reader_message_sync(self, buffer: bytes): - """used in direct call mode""" - if self.reader_client is None: - return b" " * 4 # special flag to indicate this actor not ready - result = self.reader_client.on_reader_message_sync(buffer) - return result.to_pybytes() - - def on_writer_message(self, buffer: bytes): - """used in direct call mode""" - self.writer_client.on_writer_message(buffer) - - def on_writer_message_sync(self, buffer: bytes): - """used in direct call mode""" - if self.writer_client is None: - return b" " * 4 # special flag to indicate this actor not ready - result = self.writer_client.on_writer_message_sync(buffer) - return result.to_pybytes() diff --git a/streaming/python/message.py b/streaming/python/message.py new file mode 100644 index 000000000..fab29d4bf --- /dev/null +++ b/streaming/python/message.py @@ -0,0 +1,17 @@ +class Record: + """Data record in data stream""" + + def __init__(self, value): + self.value = value + self.stream = None + + def __repr__(self): + return "Record(%s)".format(self.value) + + +class KeyRecord(Record): + """Data record in a keyed data stream""" + + def __init__(self, key, value): + super().__init__(value) + self.key = key diff --git a/streaming/python/operator.py b/streaming/python/operator.py index 036f00d4a..e03be9707 100644 --- a/streaming/python/operator.py +++ b/streaming/python/operator.py @@ -1,109 +1,243 @@ +from abc import ABC, abstractmethod import enum -import logging - -import cloudpickle - -logger = logging.getLogger(__name__) -logger.setLevel("DEBUG") +from ray import streaming +from ray.streaming import function +from ray.streaming import message -# Stream partitioning schemes -class PScheme: - def __init__(self, strategy, partition_fn=None): - self.strategy = strategy - self.partition_fn = partition_fn - - def __repr__(self): - return "({},{})".format(self.strategy, self.partition_fn) +class OperatorType(enum.Enum): + SOURCE = 0 # Sources are where your program reads its input from + ONE_INPUT = 1 # This operator has one data stream as it's input stream. + TWO_INPUT = 2 # This operator has two data stream as it's input stream. -# Partitioning strategies -class PStrategy(enum.Enum): - Forward = 0 # Default - Shuffle = 1 - Rescale = 2 - RoundRobin = 3 - Broadcast = 4 - Custom = 5 - ShuffleByKey = 6 - # ... +class Operator(ABC): + """ + Abstract base class for all operators. + An operator is used to run a :class:`function.Function`. + """ + + @abstractmethod + def open(self, collectors, runtime_context): + pass + + @abstractmethod + def finish(self): + pass + + @abstractmethod + def close(self): + pass + + @abstractmethod + def operator_type(self) -> OperatorType: + pass -# Operator types -class OpType(enum.Enum): - Source = 0 - Map = 1 - FlatMap = 2 - Filter = 3 - TimeWindow = 4 - KeyBy = 5 - Sink = 6 - WindowJoin = 7 - Inspect = 8 - ReadTextFile = 9 - Reduce = 10 - Sum = 11 - # ... +class OneInputOperator(Operator, ABC): + """Interface for stream operators with one input.""" + + @abstractmethod + def process_element(self, record): + pass + + def operator_type(self): + return OperatorType.ONE_INPUT -# A logical dataflow operator -class Operator: - def __init__(self, - id, - op_type, - processor_class, - name="", - logic=None, - num_instances=1, - other=None, - state_actor=None): - self.id = id - self.type = op_type - self.processor_class = processor_class - self.name = name - self._logic = cloudpickle.dumps(logic) # The operator's logic - self.num_instances = num_instances - # One partitioning strategy per downstream operator (default: forward) - self.partitioning_strategies = {} - self.other_args = other # Depends on the type of the operator - self.state_actor = state_actor # Actor to query state +class TwoInputOperator(Operator, ABC): + """Interface for stream operators with two input""" - # Sets the partitioning scheme for an output stream of the operator - def _set_partition_strategy(self, - stream_id, - partitioning_scheme, - dest_operator=None): - self.partitioning_strategies[stream_id] = (partitioning_scheme, - dest_operator) + @abstractmethod + def process_element(self, record1, record2): + pass - # Retrieves the partitioning scheme for the given - # output stream of the operator - # Returns None is no strategy has been defined for the particular stream - def _get_partition_strategy(self, stream_id): - return self.partitioning_strategies.get(stream_id) + def operator_type(self): + return OperatorType.TWO_INPUT - # Cleans metatada from all partitioning strategies that lack a - # destination operator - # Valid entries are re-organized as - # 'destination operator id -> partitioning scheme' - # Should be called only after the logical dataflow has been constructed - def _clean(self): - strategies = {} - for _, v in self.partitioning_strategies.items(): - strategy, destination_operator = v - if destination_operator is not None: - strategies.setdefault(destination_operator, strategy) - self.partitioning_strategies = strategies - def print(self): - log = "Operator<\nID = {}\nName = {}\nprocessor_class = {}\n" - log += "Logic = {}\nNumber_of_Instances = {}\n" - log += "Partitioning_Scheme = {}\nOther_Args = {}>\n" - logger.debug( - log.format(self.id, self.name, self.processor_class, self.logic, - self.num_instances, self.partitioning_strategies, - self.other_args)) +class StreamOperator(Operator, ABC): + """ + Basic interface for stream operators. Implementers would implement one of + :class:`OneInputOperator` or :class:`TwoInputOperator` to to create + operators that process elements. + """ - @property - def logic(self): - return cloudpickle.loads(self._logic) + def __init__(self, func): + self.func = func + self.collectors = None + self.runtime_context = None + + def open(self, collectors, runtime_context): + self.collectors = collectors + self.runtime_context = runtime_context + + def finish(self): + pass + + def close(self): + pass + + def collect(self, record): + for collector in self.collectors: + collector.collect(record) + + +class SourceOperator(StreamOperator): + """ + Operator to run a :class:`function.SourceFunction` + """ + + class SourceContextImpl(function.SourceContext): + def __init__(self, collectors): + self.collectors = collectors + + def collect(self, value): + for collector in self.collectors: + collector.collect(message.Record(value)) + + def __init__(self, func): + assert isinstance(func, function.SourceFunction) + super().__init__(func) + self.source_context = None + + def open(self, collectors, runtime_context): + super().open(collectors, runtime_context) + self.source_context = SourceOperator.SourceContextImpl(collectors) + self.func.init(runtime_context.get_parallelism(), + runtime_context.get_task_index()) + + def run(self): + self.func.run(self.source_context) + + def operator_type(self): + return OperatorType.SOURCE + + +class MapOperator(StreamOperator, OneInputOperator): + """ + Operator to run a :class:`function.MapFunction` + """ + + def __init__(self, map_func: function.MapFunction): + assert isinstance(map_func, function.MapFunction) + super().__init__(map_func) + + def process_element(self, record): + self.collect(message.Record(self.func.map(record.value))) + + +class FlatMapOperator(StreamOperator, OneInputOperator): + """ + Operator to run a :class:`function.FlatMapFunction` + """ + + def __init__(self, flat_map_func: function.FlatMapFunction): + assert isinstance(flat_map_func, function.FlatMapFunction) + super().__init__(flat_map_func) + self.collection_collector = None + + def open(self, collectors, runtime_context): + super().open(collectors, runtime_context) + self.collection_collector = streaming.collector.CollectionCollector( + collectors) + + def process_element(self, record): + self.func.flat_map(record.value, self.collection_collector) + + +class FilterOperator(StreamOperator, OneInputOperator): + """ + Operator to run a :class:`function.FilterFunction` + """ + + def __init__(self, filter_func: function.FilterFunction): + assert isinstance(filter_func, function.FilterFunction) + super().__init__(filter_func) + + def process_element(self, record): + if self.func.filter(record.value): + self.collect(record) + + +class KeyByOperator(StreamOperator, OneInputOperator): + """ + Operator to run a :class:`function.KeyFunction` + """ + + def __init__(self, key_func: function.KeyFunction): + assert isinstance(key_func, function.KeyFunction) + super().__init__(key_func) + + def process_element(self, record): + key = self.func.key_by(record.value) + self.collect(message.KeyRecord(key, record.value)) + + +class ReduceOperator(StreamOperator, OneInputOperator): + """ + Operator to run a :class:`function.ReduceFunction` + """ + + def __init__(self, reduce_func: function.ReduceFunction): + assert isinstance(reduce_func, function.ReduceFunction) + super().__init__(reduce_func) + self.reduce_state = {} + + def open(self, collectors, runtime_context): + super().open(collectors, runtime_context) + + def process_element(self, record: message.KeyRecord): + key = record.key + value = record.value + if key in self.reduce_state: + old_value = self.reduce_state[key] + new_value = self.func.reduce(old_value, value) + self.reduce_state[key] = new_value + self.collect(message.Record(new_value)) + else: + self.reduce_state[key] = value + self.collect(record) + + +class SinkOperator(StreamOperator, OneInputOperator): + """ + Operator to run a :class:`function.SinkFunction` + """ + + def __init__(self, sink_func: function.SinkFunction): + assert isinstance(sink_func, function.SinkFunction) + super().__init__(sink_func) + + def process_element(self, record): + self.func.sink(record.value) + + +_function_to_operator = { + function.SourceFunction: SourceOperator, + function.MapFunction: MapOperator, + function.FlatMapFunction: FlatMapOperator, + function.FilterFunction: FilterOperator, + function.KeyFunction: KeyByOperator, + function.ReduceFunction: ReduceOperator, + function.SinkFunction: SinkOperator, +} + + +def create_operator(func: function.Function): + """Create an operator according to a :class:`function.Function` + + Args: + func: a subclass of function.Function + + Returns: + an operator + """ + operator_class = None + super_classes = func.__class__.mro() + for super_class in super_classes: + operator_class = _function_to_operator.get(super_class, None) + if operator_class is not None: + break + assert operator_class is not None + return operator_class(func) diff --git a/streaming/python/partition.py b/streaming/python/partition.py new file mode 100644 index 000000000..a9f9d92d6 --- /dev/null +++ b/streaming/python/partition.py @@ -0,0 +1,117 @@ +import importlib +from abc import ABC, abstractmethod + +import cloudpickle +from ray.streaming.runtime import gateway_client + + +class Partition(ABC): + """Interface of the partitioning strategy.""" + + @abstractmethod + def partition(self, record, num_partition: int): + """Given a record and downstream partitions, determine which partition(s) + should receive the record. + + Args: + record: The record. + num_partition: num of partitions + Returns: + IDs of the downstream partitions that should receive the record. + """ + pass + + +class BroadcastPartition(Partition): + """Broadcast the record to all downstream partitions.""" + + def __init__(self): + self.__partitions = [] + + def partition(self, record, num_partition: int): + if len(self.__partitions) != num_partition: + self.__partitions = list(range(num_partition)) + return self.__partitions + + +class KeyPartition(Partition): + """Partition the record by the key.""" + + def __init__(self): + self.__partitions = [-1] + + def partition(self, key_record, num_partition: int): + # TODO support key group + self.__partitions[0] = abs(hash(key_record.key)) % num_partition + return self.__partitions + + +class RoundRobinPartition(Partition): + """Partition record to downstream tasks in a round-robin matter.""" + + def __init__(self): + self.__partitions = [-1] + self.seq = 0 + + def partition(self, key_record, num_partition: int): + self.seq = (self.seq + 1) % num_partition + self.__partitions[0] = self.seq + return self.__partitions + + +class SimplePartition(Partition): + """Wrap a python function as subclass of :class:`Partition`""" + + def __init__(self, func): + self.func = func + + def partition(self, record, num_partition: int): + return self.func(record, num_partition) + + +def serialize(partition_func): + """ + Serialize the partition function so that it can be deserialized by + :func:`deserialize` + """ + return cloudpickle.dumps(partition_func) + + +def deserialize(partition_bytes): + """Deserialize the binary partition function serialized by + :func:`serialize`""" + return cloudpickle.loads(partition_bytes) + + +def load_partition(descriptor_partition_bytes: bytes): + """ + Deserialize `descriptor_partition_bytes` to get partition info, then + get or load partition function. + Note that this function must be kept in sync with + `org.ray.streaming.runtime.python.GraphPbBuilder.serializePartition` + + Args: + descriptor_partition_bytes: serialized partition info + + Returns: + partition function + """ + partition_bytes, module_name, class_name, function_name =\ + gateway_client.deserialize(descriptor_partition_bytes) + if partition_bytes: + return deserialize(partition_bytes) + else: + assert module_name + mod = importlib.import_module(module_name) + # If class_name is not None, user partition is a sub class + # of Partition. + # If function_name is not None, user partition is a simple python + # function, which will be wrapped as a SimplePartition. + if class_name: + assert function_name is None + cls = getattr(mod, class_name) + return cls() + else: + assert function_name + func = getattr(mod, function_name) + return SimplePartition(func) diff --git a/streaming/python/processor.py b/streaming/python/processor.py deleted file mode 100644 index 96da54245..000000000 --- a/streaming/python/processor.py +++ /dev/null @@ -1,222 +0,0 @@ -import logging -import sys -import time -import types - -logger = logging.getLogger(__name__) -logger.setLevel("INFO") - - -def _identity(element): - return element - - -class ReadTextFile: - """A source operator instance that reads a text file line by line. - - Attributes: - filepath (string): The path to the input file. - """ - - def __init__(self, operator): - self.filepath = operator.other_args - # TODO (john): Handle possible exception here - self.reader = open(self.filepath, "r") - - # Read input file line by line - def run(self, input_gate, output_gate): - while True: - record = self.reader.readline() - # Reader returns empty string ('') on EOF - if not record: - self.reader.close() - return - output_gate.push( - record[:-1]) # Push after removing newline characters - - -class Map: - """A map operator instance that applies a user-defined - stream transformation. - - A map produces exactly one output record for each record in - the input stream. - - """ - - def __init__(self, operator): - self.map_fn = operator.logic - - # Applies the mapper each record of the input stream(s) - # and pushes resulting records to the output stream(s) - def run(self, input_gate, output_gate): - elements = 0 - while True: - record = input_gate.pull() - if record is None: - return - output_gate.push(self.map_fn(record)) - elements += 1 - - -class FlatMap: - """A map operator instance that applies a user-defined - stream transformation. - - A flatmap produces one or more output records for each record in - the input stream. - - Attributes: - flatmap_fn (function): The user-defined function. - """ - - def __init__(self, operator): - self.flatmap_fn = operator.logic - - # Applies the splitter to the records of the input stream(s) - # and pushes resulting records to the output stream(s) - def run(self, input_gate, output_gate): - while True: - record = input_gate.pull() - if record is None: - return - output_gate.push_all(self.flatmap_fn(record)) - - -class Filter: - """A filter operator instance that applies a user-defined filter to - each record of the stream. - - Output records are those that pass the filter, i.e. those for which - the filter function returns True. - - Attributes: - filter_fn (function): The user-defined boolean function. - """ - - def __init__(self, operator): - self.filter_fn = operator.logic - - # Applies the filter to the records of the input stream(s) - # and pushes resulting records to the output stream(s) - def run(self, input_gate, output_gate): - while True: - record = input_gate.pull() - if record is None: - return - if self.filter_fn(record): - output_gate.push(record) - - -class Inspect: - """A inspect operator instance that inspects the content of the stream. - Inspect is useful for printing the records in the stream. - """ - - def __init__(self, operator): - self.inspect_fn = operator.logic - - def run(self, input_gate, output_gate): - # Applies the inspect logic (e.g. print) to the records of - # the input stream(s) - # and leaves stream unaffected by simply pushing the records to - # the output stream(s) - while True: - record = input_gate.pull() - if record is None: - return - if output_gate: - output_gate.push(record) - self.inspect_fn(record) - - -class Reduce: - """A reduce operator instance that combines a new value for a key - with the last reduced one according to a user-defined logic. - """ - - def __init__(self, operator): - self.reduce_fn = operator.logic - # Set the attribute selector - self.attribute_selector = operator.other_args - if self.attribute_selector is None: - self.attribute_selector = _identity - elif isinstance(self.attribute_selector, int): - self.key_index = self.attribute_selector - self.attribute_selector =\ - lambda record: record[self.attribute_selector] - elif isinstance(self.attribute_selector, str): - self.attribute_selector =\ - lambda record: vars(record)[self.attribute_selector] - elif not isinstance(self.attribute_selector, types.FunctionType): - sys.exit("Unrecognized or unsupported key selector.") - self.state = {} # key -> value - - # Combines the input value for a key with the last reduced - # value for that key to produce a new value. - # Outputs the result as (key,new value) - def run(self, input_gate, output_gate): - while True: - record = input_gate.pull() - if record is None: - return - key, rest = record - new_value = self.attribute_selector(rest) - # TODO (john): Is there a way to update state with - # a single dictionary lookup? - try: - old_value = self.state[key] - new_value = self.reduce_fn(old_value, new_value) - self.state[key] = new_value - except KeyError: # Key does not exist in state - self.state.setdefault(key, new_value) - output_gate.push((key, new_value)) - - # Returns the state of the actor - def get_state(self): - return self.state - - -class KeyBy: - """A key_by operator instance that physically partitions the - stream based on a key. - """ - - def __init__(self, operator): - # Set the key selector - self.key_selector = operator.other_args - if isinstance(self.key_selector, int): - self.key_selector = lambda r: r[self.key_selector] - elif isinstance(self.key_selector, str): - self.key_selector = lambda record: vars(record)[self.key_selector] - elif not isinstance(self.key_selector, types.FunctionType): - sys.exit("Unrecognized or unsupported key selector.") - - # The actual partitioning is done by the output gate - def run(self, input_gate, output_gate): - while True: - record = input_gate.pull() - if record is None: - return - key = self.key_selector(record) - output_gate.push((key, record)) - - -# A custom source actor -class Source: - def __init__(self, operator): - # The user-defined source with a get_next() method - self.source = operator.logic - - # Starts the source by calling get_next() repeatedly - def run(self, input_gate, output_gate): - start = time.time() - elements = 0 - while True: - record = self.source.get_next() - if not record: - logger.debug("[writer] puts per second: {}".format( - elements / (time.time() - start))) - return - output_gate.push(record) - elements += 1 diff --git a/streaming/python/runtime/gateway_client.py b/streaming/python/runtime/gateway_client.py new file mode 100644 index 000000000..12433b017 --- /dev/null +++ b/streaming/python/runtime/gateway_client.py @@ -0,0 +1,67 @@ +# -*- coding: UTF-8 -*- +"""Module to interact between java and python +""" + +import msgpack +import ray + + +class GatewayClient: + """GatewayClient is used to interact with `PythonGateway` java actor""" + + _PYTHON_GATEWAY_CLASSNAME = \ + b"org.ray.streaming.runtime.python.PythonGateway" + + def __init__(self): + self._python_gateway_actor = ray.java_actor_class( + GatewayClient._PYTHON_GATEWAY_CLASSNAME).remote() + + def create_streaming_context(self): + call = self._python_gateway_actor.createStreamingContext.remote() + return deserialize(ray.get(call)) + + def with_config(self, conf): + call = self._python_gateway_actor.withConfig.remote(serialize(conf)) + ray.get(call) + + def execute(self, job_name): + call = self._python_gateway_actor.execute.remote(serialize(job_name)) + ray.get(call) + + def create_py_stream_source(self, serialized_func): + assert isinstance(serialized_func, bytes) + call = self._python_gateway_actor.createPythonStreamSource\ + .remote(serialized_func) + return deserialize(ray.get(call)) + + def create_py_func(self, serialized_func): + assert isinstance(serialized_func, bytes) + call = self._python_gateway_actor.createPyFunc.remote(serialized_func) + return deserialize(ray.get(call)) + + def create_py_partition(self, serialized_partition): + assert isinstance(serialized_partition, bytes) + call = self._python_gateway_actor.createPyPartition\ + .remote(serialized_partition) + return deserialize(ray.get(call)) + + def call_function(self, java_class, java_function, *args): + java_params = serialize([java_class, java_function] + list(args)) + call = self._python_gateway_actor.callFunction.remote(java_params) + return deserialize(ray.get(call)) + + def call_method(self, java_object, java_method, *args): + java_params = serialize([java_object, java_method] + list(args)) + call = self._python_gateway_actor.callMethod.remote(java_params) + return deserialize(ray.get(call)) + + +def serialize(obj) -> bytes: + """Serialize a python object which can be deserialized by `PythonGateway` + """ + return msgpack.packb(obj, use_bin_type=True) + + +def deserialize(data: bytes): + """Deserialize the binary data serialized by `PythonGateway`""" + return msgpack.unpackb(data, raw=False) diff --git a/streaming/python/runtime/graph.py b/streaming/python/runtime/graph.py new file mode 100644 index 000000000..ab680129f --- /dev/null +++ b/streaming/python/runtime/graph.py @@ -0,0 +1,102 @@ +import enum + +import ray +import ray.streaming.generated.remote_call_pb2 as remote_call_pb +import ray.streaming.generated.streaming_pb2 as streaming_pb +import ray.streaming.operator as operator +import ray.streaming.partition as partition +from ray.streaming import function +from ray.streaming.generated.streaming_pb2 import Language + + +class NodeType(enum.Enum): + """ + SOURCE: Sources are where your program reads its input from + + TRANSFORM: Operators transform one or more DataStreams into a new + DataStream. Programs can combine multiple transformations into + sophisticated dataflow topologies. + + SINK: Sinks consume DataStreams and forward them to files, sockets, + external systems, or print them. + """ + SOURCE = 0 + TRANSFORM = 1 + SINK = 2 + + +class ExecutionNode: + def __init__(self, node_pb): + self.node_id = node_pb.node_id + self.node_type = NodeType[streaming_pb.NodeType.Name( + node_pb.node_type)] + self.parallelism = node_pb.parallelism + if node_pb.language == Language.PYTHON: + func_bytes = node_pb.function # python function descriptor + func = function.load_function(func_bytes) + self.stream_operator = operator.create_operator(func) + self.execution_tasks = [ + ExecutionTask(task) for task in node_pb.execution_tasks + ] + self.input_edges = [ + ExecutionEdge(edge, node_pb.language) + for edge in node_pb.input_edges + ] + self.output_edges = [ + ExecutionEdge(edge, node_pb.language) + for edge in node_pb.output_edges + ] + + +class ExecutionEdge: + def __init__(self, edge_pb, language): + self.src_node_id = edge_pb.src_node_id + self.target_node_id = edge_pb.target_node_id + partition_bytes = edge_pb.partition + if language == Language.PYTHON: + self.partition = partition.load_partition(partition_bytes) + + +class ExecutionTask: + def __init__(self, task_pb): + self.task_id = task_pb.task_id + self.task_index = task_pb.task_index + self.worker_actor = ray.actor.ActorHandle.\ + _deserialization_helper(task_pb.worker_actor, False) + + +class ExecutionGraph: + def __init__(self, graph_pb: remote_call_pb.ExecutionGraph): + self._graph_pb = graph_pb + self.execution_nodes = [ + ExecutionNode(node) for node in graph_pb.execution_nodes + ] + + def build_time(self): + return self._graph_pb.build_time + + def execution_nodes(self): + return self.execution_nodes + + def get_execution_task_by_task_id(self, task_id): + for execution_node in self.execution_nodes: + for task in execution_node.execution_tasks: + if task.task_id == task_id: + return task + raise Exception("Task %s does not exist!".format(task_id)) + + def get_execution_node_by_task_id(self, task_id): + for execution_node in self.execution_nodes: + for task in execution_node.execution_tasks: + if task.task_id == task_id: + return execution_node + raise Exception("Task %s does not exist!".format(task_id)) + + def get_task_id2_worker_by_node_id(self, node_id): + for execution_node in self.execution_nodes: + if execution_node.node_id == node_id: + task_id2_worker = {} + for task in execution_node.execution_tasks: + task_id2_worker[task.task_id] = task.worker_actor + return task_id2_worker + raise Exception("Node %s does not exist!".format(node_id)) diff --git a/streaming/python/runtime/processor.py b/streaming/python/runtime/processor.py new file mode 100644 index 000000000..ccfa55921 --- /dev/null +++ b/streaming/python/runtime/processor.py @@ -0,0 +1,113 @@ +import logging +from abc import ABC, abstractmethod + +import ray.streaming.context as context +from ray.streaming import message +from ray.streaming.operator import OperatorType + +logger = logging.getLogger(__name__) + + +class Processor(ABC): + """The base interface for all processors.""" + + @abstractmethod + def open(self, collectors, runtime_context): + pass + + @abstractmethod + def process(self, record: message.Record): + pass + + @abstractmethod + def close(self): + pass + + +class StreamingProcessor(Processor, ABC): + """StreamingProcessor is a process unit for a operator.""" + + def __init__(self, operator): + self.operator = operator + self.collectors = None + self.runtime_context = None + + def open(self, collectors, runtime_context: context.RuntimeContext): + self.collectors = collectors + self.runtime_context = runtime_context + if self.operator is not None: + self.operator.open(collectors, runtime_context) + logger.info("Opened Processor {}".format(self)) + + def close(self): + pass + + +class SourceProcessor(StreamingProcessor): + """Processor for :class:`ray.streaming.operator.SourceOperator` """ + + def __init__(self, operator): + super().__init__(operator) + + def process(self, record): + raise Exception("SourceProcessor should not process record") + + def run(self): + self.operator.run() + + +class OneInputProcessor(StreamingProcessor): + """Processor for stream operator with one input""" + + def __init__(self, operator): + super().__init__(operator) + + def process(self, record): + self.operator.process_element(record) + + +class TwoInputProcessor(StreamingProcessor): + """Processor for stream operator with two inputs""" + + def __init__(self, operator): + super().__init__(operator) + self.left_stream = None + self.right_stream = None + + def process(self, record: message.Record): + if record.stream == self.left_stream: + self.operator.process_element(record, None) + else: + self.operator.process_element(None, record) + + @property + def left_stream(self): + return self.left_stream + + @left_stream.setter + def left_stream(self, value): + self._left_stream = value + + @property + def right_stream(self): + return self.right_stream + + @right_stream.setter + def right_stream(self, value): + self.right_stream = value + + +def build_processor(operator_instance): + """Create a processor for the given operator.""" + operator_type = operator_instance.operator_type() + logger.info( + "Building StreamProcessor, operator type = {}, operator = {}.".format( + operator_type, operator_instance)) + if operator_type == OperatorType.SOURCE: + return SourceProcessor(operator_instance) + elif operator_type == OperatorType.ONE_INPUT: + return OneInputProcessor(operator_instance) + elif operator_type == OperatorType.TWO_INPUT: + return TwoInputProcessor(operator_instance) + else: + raise Exception("Current operator type is not supported") diff --git a/streaming/python/runtime/task.py b/streaming/python/runtime/task.py new file mode 100644 index 000000000..87dfb16b0 --- /dev/null +++ b/streaming/python/runtime/task.py @@ -0,0 +1,158 @@ +import logging +import pickle +import threading +from abc import ABC, abstractmethod + +import ray +from ray.streaming.collector import OutputCollector +from ray.streaming.config import Config +from ray.streaming.context import RuntimeContextImpl +from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader + +logger = logging.getLogger(__name__) + + +class StreamTask(ABC): + """Base class for all streaming tasks. Each task runs a processor.""" + + def __init__(self, task_id, processor, worker): + self.task_id = task_id + self.processor = processor + self.worker = worker + self.reader = None # DataReader + self.writers = {} # ExecutionEdge -> DataWriter + self.thread = None + self.prepare_task() + self.thread = threading.Thread(target=self.run, daemon=True) + + def prepare_task(self): + channel_conf = dict(self.worker.config) + channel_size = int( + self.worker.config.get(Config.CHANNEL_SIZE, + Config.CHANNEL_SIZE_DEFAULT)) + channel_conf[Config.CHANNEL_SIZE] = channel_size + channel_conf[Config.TASK_JOB_ID] = ray.runtime_context.\ + _get_runtime_context().current_driver_id + channel_conf[Config.CHANNEL_TYPE] = self.worker.config \ + .get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL) + + execution_graph = self.worker.execution_graph + execution_node = self.worker.execution_node + # writers + collectors = [] + for edge in execution_node.output_edges: + output_actor_ids = {} + task_id2_worker = execution_graph.get_task_id2_worker_by_node_id( + edge.target_node_id) + for target_task_id, target_actor in task_id2_worker.items(): + channel_name = ChannelID.gen_id(self.task_id, target_task_id, + execution_graph.build_time()) + output_actor_ids[channel_name] = target_actor + if len(output_actor_ids) > 0: + channel_ids = list(output_actor_ids.keys()) + to_actor_ids = list(output_actor_ids.values()) + writer = DataWriter(channel_ids, to_actor_ids, channel_conf) + logger.info("Create DataWriter succeed.") + self.writers[edge] = writer + collectors.append( + OutputCollector(channel_ids, writer, edge.partition)) + + # readers + input_actor_ids = {} + for edge in execution_node.input_edges: + task_id2_worker = execution_graph.get_task_id2_worker_by_node_id( + edge.src_node_id) + for src_task_id, src_actor in task_id2_worker.items(): + channel_name = ChannelID.gen_id(src_task_id, self.task_id, + execution_graph.build_time()) + input_actor_ids[channel_name] = src_actor + if len(input_actor_ids) > 0: + channel_ids = list(input_actor_ids.keys()) + from_actor_ids = list(input_actor_ids.values()) + logger.info("Create DataReader, channels {}.".format(channel_ids)) + self.reader = DataReader(channel_ids, from_actor_ids, channel_conf) + + def exit_handler(): + # Make DataReader stop read data when MockQueue destructor + # gets called to avoid crash + self.cancel_task() + + import atexit + atexit.register(exit_handler) + + runtime_context = RuntimeContextImpl( + self.worker.execution_task.task_id, + self.worker.execution_task.task_index, execution_node.parallelism) + logger.info("open Processor {}".format(self.processor)) + self.processor.open(collectors, runtime_context) + + @abstractmethod + def init(self): + pass + + def start(self): + self.thread.start() + + @abstractmethod + def run(self): + pass + + @abstractmethod + def cancel_task(self): + pass + + +class InputStreamTask(StreamTask): + """Base class for stream tasks that execute a + :class:`runtime.processor.OneInputProcessor` or + :class:`runtime.processor.TwoInputProcessor` """ + + def __init__(self, task_id, processor_instance, worker): + super().__init__(task_id, processor_instance, worker) + self.running = True + self.stopped = False + self.read_timeout_millis = \ + int(worker.config.get(Config.READ_TIMEOUT_MS, + Config.DEFAULT_READ_TIMEOUT_MS)) + + def init(self): + pass + + def run(self): + while self.running: + item = self.reader.read(self.read_timeout_millis) + if item is not None: + msg_data = item.body() + msg = pickle.loads(msg_data) + self.processor.process(msg) + self.stopped = True + + def cancel_task(self): + self.running = False + while not self.stopped: + pass + + +class OneInputStreamTask(InputStreamTask): + """A stream task for executing :class:`runtime.processor.OneInputProcessor` + """ + + def __init__(self, task_id, processor_instance, worker): + super().__init__(task_id, processor_instance, worker) + + +class SourceStreamTask(StreamTask): + """A stream task for executing :class:`runtime.processor.SourceProcessor` + """ + + def __init__(self, task_id, processor_instance, worker): + super().__init__(task_id, processor_instance, worker) + + def init(self): + pass + + def run(self): + self.processor.run() + + def cancel_task(self): + pass diff --git a/streaming/python/runtime/worker.py b/streaming/python/runtime/worker.py new file mode 100644 index 000000000..a2e891876 --- /dev/null +++ b/streaming/python/runtime/worker.py @@ -0,0 +1,104 @@ +import logging + +import ray +import ray.streaming._streaming as _streaming +import ray.streaming.generated.remote_call_pb2 as remote_call_pb +import ray.streaming.runtime.processor as processor +from ray._raylet import PythonFunctionDescriptor +from ray.streaming.config import Config +from ray.streaming.runtime.graph import ExecutionGraph +from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask + +logger = logging.getLogger(__name__) + + +@ray.remote +class JobWorker(object): + """A streaming job worker is used to execute user-defined function and + interact with `JobMaster`""" + + def __init__(self): + self.worker_context = None + self.task_id = None + self.config = None + self.execution_graph = None + self.execution_task = None + self.execution_node = None + self.stream_processor = None + self.task = None + self.reader_client = None + self.writer_client = None + + def init(self, worker_context_bytes): + worker_context = remote_call_pb.WorkerContext() + worker_context.ParseFromString(worker_context_bytes) + self.worker_context = worker_context + self.task_id = worker_context.task_id + self.config = worker_context.conf + execution_graph = ExecutionGraph(worker_context.graph) + self.execution_graph = execution_graph + self.execution_task = self.execution_graph. \ + get_execution_task_by_task_id(self.task_id) + self.execution_node = self.execution_graph. \ + get_execution_node_by_task_id(self.task_id) + operator = self.execution_node.stream_operator + self.stream_processor = processor.build_processor(operator) + logger.info( + "Initializing JobWorker, task_id: {}, operator: {}.".format( + self.task_id, self.stream_processor)) + + if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL): + core_worker = ray.worker.global_worker.core_worker + reader_async_func = PythonFunctionDescriptor( + __name__, self.on_reader_message.__name__, + self.__class__.__name__) + reader_sync_func = PythonFunctionDescriptor( + __name__, self.on_reader_message_sync.__name__, + self.__class__.__name__) + self.reader_client = _streaming.ReaderClient( + core_worker, reader_async_func, reader_sync_func) + writer_async_func = PythonFunctionDescriptor( + __name__, self.on_writer_message.__name__, + self.__class__.__name__) + writer_sync_func = PythonFunctionDescriptor( + __name__, self.on_writer_message_sync.__name__, + self.__class__.__name__) + self.writer_client = _streaming.WriterClient( + core_worker, writer_async_func, writer_sync_func) + + self.task = self.create_stream_task() + self.task.start() + logger.info("JobWorker init succeed") + return True + + def create_stream_task(self): + if isinstance(self.stream_processor, processor.SourceProcessor): + return SourceStreamTask(self.task_id, self.stream_processor, self) + elif isinstance(self.stream_processor, processor.OneInputProcessor): + return OneInputStreamTask(self.task_id, self.stream_processor, + self) + else: + raise Exception("Unsupported processor type: " + + type(self.stream_processor)) + + def on_reader_message(self, buffer: bytes): + """used in direct call mode""" + self.reader_client.on_reader_message(buffer) + + def on_reader_message_sync(self, buffer: bytes): + """used in direct call mode""" + if self.reader_client is None: + return b" " * 4 # special flag to indicate this actor not ready + result = self.reader_client.on_reader_message_sync(buffer) + return result.to_pybytes() + + def on_writer_message(self, buffer: bytes): + """used in direct call mode""" + self.writer_client.on_writer_message(buffer) + + def on_writer_message_sync(self, buffer: bytes): + """used in direct call mode""" + if self.writer_client is None: + return b" " * 4 # special flag to indicate this actor not ready + result = self.writer_client.on_writer_message_sync(buffer) + return result.to_pybytes() diff --git a/streaming/python/streaming.py b/streaming/python/streaming.py deleted file mode 100644 index 03a11608a..000000000 --- a/streaming/python/streaming.py +++ /dev/null @@ -1,689 +0,0 @@ -import logging -import pickle -import sys -import time - -import networkx as nx -import ray -import ray.streaming.processor as processor -import ray.streaming.runtime.transfer as transfer -from ray.streaming.communication import DataChannel -from ray.streaming.config import Config -from ray.streaming.jobworker import JobWorker -from ray.streaming.operator import Operator, OpType -from ray.streaming.operator import PScheme, PStrategy - -logger = logging.getLogger(__name__) -logger.setLevel("INFO") - - -# Rolling sum's logic -def _sum(value_1, value_2): - return value_1 + value_2 - - -# Partitioning strategies that require all-to-all instance communication -all_to_all_strategies = [ - PStrategy.Shuffle, PStrategy.ShuffleByKey, PStrategy.Broadcast, - PStrategy.RoundRobin -] - - -# Environment configuration -class Conf: - """Environment configuration. - - This class includes all information about the configuration of the - streaming environment. - """ - - def __init__(self, parallelism=1, channel_type=Config.MEMORY_CHANNEL): - self.parallelism = parallelism - self.channel_type = channel_type - # ... - - -class ExecutionGraph: - def __init__(self, env): - self.env = env - self.physical_topo = nx.DiGraph() # DAG - # Handles to all actors in the physical dataflow - self.actor_handles = [] - # (op_id, op_instance_index) -> ActorID - self.actors_map = {} - # execution graph build time: milliseconds since epoch - self.build_time = 0 - self.task_id_counter = 0 - self.task_ids = {} - self.input_channels = {} # operator id -> input channels - self.output_channels = {} # operator id -> output channels - - # Constructs and deploys a Ray actor of a specific type - # TODO (john): Actor placement information should be specified in - # the environment's configuration - def __generate_actor(self, instance_index, operator, input_channels, - output_channels): - """Generates an actor that will execute a particular instance of - the logical operator - - Attributes: - instance_index: The index of the instance the actor will execute. - operator: The metadata of the logical operator. - input_channels: The input channels of the instance. - output_channels The output channels of the instance. - """ - worker_id = (operator.id, instance_index) - # Record the physical dataflow graph (for debugging purposes) - self.__add_channel(worker_id, output_channels) - # Note direct_call only support pass by value - return JobWorker._remote( - args=[worker_id, operator, input_channels, output_channels], - is_direct_call=True) - - # Constructs and deploys a Ray actor for each instance of - # the given operator - def __generate_actors(self, operator, upstream_channels, - downstream_channels): - """Generates one actor for each instance of the given logical - operator. - - Attributes: - operator (Operator): The logical operator metadata. - upstream_channels (list): A list of all upstream channels for - all instances of the operator. - downstream_channels (list): A list of all downstream channels - for all instances of the operator. - """ - num_instances = operator.num_instances - logger.info("Generating {} actors of type {}...".format( - num_instances, operator.type)) - handles = [] - for i in range(num_instances): - # Collect input and output channels for the particular instance - ip = [c for c in upstream_channels if c.dst_instance_index == i] - op = [c for c in downstream_channels if c.src_instance_index == i] - log = "Constructed {} input and {} output channels " - log += "for the {}-th instance of the {} operator." - logger.debug(log.format(len(ip), len(op), i, operator.type)) - handle = self.__generate_actor(i, operator, ip, op) - if handle: - handles.append(handle) - self.actors_map[(operator.id, i)] = handle - return handles - - # Adds a channel/edge to the physical dataflow graph - def __add_channel(self, actor_id, output_channels): - for c in output_channels: - dest_actor_id = (c.dst_operator_id, c.dst_instance_index) - self.physical_topo.add_edge(actor_id, dest_actor_id) - - # Generates all required data channels between an operator - # and its downstream operators - def _generate_channels(self, operator): - """Generates all output data channels - (see: DataChannel in communication.py) for all instances of - the given logical operator. - - The function constructs one data channel for each pair of - communicating operator instances (instance_1,instance_2), - where instance_1 is an instance of the given operator and instance_2 - is an instance of a direct downstream operator. - - The number of total channels generated depends on the partitioning - strategy specified by the user. - """ - channels = {} # destination operator id -> channels - strategies = operator.partitioning_strategies - for dst_operator, p_scheme in strategies.items(): - num_dest_instances = self.env.operators[dst_operator].num_instances - entry = channels.setdefault(dst_operator, []) - if p_scheme.strategy == PStrategy.Forward: - for i in range(operator.num_instances): - # ID of destination instance to connect - id = i % num_dest_instances - qid = self._gen_str_qid(operator.id, i, dst_operator, id) - c = DataChannel(operator.id, i, dst_operator, id, qid) - entry.append(c) - elif p_scheme.strategy in all_to_all_strategies: - for i in range(operator.num_instances): - for j in range(num_dest_instances): - qid = self._gen_str_qid(operator.id, i, dst_operator, - j) - c = DataChannel(operator.id, i, dst_operator, j, qid) - entry.append(c) - else: - # TODO (john): Add support for other partitioning strategies - sys.exit("Unrecognized or unsupported partitioning strategy.") - return channels - - def _gen_str_qid(self, src_operator_id, src_instance_index, - dst_operator_id, dst_instance_index): - from_task_id = self.env.execution_graph.get_task_id( - src_operator_id, src_instance_index) - to_task_id = self.env.execution_graph.get_task_id( - dst_operator_id, dst_instance_index) - return transfer.ChannelID.gen_id(from_task_id, to_task_id, - self.build_time) - - def _gen_task_id(self): - task_id = self.task_id_counter - self.task_id_counter += 1 - return task_id - - def get_task_id(self, op_id, op_instance_id): - return self.task_ids[(op_id, op_instance_id)] - - def get_actor(self, op_id, op_instance_id): - return self.actors_map[(op_id, op_instance_id)] - - # Prints the physical dataflow graph - def print_physical_graph(self): - logger.info("===================================") - logger.info("======Physical Dataflow Graph======") - logger.info("===================================") - # Print all data channels between operator instances - log = "(Source Operator ID,Source Operator Name,Source Instance ID)" - log += " --> " - log += "(Destination Operator ID,Destination Operator Name," - log += "Destination Instance ID)" - logger.info(log) - for src_actor_id, dst_actor_id in self.physical_topo.edges: - src_operator_id, src_instance_index = src_actor_id - dst_operator_id, dst_instance_index = dst_actor_id - logger.info("({},{},{}) --> ({},{},{})".format( - src_operator_id, self.env.operators[src_operator_id].name, - src_instance_index, dst_operator_id, - self.env.operators[dst_operator_id].name, dst_instance_index)) - - def build_graph(self): - self.build_channels() - - # to support cyclic reference serialization - try: - ray.register_custom_serializer(Environment, use_pickle=True) - ray.register_custom_serializer(ExecutionGraph, use_pickle=True) - ray.register_custom_serializer(OpType, use_pickle=True) - ray.register_custom_serializer(PStrategy, use_pickle=True) - except Exception: - # local mode can't use pickle - pass - - # Each operator instance is implemented as a Ray actor - # Actors are deployed in topological order, as we traverse the - # logical dataflow from sources to sinks. - for node in nx.topological_sort(self.env.logical_topo): - operator = self.env.operators[node] - # Instantiate Ray actors - handles = self.__generate_actors( - operator, self.input_channels.get(node, []), - self.output_channels.get(node, [])) - if handles: - self.actor_handles.extend(handles) - - def build_channels(self): - self.build_time = int(time.time() * 1000) - # gen auto-incremented unique task id for every operator instance - for node in nx.topological_sort(self.env.logical_topo): - operator = self.env.operators[node] - for i in range(operator.num_instances): - operator_instance_id = (operator.id, i) - self.task_ids[operator_instance_id] = self._gen_task_id() - channels = {} - for node in nx.topological_sort(self.env.logical_topo): - operator = self.env.operators[node] - # Generate downstream data channels - downstream_channels = self._generate_channels(operator) - channels[node] = downstream_channels - # op_id -> channels - input_channels = {} - output_channels = {} - for op_id, all_downstream_channels in channels.items(): - for dst_op_channels in all_downstream_channels.values(): - for c in dst_op_channels: - dst = input_channels.setdefault(c.dst_operator_id, []) - dst.append(c) - src = output_channels.setdefault(c.src_operator_id, []) - src.append(c) - self.input_channels = input_channels - self.output_channels = output_channels - - -# The execution environment for a streaming job -class Environment: - """A streaming environment. - - This class is responsible for constructing the logical and the - physical dataflow. - - Attributes: - logical_topo (DiGraph): The user-defined logical topology in - NetworkX DiGRaph format. - (See: https://networkx.github.io) - physical_topo (DiGraph): The physical topology in NetworkX - DiGRaph format. The physical dataflow is constructed by the - environment based on logical_topo. - operators (dict): A mapping from operator ids to operator metadata - (See: Operator in operator.py). - config (Config): The environment's configuration. - topo_cleaned (bool): A flag that indicates whether the logical - topology is garbage collected (True) or not (False). - actor_handles (list): A list of all Ray actor handles that execute - the streaming dataflow. - """ - - def __init__(self, config=Conf()): - self.logical_topo = nx.DiGraph() # DAG - self.operators = {} # operator id --> operator object - self.config = config # Environment's configuration - self.topo_cleaned = False - self.operator_id_counter = 0 - self.execution_graph = None # set when executed - - def gen_operator_id(self): - op_id = self.operator_id_counter - self.operator_id_counter += 1 - return op_id - - # An edge denotes a flow of data between logical operators - # and may correspond to multiple data channels in the physical dataflow - def _add_edge(self, source, destination): - self.logical_topo.add_edge(source, destination) - - # Cleans the logical dataflow graph to construct and - # deploy the physical dataflow - def _collect_garbage(self): - if self.topo_cleaned is True: - return - for node in self.logical_topo: - self.operators[node]._clean() - self.topo_cleaned = True - - # Sets the level of parallelism for a registered operator - # Overwrites the environment parallelism (if set) - def _set_parallelism(self, operator_id, level_of_parallelism): - self.operators[operator_id].num_instances = level_of_parallelism - - # Sets the same level of parallelism for all operators in the environment - def set_parallelism(self, parallelism): - self.config.parallelism = parallelism - - # Creates and registers a user-defined data source - # TODO (john): There should be different types of sources, e.g. sources - # reading from Kafka, text files, etc. - # TODO (john): Handle case where environment parallelism is set - def source(self, source): - source_id = self.gen_operator_id() - source_stream = DataStream(self, source_id) - self.operators[source_id] = Operator( - source_id, OpType.Source, processor.Source, "Source", logic=source) - return source_stream - - # Creates and registers a new data source that reads a - # text file line by line - # TODO (john): There should be different types of sources, - # e.g. sources reading from Kafka, text files, etc. - # TODO (john): Handle case where environment parallelism is set - def read_text_file(self, filepath): - source_id = self.gen_operator_id() - source_stream = DataStream(self, source_id) - self.operators[source_id] = Operator( - source_id, - OpType.ReadTextFile, - processor.ReadTextFile, - "Read Text File", - other=filepath) - return source_stream - - # Constructs and deploys the physical dataflow - def execute(self): - """Deploys and executes the physical dataflow.""" - self._collect_garbage() # Make sure everything is clean - # TODO (john): Check if dataflow has any 'logical inconsistencies' - # For example, if there is a forward partitioning strategy but - # the number of downstream instances is larger than the number of - # upstream instances, some of the downstream instances will not be - # used at all - - self.execution_graph = ExecutionGraph(self) - self.execution_graph.build_graph() - logger.info("init...") - # init - init_waits = [] - for actor_handle in self.execution_graph.actor_handles: - init_waits.append(actor_handle.init.remote(pickle.dumps(self))) - for wait in init_waits: - assert ray.get(wait) is True - logger.info("running...") - # start - exec_handles = [] - for actor_handle in self.execution_graph.actor_handles: - exec_handles.append(actor_handle.start.remote()) - - return exec_handles - - def wait_finish(self): - for actor_handle in self.execution_graph.actor_handles: - while not ray.get(actor_handle.is_finished.remote()): - time.sleep(1) - - # Prints the logical dataflow graph - def print_logical_graph(self): - self._collect_garbage() - logger.info("==================================") - logger.info("======Logical Dataflow Graph======") - logger.info("==================================") - # Print operators in topological order - for node in nx.topological_sort(self.logical_topo): - downstream_neighbors = list(self.logical_topo.neighbors(node)) - logger.info("======Current Operator======") - operator = self.operators[node] - operator.print() - logger.info("======Downstream Operators======") - if len(downstream_neighbors) == 0: - logger.info("None\n") - for downstream_node in downstream_neighbors: - self.operators[downstream_node].print() - - -# TODO (john): We also need KeyedDataStream and WindowedDataStream as -# subclasses of DataStream to prevent ill-defined logical dataflows - - -# A DataStream corresponds to an edge in the logical dataflow -class DataStream: - """A data stream. - - This class contains all information about a logical stream, i.e. an edge - in the logical topology. It is the main class exposed to the user. - - Attributes: - id (UUID): The id of the stream - env (Environment): The environment the stream belongs to. - src_operator_id (UUID): The id of the source operator of the stream. - dst_operator_id (UUID): The id of the destination operator of the - stream. - is_partitioned (bool): Denotes if there is a partitioning strategy - (e.g. shuffle) for the stream or not (default stategy: Forward). - """ - stream_id_counter = 0 - - def __init__(self, - environment, - source_id=None, - dest_id=None, - is_partitioned=False): - self.env = environment - self.id = DataStream.stream_id_counter - DataStream.stream_id_counter += 1 - self.src_operator_id = source_id - self.dst_operator_id = dest_id - # True if a partitioning strategy for this stream exists, - # false otherwise - self.is_partitioned = is_partitioned - - # Generates a new stream after a data transformation is applied - def __expand(self): - stream = DataStream(self.env) - assert (self.dst_operator_id is not None) - stream.src_operator_id = self.dst_operator_id - stream.dst_operator_id = None - return stream - - # Assigns the partitioning strategy to a new 'open-ended' stream - # and returns the stream. At this point, the partitioning strategy - # is not associated with any destination operator. We expect this to - # be done later, as we continue assembling the dataflow graph - def __partition(self, strategy, partition_fn=None): - scheme = PScheme(strategy, partition_fn) - source_operator = self.env.operators[self.src_operator_id] - new_stream = DataStream( - self.env, source_id=source_operator.id, is_partitioned=True) - source_operator._set_partition_strategy(new_stream.id, scheme) - return new_stream - - # Registers the operator to the environment and returns a new - # 'open-ended' stream. The registered operator serves as the destination - # of the previously 'open' stream - def __register(self, operator): - """Registers the given logical operator to the environment and - connects it to its upstream operator (if any). - - A call to this function adds a new edge to the logical topology. - - Attributes: - operator (Operator): The metadata of the logical operator. - """ - self.env.operators[operator.id] = operator - self.dst_operator_id = operator.id - logger.debug("Adding new dataflow edge ({},{}) --> ({},{})".format( - self.src_operator_id, - self.env.operators[self.src_operator_id].name, - self.dst_operator_id, - self.env.operators[self.dst_operator_id].name)) - # Update logical dataflow graphs - self.env._add_edge(self.src_operator_id, self.dst_operator_id) - # Keep track of the partitioning strategy and the destination operator - src_operator = self.env.operators[self.src_operator_id] - if self.is_partitioned is True: - partitioning, _ = src_operator._get_partition_strategy(self.id) - src_operator._set_partition_strategy(self.id, partitioning, - operator.id) - elif src_operator.type == OpType.KeyBy: - # Set the output partitioning strategy to shuffle by key - partitioning = PScheme(PStrategy.ShuffleByKey) - src_operator._set_partition_strategy(self.id, partitioning, - operator.id) - else: # No partitioning strategy has been defined - set default - partitioning = PScheme(PStrategy.Forward) - src_operator._set_partition_strategy(self.id, partitioning, - operator.id) - return self.__expand() - - # Sets the level of parallelism for an operator, i.e. its total - # number of instances. Each operator instance corresponds to an actor - # in the physical dataflow - def set_parallelism(self, num_instances): - """Sets the number of instances for the source operator of the stream. - - Attributes: - num_instances (int): The level of parallelism for the source - operator of the stream. - """ - assert (num_instances > 0) - self.env._set_parallelism(self.src_operator_id, num_instances) - return self - - # Stream Partitioning Strategies # - # TODO (john): Currently, only forward (default), shuffle, - # and broadcast are supported - - # Hash-based record shuffling - def shuffle(self): - """Registers a shuffling partitioning strategy for the stream.""" - return self.__partition(PStrategy.Shuffle) - - # Broadcasts each record to all downstream instances - def broadcast(self): - """Registers a broadcast partitioning strategy for the stream.""" - return self.__partition(PStrategy.Broadcast) - - # Rescales load to downstream instances - def rescale(self): - """Registers a rescale partitioning strategy for the stream. - - Same as Flink's rescale (see: https://ci.apache.org/projects/flink/ - flink-docs-stable/dev/stream/operators/#physical-partitioning). - """ - return self.__partition(PStrategy.Rescale) - - # Round-robin partitioning - def round_robin(self): - """Registers a round-robin partitioning strategy for the stream.""" - return self.__partition(PStrategy.RoundRobin) - - # User-defined partitioning - def partition(self, partition_fn): - """Registers a user-defined partitioning strategy for the stream. - - Attributes: - partition_fn (function): The user-defined partitioning function. - """ - return self.__partition(PStrategy.Custom, partition_fn) - - # Data Trasnformations # - # TODO (john): Expand set of supported operators. - # TODO (john): To support event-time windows we need a mechanism for - # generating and processing watermarks - - # Registers map operator to the environment - def map(self, map_fn, name="Map"): - """Applies a map operator to the stream. - - Attributes: - map_fn (function): The user-defined logic of the map. - """ - op = Operator( - self.env.gen_operator_id(), - OpType.Map, - processor.Map, - name, - map_fn, - num_instances=self.env.config.parallelism) - return self.__register(op) - - # Registers flatmap operator to the environment - def flat_map(self, flatmap_fn): - """Applies a flatmap operator to the stream. - - Attributes: - flatmap_fn (function): The user-defined logic of the flatmap - (e.g. split()). - """ - op = Operator( - self.env.gen_operator_id(), - OpType.FlatMap, - processor.FlatMap, - "FlatMap", - flatmap_fn, - num_instances=self.env.config.parallelism) - return self.__register(op) - - # Registers keyBy operator to the environment - # TODO (john): This should returned a KeyedDataStream - def key_by(self, key_selector): - """Applies a key_by operator to the stream. - - Attributes: - key_attribute_index (int): The index of the key attributed - (assuming tuple records). - """ - op = Operator( - self.env.gen_operator_id(), - OpType.KeyBy, - processor.KeyBy, - "KeyBy", - other=key_selector, - num_instances=self.env.config.parallelism) - return self.__register(op) - - # Registers Reduce operator to the environment - def reduce(self, reduce_fn): - """Applies a rolling sum operator to the stream. - - Attributes: - sum_attribute_index (int): The index of the attribute to sum - (assuming tuple records). - """ - op = Operator( - self.env.gen_operator_id(), - OpType.Reduce, - processor.Reduce, - "Sum", - reduce_fn, - num_instances=self.env.config.parallelism) - return self.__register(op) - - # Registers Sum operator to the environment - def sum(self, attribute_selector, state_keeper=None): - """Applies a rolling sum operator to the stream. - - Attributes: - sum_attribute_index (int): The index of the attribute to sum - (assuming tuple records). - """ - op = Operator( - self.env.gen_operator_id(), - OpType.Sum, - processor.Reduce, - "Sum", - _sum, - other=attribute_selector, - state_actor=state_keeper, - num_instances=self.env.config.parallelism) - return self.__register(op) - - # Registers window operator to the environment. - # This is a system time window - # TODO (john): This should return a WindowedDataStream - def time_window(self, window_width_ms): - """Applies a system time window to the stream. - - Attributes: - window_width_ms (int): The length of the window in ms. - """ - raise Exception("time_window is unsupported") - - # Registers filter operator to the environment - def filter(self, filter_fn): - """Applies a filter to the stream. - - Attributes: - filter_fn (function): The user-defined filter function. - """ - op = Operator( - self.env.gen_operator_id(), - OpType.Filter, - processor.Filter, - "Filter", - filter_fn, - num_instances=self.env.config.parallelism) - return self.__register(op) - - # TODO (john): Registers window join operator to the environment - def window_join(self, other_stream, join_attribute, window_width): - op = Operator( - self.env.gen_operator_id(), - OpType.WindowJoin, - processor.WindowJoin, - "WindowJoin", - num_instances=self.env.config.parallelism) - return self.__register(op) - - # Registers inspect operator to the environment - def inspect(self, inspect_logic): - """Inspects the content of the stream. - - Attributes: - inspect_logic (function): The user-defined inspect function. - """ - op = Operator( - self.env.gen_operator_id(), - OpType.Inspect, - processor.Inspect, - "Inspect", - inspect_logic, - num_instances=self.env.config.parallelism) - return self.__register(op) - - # Registers sink operator to the environment - # TODO (john): A sink now just drops records but it should be able to - # export data to other systems - def sink(self): - """Closes the stream with a sink operator.""" - op = Operator( - self.env.gen_operator_id(), - OpType.Sink, - processor.Sink, - "Sink", - num_instances=self.env.config.parallelism) - return self.__register(op) diff --git a/streaming/python/tests/test_function.py b/streaming/python/tests/test_function.py new file mode 100644 index 000000000..3564a1698 --- /dev/null +++ b/streaming/python/tests/test_function.py @@ -0,0 +1,22 @@ +from ray.streaming import function +from ray.streaming.runtime import gateway_client + + +def test_get_simple_function_class(): + simple_map_func_class = function._get_simple_function_class( + function.MapFunction) + assert simple_map_func_class is function.SimpleMapFunction + + +class MapFunc(function.MapFunction): + def map(self, value): + return str(value) + + +def test_load_function(): + # function_bytes, module_name, class_name, function_name, + # function_interface + descriptor_func_bytes = gateway_client.serialize( + [None, __name__, MapFunc.__name__, None, "MapFunction"]) + func = function.load_function(descriptor_func_bytes) + assert type(func) is MapFunc diff --git a/streaming/python/tests/test_logical_graph.py b/streaming/python/tests/test_logical_graph.py deleted file mode 100644 index 2e13723df..000000000 --- a/streaming/python/tests/test_logical_graph.py +++ /dev/null @@ -1,206 +0,0 @@ -from ray.streaming.streaming import Environment, ExecutionGraph -from ray.streaming.operator import OpType, PStrategy - - -def test_parallelism(): - """Tests operator parallelism.""" - env = Environment() - # Try setting a common parallelism for all operators - env.set_parallelism(2) - stream = env.source(None).map(None).filter(None).flat_map(None) - env._collect_garbage() - for operator in env.operators.values(): - if operator.type == OpType.Source: - # TODO (john): Currently each source has only one instance - assert operator.num_instances == 1, (operator.num_instances, 1) - else: - assert operator.num_instances == 2, (operator.num_instances, 2) - # Check again after adding an operator with different parallelism - stream.map(None, "Map1").shuffle().set_parallelism(3).map( - None, "Map2").set_parallelism(4) - env._collect_garbage() - for operator in env.operators.values(): - if operator.type == OpType.Source: - assert operator.num_instances == 1, (operator.num_instances, 1) - elif operator.name != "Map1" and operator.name != "Map2": - assert operator.num_instances == 2, (operator.num_instances, 2) - elif operator.name != "Map2": - assert operator.num_instances == 3, (operator.num_instances, 3) - else: - assert operator.num_instances == 4, (operator.num_instances, 4) - - -def test_partitioning(): - """Tests stream partitioning.""" - env = Environment() - # Try defining multiple partitioning strategies for the same stream - _ = env.source(None).shuffle().rescale().broadcast().map( - None).broadcast().shuffle() - env._collect_garbage() - for operator in env.operators.values(): - p_schemes = operator.partitioning_strategies - for scheme in p_schemes.values(): - # Only last defined strategy should be kept - if operator.type == OpType.Source: - assert scheme.strategy == PStrategy.Broadcast, ( - scheme.strategy, PStrategy.Broadcast) - else: - assert scheme.strategy == PStrategy.Shuffle, ( - scheme.strategy, PStrategy.Shuffle) - - -def test_forking(): - """Tests stream forking.""" - env = Environment() - # Try forking a stream - stream = env.source(None).map(None).set_parallelism(2) - # First branch with a shuffle partitioning strategy - _ = stream.shuffle().key_by(0).sum(1) - # Second branch with the default partitioning strategy - _ = stream.key_by(1).sum(2) - env._collect_garbage() - # Operator ids - source_id = None - map_id = None - keyby1_id = None - keyby2_id = None - sum1_id = None - sum2_id = None - # Collect ids - for id, operator in env.operators.items(): - if operator.type == OpType.Source: - source_id = id - elif operator.type == OpType.Map: - map_id = id - elif operator.type == OpType.KeyBy: - if operator.other_args == 0: - keyby1_id = id - else: - assert operator.other_args == 1, (operator.other_args, 1) - keyby2_id = id - elif operator.type == OpType.Sum: - if operator.other_args == 1: - sum1_id = id - else: - assert operator.other_args == 2, (operator.other_args, 2) - sum2_id = id - # Check generated streams and their partitioning - for source, destination in env.logical_topo.edges: - operator = env.operators[source] - if source == source_id: - assert destination == map_id, (destination, map_id) - elif source == map_id: - p_scheme = operator.partitioning_strategies[destination] - strategy = p_scheme.strategy - key_index = env.operators[destination].other_args - if key_index == 0: # This must be the first branch - assert strategy == PStrategy.Shuffle, (strategy, - PStrategy.Shuffle) - assert destination == keyby1_id, (destination, keyby1_id) - else: # This must be the second branch - assert key_index == 1, (key_index, 1) - assert strategy == PStrategy.Forward, (strategy, - PStrategy.Forward) - assert destination == keyby2_id, (destination, keyby2_id) - elif source == keyby1_id or source == keyby2_id: - p_scheme = operator.partitioning_strategies[destination] - strategy = p_scheme.strategy - key_index = env.operators[destination].other_args - if key_index == 1: # This must be the first branch - assert strategy == PStrategy.ShuffleByKey, ( - strategy, PStrategy.ShuffleByKey) - assert destination == sum1_id, (destination, sum1_id) - else: # This must be the second branch - assert key_index == 2, (key_index, 2) - assert strategy == PStrategy.ShuffleByKey, ( - strategy, PStrategy.ShuffleByKey) - assert destination == sum2_id, (destination, sum2_id) - else: # This must be a sum operator - assert operator.type == OpType.Sum, (operator.type, OpType.Sum) - - -def _test_shuffle_channels(): - """Tests shuffling connectivity.""" - env = Environment() - # Try defining a shuffle - _ = env.source(None).shuffle().map(None).set_parallelism(4) - expected = [(0, 0), (0, 1), (0, 2), (0, 3)] - _test_channels(env, expected) - - -def _test_forward_channels(): - """Tests forward connectivity.""" - env = Environment() - # Try the default partitioning strategy - _ = env.source(None).set_parallelism(4).map(None).set_parallelism(2) - expected = [(0, 0), (1, 1), (2, 0), (3, 1)] - _test_channels(env, expected) - - -def _test_broadcast_channels(): - """Tests broadcast connectivity.""" - env = Environment() - # Try broadcasting - _ = env.source(None).set_parallelism(4).broadcast().map( - None).set_parallelism(2) - expected = [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (3, 0), (3, 1)] - _test_channels(env, expected) - - -def _test_round_robin_channels(): - """Tests round-robin connectivity.""" - env = Environment() - # Try broadcasting - _ = env.source(None).round_robin().map(None).set_parallelism(2) - expected = [(0, 0), (0, 1)] - _test_channels(env, expected) - - -def _test_channels(environment, expected_channels): - """Tests operator connectivity.""" - environment._collect_garbage() - map_id = None - # Get id - for id, operator in environment.operators.items(): - if operator.type == OpType.Map: - map_id = id - # Collect channels - environment.execution_graph = ExecutionGraph(environment) - environment.execution_graph.build_channels() - channels_per_destination = [] - for operator in environment.operators.values(): - channels_per_destination.append( - environment.execution_graph._generate_channels(operator)) - # Check actual connectivity - actual = [] - for destination in channels_per_destination: - for channels in destination.values(): - for channel in channels: - src_instance_index = channel.src_instance_index - dst_instance_index = channel.dst_instance_index - connection = (src_instance_index, dst_instance_index) - assert channel.dst_operator_id == map_id, ( - channel.dst_operator_id, map_id) - actual.append(connection) - # Make sure connections are as expected - set_1 = set(expected_channels) - set_2 = set(actual) - assert set_1 == set_2, (set_1, set_2) - - -def test_channel_generation(): - """Tests data channel generation.""" - _test_shuffle_channels() - _test_broadcast_channels() - _test_round_robin_channels() - _test_forward_channels() - - -# TODO (john): Add simple wordcount test -def test_wordcount(): - """Tests a simple streaming wordcount.""" - pass - - -if __name__ == "__main__": - test_channel_generation() diff --git a/streaming/python/tests/test_operator.py b/streaming/python/tests/test_operator.py new file mode 100644 index 000000000..2a72aa253 --- /dev/null +++ b/streaming/python/tests/test_operator.py @@ -0,0 +1,8 @@ +from ray.streaming import operator +from ray.streaming import function + + +def test_create_operator(): + map_func = function.SimpleMapFunction(lambda x: x) + map_operator = operator.create_operator(map_func) + assert type(map_operator) is operator.MapOperator diff --git a/streaming/python/tests/test_word_count.py b/streaming/python/tests/test_word_count.py index 9fc2f2e11..5758b9336 100644 --- a/streaming/python/tests/test_word_count.py +++ b/streaming/python/tests/test_word_count.py @@ -1,18 +1,23 @@ import ray -from ray.streaming.config import Config -from ray.streaming.streaming import Environment, Conf +from ray.streaming import StreamingContext def test_word_count(): - ray.init() - env = Environment(config=Conf(channel_type=Config.NATIVE_CHANNEL)) - env.read_text_file(__file__) \ + ray.init(load_code_from_local=True, include_java=True) + ctx = StreamingContext.Builder() \ + .build() + ctx.read_text_file(__file__) \ .set_parallelism(1) \ - .filter(lambda x: "word" in x) \ - .inspect(lambda x: print("result", x)) - env_handle = env.execute() - ray.get(env_handle) # Stay alive until execution finishes - env.wait_finish() + .flat_map(lambda x: x.split()) \ + .map(lambda x: (x, 1)) \ + .key_by(lambda x: x[0]) \ + .reduce(lambda old_value, new_value: + (old_value[0], old_value[1] + new_value[1])) \ + .filter(lambda x: "ray" not in x) \ + .sink(lambda x: print("result", x)) + ctx.submit("word_count") + import time + time.sleep(3) ray.shutdown() diff --git a/streaming/src/config/streaming_config.cc b/streaming/src/config/streaming_config.cc index 8fa21dc0c..c668c9b79 100644 --- a/streaming/src/config/streaming_config.cc +++ b/streaming/src/config/streaming_config.cc @@ -28,8 +28,8 @@ void StreamingConfig::FromProto(const uint8_t *data, uint32_t size) { if (!config.op_name().empty()) { SetOpName(config.op_name()); } - if (config.role() != proto::OperatorType::UNKNOWN) { - SetOperatorType(config.role()); + if (config.role() != proto::NodeType::UNKNOWN) { + SetNodeType(config.role()); } if (config.ring_buffer_capacity() != 0) { SetRingBufferCapacity(config.ring_buffer_capacity()); diff --git a/streaming/src/config/streaming_config.h b/streaming/src/config/streaming_config.h index add9a8d56..474d7571e 100644 --- a/streaming/src/config/streaming_config.h +++ b/streaming/src/config/streaming_config.h @@ -22,8 +22,7 @@ class StreamingConfig { uint32_t empty_message_time_interval_ = DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL; - streaming::proto::OperatorType operator_type_ = - streaming::proto::OperatorType::TRANSFORM; + streaming::proto::NodeType node_type_ = streaming::proto::NodeType::TRANSFORM; std::string job_name_ = "DEFAULT_JOB_NAME"; @@ -55,7 +54,7 @@ class StreamingConfig { DECL_GET_SET_PROPERTY(const std::string &, WorkerName, worker_name_) DECL_GET_SET_PROPERTY(const std::string &, OpName, op_name_) DECL_GET_SET_PROPERTY(uint32_t, EmptyMessageTimeInterval, empty_message_time_interval_) - DECL_GET_SET_PROPERTY(streaming::proto::OperatorType, OperatorType, operator_type_) + DECL_GET_SET_PROPERTY(streaming::proto::NodeType, NodeType, node_type_) DECL_GET_SET_PROPERTY(const std::string &, JobName, job_name_) DECL_GET_SET_PROPERTY(uint32_t, WriterConsumedStep, writer_consumed_step_) DECL_GET_SET_PROPERTY(uint32_t, ReaderConsumedStep, reader_consumed_step_) diff --git a/streaming/src/protobuf/remote_call.proto b/streaming/src/protobuf/remote_call.proto new file mode 100644 index 000000000..bb952b624 --- /dev/null +++ b/streaming/src/protobuf/remote_call.proto @@ -0,0 +1,59 @@ +syntax = "proto3"; + +package ray.streaming.proto; + +import "streaming/src/protobuf/streaming.proto"; + +option java_package = "org.ray.streaming.runtime.generated"; + +// Streaming execution graph +message ExecutionGraph { + // A parallel operation consisting of multiple execution tasks + message ExecutionNode { + int32 node_id = 1; + int32 parallelism = 2; + NodeType node_type = 3; + Language language = 4; + // serialized user function + bytes function = 5; + repeated ExecutionTask execution_tasks = 6; + repeated ExecutionEdge input_edges = 7; + repeated ExecutionEdge output_edges = 8; + } + + // execution edge + message ExecutionEdge { + // upstream execution node id + int32 src_node_id = 1; + // downstream execution node id + int32 target_node_id = 2; + // serialized partition between src/target node + bytes partition = 3; + } + + // a parallel subtask of the execution + message ExecutionTask { + // unique execution task id + int32 task_id = 1; + // an ordered task index range from 0 to parallelism - 1 + int32 task_index = 2; + // serialized actor handle + bytes worker_actor = 3; + } + + // graph build time + uint64 build_time = 1; + repeated ExecutionNode execution_nodes = 2; +} + +// Streaming worker context +message WorkerContext { + // job name + string job_name = 1; + // unique execution task id + int32 task_id = 2; + // job config + map conf = 3; + // execution graph + ExecutionGraph graph = 4; +} diff --git a/streaming/src/protobuf/streaming.proto b/streaming/src/protobuf/streaming.proto index caf2cc865..3c68c51e6 100644 --- a/streaming/src/protobuf/streaming.proto +++ b/streaming/src/protobuf/streaming.proto @@ -4,10 +4,19 @@ package ray.streaming.proto; option java_package = "org.ray.streaming.runtime.generated"; -enum OperatorType { +enum Language { + JAVA = 0; + PYTHON = 1; +} + +enum NodeType { UNKNOWN = 0; - TRANSFORM = 1; - SOURCE = 2; + // Sources are where your program reads its input from + SOURCE = 1; + // Transform one or more DataStreams into a new DataStream. + TRANSFORM = 2; + // Sinks consume DataStreams and forward them to files, sockets, external + // systems, or print them. SINK = 3; } @@ -23,7 +32,7 @@ message StreamingConfig { string task_job_id = 2; string worker_name = 3; string op_name = 4; - OperatorType role = 5; + NodeType role = 5; uint32 ring_buffer_capacity = 6; uint32 empty_message_interval = 7; FlowControlType flow_control_type = 8;