diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java index 5f1ab4d4e..c18fb7cef 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java @@ -58,6 +58,7 @@ public class StreamingContext implements Serializable { JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(this.streamSinks, jobName); this.jobGraph = jobGraphBuilder.build(); jobGraph.printJobGraph(); + LOG.info("JobGraph digraph\n{}", jobGraph.generateDigraph()); if (Ray.internal() == null) { if (Config.MEMORY_CHANNEL.equalsIgnoreCase(jobConfig.get(Config.CHANNEL_TYPE))) { diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/Functions.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/Functions.java index 3472da79e..94441076f 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/Functions.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/Functions.java @@ -37,4 +37,8 @@ public class Functions { } } + public static RichFunction emptyFunction() { + return new DefaultRichFunction(null); + } + } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java index fe0a3af1b..bd90b0e25 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java @@ -1,6 +1,5 @@ package io.ray.streaming.api.stream; - import io.ray.streaming.api.Language; import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.function.impl.FilterFunction; @@ -17,9 +16,13 @@ import io.ray.streaming.operator.impl.KeyByOperator; import io.ray.streaming.operator.impl.MapOperator; import io.ray.streaming.operator.impl.SinkOperator; import io.ray.streaming.python.stream.PythonDataStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; /** * Represents a stream of data. + * *

This class defines all the streaming operations. * * @param Type of data in the stream. @@ -81,13 +84,36 @@ public class DataStream extends Stream, T> { } /** - * Apply a union transformation to this stream, with another stream. + * Apply union transformations to this stream by merging {@link DataStream} outputs of + * the same type with each other. * - * @param other Another stream. + * @param stream The DataStream to union output with. + * @param others The other DataStreams to union output with. * @return A new UnionStream. */ - public UnionStream union(DataStream other) { - return new UnionStream<>(this, null, other); + @SafeVarargs + public final DataStream union(DataStream stream, DataStream... others) { + List> streams = new ArrayList<>(); + streams.add(stream); + streams.addAll(Arrays.asList(others)); + return union(streams); + } + + /** + * Apply union transformations to this stream by merging {@link DataStream} outputs of + * the same type with each other. + * + * @param streams The DataStreams to union output with. + * @return A new UnionStream. + */ + public final DataStream union(List> streams) { + if (this instanceof UnionStream) { + UnionStream unionStream = (UnionStream) this; + streams.forEach(unionStream::addStream); + return unionStream; + } else { + return new UnionStream<>(this, streams); + } } /** diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java index 6dd559ce7..833cddaa8 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java @@ -1,22 +1,32 @@ package io.ray.streaming.api.stream; -import io.ray.streaming.operator.StreamOperator; +import io.ray.streaming.operator.impl.UnionOperator; import java.util.ArrayList; import java.util.List; /** * Represents a union DataStream. * + *

This stream does not create a physical operation, it only affects how upstream data are + * connected to downstream data. + * * @param The type of union data. */ public class UnionStream extends DataStream { - private List> unionStreams; - public UnionStream(DataStream input, StreamOperator streamOperator, DataStream other) { - super(input, streamOperator); + public UnionStream(DataStream input, List> streams) { + super(input, new UnionOperator()); this.unionStreams = new ArrayList<>(); - this.unionStreams.add(other); + streams.forEach(this::addStream); + } + + void addStream(DataStream stream) { + if (stream instanceof UnionStream) { + this.unionStreams.addAll(((UnionStream) stream).getUnionStreams()); + } else { + this.unionStreams.add(stream); + } } public List> getUnionStreams() { diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java index 30e26ce9b..2a2d02ebf 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java @@ -5,8 +5,11 @@ import io.ray.streaming.api.stream.DataStream; import io.ray.streaming.api.stream.Stream; import io.ray.streaming.api.stream.StreamSink; import io.ray.streaming.api.stream.StreamSource; +import io.ray.streaming.api.stream.UnionStream; import io.ray.streaming.operator.StreamOperator; import io.ray.streaming.python.stream.PythonDataStream; +import io.ray.streaming.python.stream.PythonUnionStream; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -44,6 +47,7 @@ public class JobGraphBuilder { return this.jobGraph; } + @SuppressWarnings("unchecked") private void processStream(Stream stream) { while (stream.isProxyStream()) { // Proxy stream and original stream are the same logical stream, both refer to the @@ -74,6 +78,20 @@ public class JobGraphBuilder { JobEdge jobEdge = new JobEdge(inputVertexId, vertexId, parentStream.getPartition()); this.jobGraph.addEdge(jobEdge); processStream(parentStream); + + // process union stream + List streams = new ArrayList<>(); + if (stream instanceof UnionStream) { + streams.addAll(((UnionStream) stream).getUnionStreams()); + } + if (stream instanceof PythonUnionStream) { + streams.addAll(((PythonUnionStream) stream).getUnionStreams()); + } + for (Stream otherStream : streams) { + JobEdge otherEdge = new JobEdge(otherStream.getId(), vertexId, otherStream.getPartition()); + this.jobGraph.addEdge(otherEdge); + processStream(otherStream); + } } else { throw new UnsupportedOperationException("Unsupported stream: " + stream); } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/UnionOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/UnionOperator.java new file mode 100644 index 000000000..c3467582f --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/UnionOperator.java @@ -0,0 +1,21 @@ +package io.ray.streaming.operator.impl; + +import io.ray.streaming.api.function.Function; +import io.ray.streaming.api.function.internal.Functions; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.OneInputOperator; +import io.ray.streaming.operator.StreamOperator; + +public class UnionOperator extends StreamOperator implements + OneInputOperator { + + public UnionOperator() { + super(Functions.emptyFunction()); + } + + @Override + public void processElement(Record record) { + collect(record); + } + +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java index 21533de70..aac706d2f 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java @@ -2,6 +2,7 @@ package io.ray.streaming.python; import com.google.common.base.Preconditions; import io.ray.streaming.api.function.Function; +import java.util.StringJoiner; import org.apache.commons.lang3.StringUtils; /** @@ -99,4 +100,17 @@ public class PythonFunction implements Function { return functionInterface; } + @Override + public String toString() { + StringJoiner stringJoiner = new StringJoiner(", ", + PythonFunction.class.getSimpleName() + "[", "]"); + if (function != null) { + stringJoiner.add("function=binary function"); + } else { + stringJoiner.add("moduleName='" + moduleName + "'") + .add("functionName='" + functionName + "'"); + } + stringJoiner.add("functionInterface='" + functionInterface + "'"); + return stringJoiner.toString(); + } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonOperator.java index d0eec6e3b..045814c7e 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonOperator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonOperator.java @@ -5,15 +5,26 @@ import io.ray.streaming.api.context.RuntimeContext; import io.ray.streaming.operator.OperatorType; import io.ray.streaming.operator.StreamOperator; import java.util.List; +import java.util.StringJoiner; /** * Represents a {@link StreamOperator} that wraps python {@link PythonFunction}. */ @SuppressWarnings("unchecked") public class PythonOperator extends StreamOperator { + private final String moduleName; + private final String className; + + public PythonOperator(String moduleName, String className) { + super(null); + this.moduleName = moduleName; + this.className = className; + } public PythonOperator(PythonFunction function) { super(function); + this.moduleName = null; + this.className = null; } @Override @@ -44,4 +55,25 @@ public class PythonOperator extends StreamOperator { public Language getLanguage() { return Language.PYTHON; } + + public String getModuleName() { + return moduleName; + } + + public String getClassName() { + return className; + } + + @Override + public String toString() { + StringJoiner stringJoiner = new StringJoiner(", ", + PythonOperator.class.getSimpleName() + "[", "]"); + if (function != null) { + stringJoiner.add("function='" + function + "'"); + } else { + stringJoiner.add("moduleName='" + moduleName + "'") + .add("className='" + className + "'"); + } + return stringJoiner.toString(); + } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java index 6d8de051f..9f3bcd7a1 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java @@ -2,6 +2,7 @@ package io.ray.streaming.python; import com.google.common.base.Preconditions; import io.ray.streaming.api.partition.Partition; +import java.util.StringJoiner; import org.apache.commons.lang3.StringUtils; /** @@ -35,6 +36,7 @@ public class PythonPartition implements Partition { /** * Create a python partition from a moduleName and partition function name + * * @param moduleName module name of python partition * @param functionName function/class name of the partition function. */ @@ -63,4 +65,18 @@ public class PythonPartition implements Partition { public String getFunctionName() { return functionName; } + + @Override + public String toString() { + StringJoiner stringJoiner = new StringJoiner(", ", + PythonPartition.class.getSimpleName() + "[", "]"); + if (partition != null) { + stringJoiner.add("partition=binary partition"); + } else { + stringJoiner.add("moduleName='" + moduleName + "'") + .add("functionName='" + functionName + "'"); + } + return stringJoiner.toString(); + } + } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonDataStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonDataStream.java index 8911fde13..f7a3e228f 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonDataStream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonDataStream.java @@ -9,6 +9,9 @@ import io.ray.streaming.python.PythonFunction; import io.ray.streaming.python.PythonFunction.FunctionInterface; import io.ray.streaming.python.PythonOperator; import io.ray.streaming.python.PythonPartition; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; /** * Represents a stream of data whose transformations will be executed in python. @@ -90,6 +93,38 @@ public class PythonDataStream extends Stream implement return new PythonDataStream(this, new PythonOperator(func)); } + /** + * Apply union transformations to this stream by merging {@link PythonDataStream} outputs of + * the same type with each other. + * + * @param stream The DataStream to union output with. + * @param others The other DataStreams to union output with. + * @return A new UnionStream. + */ + public final PythonDataStream union(PythonDataStream stream, PythonDataStream... others) { + List streams = new ArrayList<>(); + streams.add(stream); + streams.addAll(Arrays.asList(others)); + return union(streams); + } + + /** + * Apply union transformations to this stream by merging {@link PythonDataStream} outputs of + * the same type with each other. + * + * @param streams The DataStreams to union output with. + * @return A new UnionStream. + */ + public final PythonDataStream union(List streams) { + if (this instanceof PythonUnionStream) { + PythonUnionStream unionStream = (PythonUnionStream) this; + streams.forEach(unionStream::addStream); + return unionStream; + } else { + return new PythonUnionStream(this, streams); + } + } + public PythonStreamSink sink(String moduleName, String funcName) { return sink(new PythonFunction(moduleName, funcName)); } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonUnionStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonUnionStream.java new file mode 100644 index 000000000..01f9087d5 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonUnionStream.java @@ -0,0 +1,34 @@ +package io.ray.streaming.python.stream; + +import io.ray.streaming.python.PythonOperator; +import java.util.ArrayList; +import java.util.List; + +/** + * Represents a union DataStream. + * + *

This stream does not create a physical operation, it only affects how upstream data are + * connected to downstream data. + */ +public class PythonUnionStream extends PythonDataStream { + private List unionStreams; + + public PythonUnionStream(PythonDataStream input, List others) { + super(input, new PythonOperator( + "ray.streaming.operator", "UnionOperator")); + this.unionStreams = new ArrayList<>(); + others.forEach(this::addStream); + } + + void addStream(PythonDataStream stream) { + if (stream instanceof PythonUnionStream) { + this.unionStreams.addAll(((PythonUnionStream) stream).getUnionStreams()); + } else { + this.unionStreams.add(stream); + } + } + + public List getUnionStreams() { + return unionStreams; + } +} \ No newline at end of file diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java index b79926490..4a39eadab 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java @@ -4,7 +4,9 @@ import com.google.protobuf.ByteString; import io.ray.runtime.actor.NativeRayActor; import io.ray.streaming.api.function.Function; import io.ray.streaming.api.partition.Partition; +import io.ray.streaming.operator.Operator; import io.ray.streaming.python.PythonFunction; +import io.ray.streaming.python.PythonOperator; import io.ray.streaming.python.PythonPartition; import io.ray.streaming.runtime.core.graph.ExecutionEdge; import io.ray.streaming.runtime.core.graph.ExecutionGraph; @@ -34,8 +36,8 @@ public class GraphPbBuilder { nodeBuilder.setNodeType( Streaming.NodeType.valueOf(node.getNodeType().name())); nodeBuilder.setLanguage(Streaming.Language.valueOf(node.getLanguage().name())); - byte[] functionBytes = serializeFunction(node.getStreamOperator().getFunction()); - nodeBuilder.setFunction(ByteString.copyFrom(functionBytes)); + byte[] operatorBytes = serializeOperator(node.getStreamOperator()); + nodeBuilder.setOperator(ByteString.copyFrom(operatorBytes)); // build tasks for (ExecutionTask task : node.getExecutionTasks()) { @@ -72,6 +74,19 @@ public class GraphPbBuilder { return edgeBuilder.build(); } + private byte[] serializeOperator(Operator operator) { + if (operator instanceof PythonOperator) { + PythonOperator pythonOperator = (PythonOperator) operator; + return serializer.serialize(Arrays.asList( + serializeFunction(pythonOperator.getFunction()), + pythonOperator.getModuleName(), + pythonOperator.getClassName() + )); + } else { + return new byte[0]; + } + } + private byte[] serializeFunction(Function function) { if (function instanceof PythonFunction) { PythonFunction pyFunc = (PythonFunction) function; diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java index b5ca58e78..2f74587de 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java @@ -3,8 +3,11 @@ package io.ray.streaming.runtime.python; import com.google.common.base.Preconditions; import com.google.common.primitives.Primitives; import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.stream.DataStream; +import io.ray.streaming.api.stream.Stream; import io.ray.streaming.python.PythonFunction; import io.ray.streaming.python.PythonPartition; +import io.ray.streaming.python.stream.PythonDataStream; import io.ray.streaming.python.stream.PythonStreamSource; import io.ray.streaming.runtime.serialization.MsgPackSerializer; import io.ray.streaming.runtime.util.ReflectionUtils; @@ -99,6 +102,26 @@ public class PythonGateway { return serializer.serialize(getReferenceId(partition)); } + public byte[] union(byte[] paramsBytes) { + List streams = (List) serializer.deserialize(paramsBytes); + streams = processParameters(streams); + LOG.info("Call union with streams {}", streams); + Preconditions.checkArgument(streams.size() >= 2, + "Union needs at least two streams"); + Stream unionStream; + Stream stream1 = (Stream) streams.get(0); + List otherStreams = streams.subList(1, streams.size()); + if (stream1 instanceof DataStream) { + DataStream dataStream = (DataStream) stream1; + unionStream = dataStream.union(otherStreams); + } else { + Preconditions.checkArgument(stream1 instanceof PythonDataStream); + PythonDataStream pythonDataStream = (PythonDataStream) stream1; + unionStream = pythonDataStream.union(otherStreams); + } + return serialize(unionStream); + } + public byte[] callFunction(byte[] paramsBytes) { try { List params = (List) serializer.deserialize(paramsBytes); @@ -111,12 +134,7 @@ public class PythonGateway { .map(Object::getClass).toArray(Class[]::new); Method method = findMethod(clz, funcName, paramsTypes); Object result = method.invoke(null, params.subList(2, params.size()).toArray()); - if (returnReference(result)) { - referenceMap.put(getReferenceId(result), result); - return serializer.serialize(getReferenceId(result)); - } else { - return serializer.serialize(result); - } + return serialize(result); } catch (Exception e) { throw new RuntimeException(e); } @@ -134,12 +152,7 @@ public class PythonGateway { .map(Object::getClass).toArray(Class[]::new); Method method = findMethod(clz, methodName, paramsTypes); Object result = method.invoke(obj, params.subList(2, params.size()).toArray()); - if (returnReference(result)) { - referenceMap.put(getReferenceId(result), result); - return serializer.serialize(getReferenceId(result)); - } else { - return serializer.serialize(result); - } + return serialize(result); } catch (Exception e) { throw new RuntimeException(e); } @@ -179,6 +192,15 @@ public class PythonGateway { return any.get(); } + private byte[] serialize(Object value) { + if (returnReference(value)) { + referenceMap.put(getReferenceId(value), value); + return serializer.serialize(getReferenceId(value)); + } else { + return serializer.serialize(value); + } + } + private static boolean returnReference(Object value) { if (isBasic(value)) { return false; diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/HybridStreamTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/HybridStreamTest.java index 025f67e21..5bdb530f6 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/HybridStreamTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/HybridStreamTest.java @@ -4,16 +4,22 @@ import io.ray.api.Ray; import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.function.impl.FilterFunction; import io.ray.streaming.api.function.impl.MapFunction; +import io.ray.streaming.api.function.impl.SinkFunction; import io.ray.streaming.api.stream.DataStreamSource; import io.ray.streaming.runtime.BaseUnitTest; +import java.io.IOException; import java.io.Serializable; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; import java.util.Arrays; import java.util.concurrent.TimeUnit; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testng.Assert; import org.testng.annotations.Test; -public class HybridStreamTest extends BaseUnitTest implements Serializable { +public class HybridStreamTest { private static final Logger LOG = LoggerFactory.getLogger(HybridStreamTest.class); public static class Mapper1 implements MapFunction { @@ -34,9 +40,12 @@ public class HybridStreamTest extends BaseUnitTest implements Serializable { } } - @Test - public void testHybridDataStream() throws InterruptedException { + @Test(timeOut = 60000) + public void testHybridDataStream() throws Exception { Ray.shutdown(); + String sinkFileName = "/tmp/testHybridDataStream.txt"; + Files.deleteIfExists(Paths.get(sinkFileName)); + StreamingContext context = StreamingContext.buildContext(); DataStreamSource streamSource = DataStreamSource.fromCollection(context, Arrays.asList("a", "b", "c")); @@ -46,9 +55,38 @@ public class HybridStreamTest extends BaseUnitTest implements Serializable { .map("ray.streaming.tests.test_hybrid_stream", "map_func1") .filter("ray.streaming.tests.test_hybrid_stream", "filter_func1") .asJavaStream() - .sink(x -> System.out.println("HybridStreamTest: " + x)); + .sink((SinkFunction) value -> { + LOG.info("HybridStreamTest: {}", value); + try { + if (!Files.exists(Paths.get(sinkFileName))) { + Files.createFile(Paths.get(sinkFileName)); + } + Files.write(Paths.get(sinkFileName), value.toString().getBytes(), + StandardOpenOption.APPEND); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); context.execute("HybridStreamTestJob"); + int sleptTime = 0; TimeUnit.SECONDS.sleep(3); + while (true) { + if (Files.exists(Paths.get(sinkFileName))) { + TimeUnit.SECONDS.sleep(3); + String text = String.join(", ", Files.readAllLines(Paths.get(sinkFileName))); + Assert.assertTrue(text.contains("a")); + Assert.assertFalse(text.contains("b")); + Assert.assertTrue(text.contains("c")); + LOG.info("Execution succeed"); + break; + } + sleptTime += 1; + if (sleptTime >= 60) { + throw new RuntimeException("Execution not finished"); + } + LOG.info("Wait finish..."); + TimeUnit.SECONDS.sleep(1); + } context.stop(); LOG.info("HybridStreamTest succeed"); } diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/UnionStreamTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/UnionStreamTest.java new file mode 100644 index 000000000..f3d46d826 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/UnionStreamTest.java @@ -0,0 +1,71 @@ +package io.ray.streaming.runtime.demo; + +import io.ray.api.Ray; +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.function.impl.SinkFunction; +import io.ray.streaming.api.stream.DataStreamSource; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.Arrays; +import java.util.concurrent.TimeUnit; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class UnionStreamTest { + private static final Logger LOG = LoggerFactory.getLogger( UnionStreamTest.class ); + + @Test(timeOut = 60000) + public void testUnionStream() throws Exception { + Ray.shutdown(); + String sinkFileName = "/tmp/testUnionStream.txt"; + Files.deleteIfExists(Paths.get(sinkFileName)); + + StreamingContext context = StreamingContext.buildContext(); + DataStreamSource streamSource1 = + DataStreamSource.fromCollection(context, Arrays.asList(1, 1)); + DataStreamSource streamSource2 = + DataStreamSource.fromCollection(context, Arrays.asList(1, 1)); + DataStreamSource streamSource3 = + DataStreamSource.fromCollection(context, Arrays.asList(1, 1)); + streamSource1 + .union(streamSource2, streamSource3) + .sink((SinkFunction) value -> { + LOG.info("UnionStreamTest: {}", value); + try { + if (!Files.exists(Paths.get(sinkFileName))) { + Files.createFile(Paths.get(sinkFileName)); + } + Files.write(Paths.get(sinkFileName), value.toString().getBytes(), + StandardOpenOption.APPEND); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + context.execute("UnionStreamTest"); + int sleptTime = 0; + TimeUnit.SECONDS.sleep(3); + while (true) { + if (Files.exists(Paths.get(sinkFileName))) { + TimeUnit.SECONDS.sleep(3); + String text = String.join(", ", Files.readAllLines(Paths.get(sinkFileName))); + Assert.assertEquals(text, StringUtils.repeat("1", 6)); + LOG.info("Execution succeed"); + break; + } + sleptTime += 1; + if (sleptTime >= 60) { + throw new RuntimeException("Execution not finished"); + } + LOG.info("Wait finish..."); + TimeUnit.SECONDS.sleep(1); + } + context.stop(); + LOG.info("HybridStreamTest succeed"); + } + +} diff --git a/streaming/python/context.py b/streaming/python/context.py index 8c3b9a5f6..074f82524 100644 --- a/streaming/python/context.py +++ b/streaming/python/context.py @@ -33,7 +33,7 @@ class StreamingContext: self """ if key is not None: - assert value + assert value is not None self._options[key] = str(value) if conf is not None: for k, v in conf.items(): diff --git a/streaming/python/datastream.py b/streaming/python/datastream.py index 21837e072..0fc50ea4b 100644 --- a/streaming/python/datastream.py +++ b/streaming/python/datastream.py @@ -71,7 +71,6 @@ class Stream(ABC): self """ if key is not None: - assert value assert type(key) is str assert type(value) is str self._gateway_client(). \ @@ -183,6 +182,21 @@ class DataStream(Stream): call_method(self._j_stream, "filter", j_func) return DataStream(self, j_stream) + def union(self, *streams): + """Apply union transformations to this stream by merging data stream + outputs of the same type with each other. + + Args: + *streams: The DataStreams to union output with. + + Returns: + A new UnionStream. + """ + assert len(streams) >= 1, "Need at least one stream to union with" + j_streams = [s._j_stream for s in streams] + j_stream = self._gateway_client().union(self._j_stream, *j_streams) + return UnionStream(self, j_stream) + def key_by(self, func): """ Creates a new :class:`KeyDataStream` that uses the provided key to @@ -308,6 +322,13 @@ class JavaDataStream(Stream): return JavaDataStream(self, self._unary_call("filter", java_func_class)) + def union(self, *streams): + """See io.ray.streaming.api.stream.DataStream.union""" + assert len(streams) >= 1, "Need at least one stream to union with" + j_streams = [s._j_stream for s in streams] + j_stream = self._gateway_client().union(self._j_stream, *j_streams) + return JavaUnionStream(self, j_stream) + def key_by(self, java_func_class): """See io.ray.streaming.api.stream.DataStream.keyBy""" self._check_partition_call() @@ -429,6 +450,30 @@ class JavaKeyDataStream(JavaDataStream): return KeyDataStream(self, j_stream) +class UnionStream(DataStream): + """Represents a union stream. + Wrapper of java io.ray.streaming.python.stream.PythonUnionStream + """ + + def __init__(self, input_stream, j_stream): + super().__init__(input_stream, j_stream) + + def get_language(self): + return function.Language.PYTHON + + +class JavaUnionStream(JavaDataStream): + """Represents a java union stream. + Wrapper of java io.ray.streaming.api.stream.UnionStream + """ + + def __init__(self, input_stream, j_stream): + super().__init__(input_stream, j_stream) + + def get_language(self): + return function.Language.JAVA + + class StreamSource(DataStream): """Represents a source of the DataStream. Wrapper of java io.ray.streaming.python.stream.PythonStreamSource diff --git a/streaming/python/function.py b/streaming/python/function.py index b4d4d2383..c9cfedcc6 100644 --- a/streaming/python/function.py +++ b/streaming/python/function.py @@ -23,6 +23,16 @@ class Function(ABC): pass +class EmptyFunction(Function): + """Default function which does nothing""" + + def open(self, runtime_context): + pass + + def close(self): + pass + + class SourceContext(ABC): """ Interface that source functions use to emit elements, and possibly @@ -216,7 +226,7 @@ class SimpleFlatMapFunction(FlatMapFunction): self.func = func self.process_func = None sig = inspect.signature(func) - assert len(sig.parameters) <= 2,\ + assert len(sig.parameters) <= 2, \ "func should receive value [, collector] as arguments" if len(sig.parameters) == 2: @@ -292,7 +302,7 @@ def load_function(descriptor_func_bytes: bytes): a streaming function """ assert len(descriptor_func_bytes) > 0 - function_bytes, module_name, function_name, function_interface\ + function_bytes, module_name, function_name, function_interface \ = gateway_client.deserialize(descriptor_func_bytes) if function_bytes: return deserialize(function_bytes) diff --git a/streaming/python/operator.py b/streaming/python/operator.py index d6937543f..d92821de7 100644 --- a/streaming/python/operator.py +++ b/streaming/python/operator.py @@ -1,8 +1,11 @@ -from abc import ABC, abstractmethod import enum +import importlib +from abc import ABC, abstractmethod + from ray import streaming from ray.streaming import function from ray.streaming import message +from ray.streaming.runtime import gateway_client class OperatorType(enum.Enum): @@ -214,6 +217,16 @@ class SinkOperator(StreamOperator, OneInputOperator): self.func.sink(record.value) +class UnionOperator(StreamOperator, OneInputOperator): + """Operator for union operation""" + + def __init__(self): + super().__init__(function.EmptyFunction()) + + def process_element(self, record): + self.collect(record) + + _function_to_operator = { function.SourceFunction: SourceOperator, function.MapFunction: MapOperator, @@ -225,7 +238,36 @@ _function_to_operator = { } -def create_operator(func: function.Function): +def load_operator(descriptor_operator_bytes: bytes): + """ + Deserialize `descriptor_operator_bytes` to get operator info, then + create streaming operator. + Note that this function must be kept in sync with + `io.ray.streaming.runtime.python.GraphPbBuilder.serializeOperator` + + Args: + descriptor_operator_bytes: serialized operator info + + Returns: + a streaming operator + """ + assert len(descriptor_operator_bytes) > 0 + function_desc_bytes, module_name, class_name \ + = gateway_client.deserialize(descriptor_operator_bytes) + if function_desc_bytes: + return create_operator_with_func( + function.load_function(function_desc_bytes)) + else: + assert module_name + assert class_name + mod = importlib.import_module(module_name) + cls = getattr(mod, class_name) + assert issubclass(cls, Operator) + print("cls", cls) + return cls() + + +def create_operator_with_func(func: function.Function): """Create an operator according to a :class:`function.Function` Args: diff --git a/streaming/python/runtime/gateway_client.py b/streaming/python/runtime/gateway_client.py index 493de2cf8..455ccf234 100644 --- a/streaming/python/runtime/gateway_client.py +++ b/streaming/python/runtime/gateway_client.py @@ -30,7 +30,7 @@ class GatewayClient: def create_py_stream_source(self, serialized_func): assert isinstance(serialized_func, bytes) - call = self._python_gateway_actor.createPythonStreamSource\ + call = self._python_gateway_actor.createPythonStreamSource \ .remote(serialized_func) return deserialize(ray.get(call)) @@ -41,10 +41,16 @@ class GatewayClient: def create_py_partition(self, serialized_partition): assert isinstance(serialized_partition, bytes) - call = self._python_gateway_actor.createPyPartition\ + call = self._python_gateway_actor.createPyPartition \ .remote(serialized_partition) return deserialize(ray.get(call)) + def union(self, *streams): + serialized_streams = serialize(streams) + call = self._python_gateway_actor.union \ + .remote(serialized_streams) + return deserialize(ray.get(call)) + def call_function(self, java_class, java_function, *args): java_params = serialize([java_class, java_function] + list(args)) call = self._python_gateway_actor.callFunction.remote(java_params) diff --git a/streaming/python/runtime/graph.py b/streaming/python/runtime/graph.py index 645827601..721cc358b 100644 --- a/streaming/python/runtime/graph.py +++ b/streaming/python/runtime/graph.py @@ -5,7 +5,6 @@ import ray.streaming.generated.remote_call_pb2 as remote_call_pb import ray.streaming.generated.streaming_pb2 as streaming_pb import ray.streaming.operator as operator import ray.streaming.partition as partition -from ray.streaming import function from ray.streaming.generated.streaming_pb2 import Language @@ -32,9 +31,8 @@ class ExecutionNode: node_pb.node_type)] self.parallelism = node_pb.parallelism if node_pb.language == Language.PYTHON: - func_bytes = node_pb.function # python function descriptor - func = function.load_function(func_bytes) - self.stream_operator = operator.create_operator(func) + operator_bytes = node_pb.operator # python operator descriptor + self.stream_operator = operator.load_operator(operator_bytes) self.execution_tasks = [ ExecutionTask(task) for task in node_pb.execution_tasks ] diff --git a/streaming/python/tests/test_operator.py b/streaming/python/tests/test_operator.py index 2a72aa253..937822125 100644 --- a/streaming/python/tests/test_operator.py +++ b/streaming/python/tests/test_operator.py @@ -1,8 +1,37 @@ -from ray.streaming import operator from ray.streaming import function +from ray.streaming import operator +from ray.streaming.operator import OperatorType +from ray.streaming.runtime import gateway_client -def test_create_operator(): +def test_create_operator_with_func(): map_func = function.SimpleMapFunction(lambda x: x) - map_operator = operator.create_operator(map_func) + map_operator = operator.create_operator_with_func(map_func) assert type(map_operator) is operator.MapOperator + + +class MapFunc(function.MapFunction): + def map(self, value): + return str(value) + + +class EmptyOperator(operator.StreamOperator): + def __init__(self): + super().__init__(function.EmptyFunction()) + + def operator_type(self) -> OperatorType: + return OperatorType.ONE_INPUT + + +def test_load_operator(): + # function_bytes, module_name, class_name, + descriptor_func_bytes = gateway_client.serialize( + [None, __name__, MapFunc.__name__, "MapFunction"]) + descriptor_op_bytes = gateway_client.serialize( + [descriptor_func_bytes, "", ""]) + map_operator = operator.load_operator(descriptor_op_bytes) + assert type(map_operator) is operator.MapOperator + descriptor_op_bytes = gateway_client.serialize( + [None, __name__, EmptyOperator.__name__]) + test_operator = operator.load_operator(descriptor_op_bytes) + assert isinstance(test_operator, EmptyOperator) diff --git a/streaming/python/tests/test_union_stream.py b/streaming/python/tests/test_union_stream.py new file mode 100644 index 000000000..2050ab1b2 --- /dev/null +++ b/streaming/python/tests/test_union_stream.py @@ -0,0 +1,47 @@ +import os + +import ray +from ray.streaming import StreamingContext + + +def test_union_stream(): + ray.init(load_code_from_local=True, include_java=True) + ctx = StreamingContext.Builder() \ + .option("streaming.metrics.reporters", "") \ + .build() + sink_file = "/tmp/test_union_stream.txt" + if os.path.exists(sink_file): + os.remove(sink_file) + + def sink_func(x): + with open(sink_file, "a") as f: + print("sink_func", x) + f.write(str(x)) + + stream1 = ctx.from_values(1, 2) + stream2 = ctx.from_values(3, 4) + stream3 = ctx.from_values(5, 6) + stream1.union(stream2, stream3).sink(sink_func) + ctx.submit("test_union_stream") + import time + slept_time = 0 + while True: + if os.path.exists(sink_file): + time.sleep(3) + with open(sink_file, "r") as f: + result = f.read() + print("sink result", result) + assert set(result) == {"1", "2", "3", "4", "5", "6"} + print("Execution succeed") + break + if slept_time >= 60: + raise Exception("Execution not finished") + slept_time = slept_time + 1 + print("Wait finish...") + time.sleep(1) + + ray.shutdown() + + +if __name__ == "__main__": + test_union_stream() diff --git a/streaming/src/protobuf/remote_call.proto b/streaming/src/protobuf/remote_call.proto index 2138cf4ba..b331d3928 100644 --- a/streaming/src/protobuf/remote_call.proto +++ b/streaming/src/protobuf/remote_call.proto @@ -14,8 +14,8 @@ message ExecutionGraph { int32 parallelism = 2; NodeType node_type = 3; Language language = 4; - // serialized user function - bytes function = 5; + // serialized operator + bytes operator = 5; repeated ExecutionTask execution_tasks = 6; repeated ExecutionEdge input_edges = 7; repeated ExecutionEdge output_edges = 8;