diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java index b5cb0a931..63daebb73 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java @@ -6,6 +6,7 @@ import io.ray.streaming.api.stream.StreamSink; import io.ray.streaming.client.JobClient; import io.ray.streaming.jobgraph.JobGraph; import io.ray.streaming.jobgraph.JobGraphBuilder; +import io.ray.streaming.jobgraph.JobGraphOptimizer; import io.ray.streaming.util.Config; import java.io.Serializable; import java.util.ArrayList; @@ -56,7 +57,8 @@ public class StreamingContext implements Serializable { */ public void execute(String jobName) { JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(this.streamSinks, jobName); - this.jobGraph = jobGraphBuilder.build(); + JobGraph originalJobGraph = jobGraphBuilder.build(); + this.jobGraph = new JobGraphOptimizer(originalJobGraph).optimize(); jobGraph.printJobGraph(); LOG.info("JobGraph digraph\n{}", jobGraph.generateDigraph()); diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/ForwardPartition.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/ForwardPartition.java new file mode 100644 index 000000000..a89bf8f92 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/ForwardPartition.java @@ -0,0 +1,19 @@ +package io.ray.streaming.api.partition.impl; + +import io.ray.streaming.api.partition.Partition; + +/** + * Default partition for operator if the operator can be chained with succeeding operators. + * Partition will be set to {@link RoundRobinPartition} if the operator can't be chiained with + * succeeding operators. + * + * @param Type of the input record. + */ +public class ForwardPartition implements Partition { + private int[] partitions = new int[] {0}; + + @Override + public int[] partition(T record, int numPartition) { + return partitions; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java index 87ccb5eaf..baecdc339 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java @@ -3,8 +3,7 @@ package io.ray.streaming.api.stream; import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.function.impl.SourceFunction; import io.ray.streaming.api.function.internal.CollectionSourceFunction; -import io.ray.streaming.api.partition.impl.RoundRobinPartition; -import io.ray.streaming.operator.impl.SourceOperator; +import io.ray.streaming.operator.impl.SourceOperatorImpl; import java.util.Collection; /** @@ -15,7 +14,7 @@ import java.util.Collection; public class DataStreamSource extends DataStream implements StreamSource { private DataStreamSource(StreamingContext streamingContext, SourceFunction sourceFunction) { - super(streamingContext, new SourceOperator<>(sourceFunction), new RoundRobinPartition<>()); + super(streamingContext, new SourceOperatorImpl<>(sourceFunction)); } public static DataStreamSource fromSource( diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/JoinStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/JoinStream.java index ae33a8829..1f03e517e 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/JoinStream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/JoinStream.java @@ -2,6 +2,7 @@ package io.ray.streaming.api.stream; import io.ray.streaming.api.function.impl.JoinFunction; import io.ray.streaming.api.function.impl.KeyFunction; +import io.ray.streaming.operator.impl.JoinOperator; import java.io.Serializable; /** @@ -9,40 +10,42 @@ import java.io.Serializable; * * @param Type of the data in the left stream. * @param Type of the data in the right stream. - * @param Type of the data in the joined stream. + * @param Type of the data in the joined stream. */ -public class JoinStream extends DataStream { +public class JoinStream extends DataStream { + private final DataStream rightStream; public JoinStream(DataStream leftStream, DataStream rightStream) { - super(leftStream, null); + super(leftStream, new JoinOperator<>()); + this.rightStream = rightStream; + } + + public DataStream getRightStream() { + return rightStream; } /** * Apply key-by to the left join stream. */ - public Where where(KeyFunction keyFunction) { + public Where where(KeyFunction keyFunction) { return new Where<>(this, keyFunction); } /** * Where clause of the join transformation. * - * @param Type of the data in the left stream. - * @param Type of the data in the right stream. - * @param Type of the data in the joined stream. * @param Type of the join key. */ - class Where implements Serializable { - - private JoinStream joinStream; + class Where implements Serializable { + private JoinStream joinStream; private KeyFunction leftKeyByFunction; - public Where(JoinStream joinStream, KeyFunction leftKeyByFunction) { + Where(JoinStream joinStream, KeyFunction leftKeyByFunction) { this.joinStream = joinStream; this.leftKeyByFunction = leftKeyByFunction; } - public Equal equalLo(KeyFunction rightKeyFunction) { + public Equal equalTo(KeyFunction rightKeyFunction) { return new Equal<>(joinStream, leftKeyByFunction, rightKeyFunction); } } @@ -50,26 +53,25 @@ public class JoinStream extends DataStream { /** * Equal clause of the join transformation. * - * @param Type of the data in the left stream. - * @param Type of the data in the right stream. - * @param Type of the data in the joined stream. * @param Type of the join key. */ - class Equal implements Serializable { - - private JoinStream joinStream; + class Equal implements Serializable { + private JoinStream joinStream; private KeyFunction leftKeyByFunction; private KeyFunction rightKeyByFunction; - public Equal(JoinStream joinStream, KeyFunction leftKeyByFunction, - KeyFunction rightKeyByFunction) { + Equal(JoinStream joinStream, KeyFunction leftKeyByFunction, + KeyFunction rightKeyByFunction) { this.joinStream = joinStream; this.leftKeyByFunction = leftKeyByFunction; this.rightKeyByFunction = rightKeyByFunction; } - public DataStream with(JoinFunction joinFunction) { - return (DataStream) joinStream; + @SuppressWarnings("unchecked") + public DataStream with(JoinFunction joinFunction) { + JoinOperator joinOperator = (JoinOperator) joinStream.getOperator(); + joinOperator.setFunction(joinFunction); + return (DataStream) joinStream; } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java index 241cef4f5..987df79b0 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java @@ -4,7 +4,8 @@ import com.google.common.base.Preconditions; import io.ray.streaming.api.Language; import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.partition.Partition; -import io.ray.streaming.api.partition.impl.RoundRobinPartition; +import io.ray.streaming.api.partition.impl.ForwardPartition; +import io.ray.streaming.operator.ChainStrategy; import io.ray.streaming.operator.Operator; import io.ray.streaming.operator.StreamOperator; import io.ray.streaming.python.PythonPartition; @@ -30,8 +31,7 @@ public abstract class Stream, T> private Stream originalStream; public Stream(StreamingContext streamingContext, StreamOperator streamOperator) { - this(streamingContext, null, streamOperator, - selectPartition(streamOperator)); + this(streamingContext, null, streamOperator, getForwardPartition(streamOperator)); } public Stream(StreamingContext streamingContext, @@ -42,7 +42,7 @@ public abstract class Stream, T> public Stream(Stream inputStream, StreamOperator streamOperator) { this(inputStream.getStreamingContext(), inputStream, streamOperator, - selectPartition(streamOperator)); + getForwardPartition(streamOperator)); } public Stream(Stream inputStream, StreamOperator streamOperator, Partition partition) { @@ -50,9 +50,9 @@ public abstract class Stream, T> } protected Stream(StreamingContext streamingContext, - Stream inputStream, - StreamOperator streamOperator, - Partition partition) { + Stream inputStream, + StreamOperator streamOperator, + Partition partition) { this.streamingContext = streamingContext; this.inputStream = inputStream; this.operator = streamOperator; @@ -73,15 +73,16 @@ public abstract class Stream, T> this.streamingContext = originalStream.getStreamingContext(); this.inputStream = originalStream.getInputStream(); this.operator = originalStream.getOperator(); + Preconditions.checkNotNull(operator); } @SuppressWarnings("unchecked") - private static Partition selectPartition(Operator operator) { + private static Partition getForwardPartition(Operator operator) { switch (operator.getLanguage()) { case PYTHON: - return (Partition) PythonPartition.RoundRobinPartition; + return (Partition) PythonPartition.ForwardPartition; case JAVA: - return new RoundRobinPartition<>(); + return new ForwardPartition<>(); default: throw new UnsupportedOperationException( "Unsupported language " + operator.getLanguage()); @@ -165,5 +166,29 @@ public abstract class Stream, T> return originalStream; } + /** + * Set chain strategy for this stream + */ + public S withChainStrategy(ChainStrategy chainStrategy) { + Preconditions.checkArgument(!isProxyStream()); + operator.setChainStrategy(chainStrategy); + return self(); + } + + /** + * Disable chain for this stream + */ + public S disableChain() { + return withChainStrategy(ChainStrategy.NEVER); + } + + /** + * Set the partition function of this {@link Stream} so that output elements are forwarded to + * next operator locally. + */ + public S forward() { + return setPartition(getForwardPartition(operator)); + } + public abstract Language getLanguage(); } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java index 833cddaa8..c3854a6a0 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java @@ -16,6 +16,8 @@ public class UnionStream extends DataStream { private List> unionStreams; public UnionStream(DataStream input, List> streams) { + // Union stream does not create a physical operation, so we don't have to set partition + // function for it. super(input, new UnionOperator()); this.unionStreams = new ArrayList<>(); streams.forEach(this::addStream); diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java index e670e5ea3..8711109a2 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java @@ -5,6 +5,8 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -17,15 +19,24 @@ public class JobGraph implements Serializable { private final String jobName; private final Map jobConfig; - private List jobVertexList; - private List jobEdgeList; + private List jobVertices; + private List jobEdges; private String digraph; public JobGraph(String jobName, Map jobConfig) { this.jobName = jobName; this.jobConfig = jobConfig; - this.jobVertexList = new ArrayList<>(); - this.jobEdgeList = new ArrayList<>(); + this.jobVertices = new ArrayList<>(); + this.jobEdges = new ArrayList<>(); + } + + public JobGraph(String jobName, Map jobConfig, + List jobVertices, List jobEdges) { + this.jobName = jobName; + this.jobConfig = jobConfig; + this.jobVertices = jobVertices; + this.jobEdges = jobEdges; + generateDigraph(); } /** @@ -36,12 +47,12 @@ public class JobGraph implements Serializable { */ public String generateDigraph() { StringBuilder digraph = new StringBuilder(); - digraph.append("digraph ").append(jobName + " ").append(" {"); + digraph.append("digraph ").append(jobName).append(" ").append(" {"); - for (JobEdge jobEdge : jobEdgeList) { + for (JobEdge jobEdge : jobEdges) { String srcNode = null; String targetNode = null; - for (JobVertex jobVertex : jobVertexList) { + for (JobVertex jobVertex : jobVertices) { if (jobEdge.getSrcVertexId() == jobVertex.getVertexId()) { srcNode = jobVertex.getVertexId() + "-" + jobVertex.getStreamOperator().getName(); } else if (jobEdge.getTargetVertexId() == jobVertex.getVertexId()) { @@ -49,7 +60,7 @@ public class JobGraph implements Serializable { } } digraph.append(System.getProperty("line.separator")); - digraph.append(srcNode).append(" -> ").append(targetNode); + digraph.append(String.format(" \"%s\" -> \"%s\"", srcNode, targetNode)); } digraph.append(System.getProperty("line.separator")).append("}"); @@ -58,19 +69,47 @@ public class JobGraph implements Serializable { } public void addVertex(JobVertex vertex) { - this.jobVertexList.add(vertex); + this.jobVertices.add(vertex); } public void addEdge(JobEdge jobEdge) { - this.jobEdgeList.add(jobEdge); + this.jobEdges.add(jobEdge); } - public List getJobVertexList() { - return jobVertexList; + public List getJobVertices() { + return jobVertices; } - public List getJobEdgeList() { - return jobEdgeList; + public List getSourceVertices() { + return jobVertices.stream() + .filter(v -> v.getVertexType() == VertexType.SOURCE) + .collect(Collectors.toList()); + } + + public List getSinkVertices() { + return jobVertices.stream() + .filter(v -> v.getVertexType() == VertexType.SINK) + .collect(Collectors.toList()); + } + + public JobVertex getVertex(int vertexId) { + return jobVertices.stream().filter(v -> v.getVertexId() == vertexId).findFirst().get(); + } + + public List getJobEdges() { + return jobEdges; + } + + public Set getVertexInputEdges(int vertexId) { + return jobEdges.stream() + .filter(jobEdge -> jobEdge.getTargetVertexId() == vertexId) + .collect(Collectors.toSet()); + } + + public Set getVertexOutputEdges(int vertexId) { + return jobEdges.stream() + .filter(jobEdge -> jobEdge.getSrcVertexId() == vertexId) + .collect(Collectors.toSet()); } public String getDigraph() { @@ -90,17 +129,17 @@ public class JobGraph implements Serializable { return; } LOG.info("Printing job graph:"); - for (JobVertex jobVertex : jobVertexList) { + for (JobVertex jobVertex : jobVertices) { LOG.info(jobVertex.toString()); } - for (JobEdge jobEdge : jobEdgeList) { + for (JobEdge jobEdge : jobEdges) { LOG.info(jobEdge.toString()); } } public boolean isCrossLanguageGraph() { - Language language = jobVertexList.get(0).getLanguage(); - for (JobVertex jobVertex : jobVertexList) { + Language language = jobVertices.get(0).getLanguage(); + for (JobVertex jobVertex : jobVertices) { if (jobVertex.getLanguage() != language) { return true; } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java index 2a2d02ebf..2f6eee6c6 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java @@ -2,6 +2,7 @@ package io.ray.streaming.jobgraph; import com.google.common.base.Preconditions; import io.ray.streaming.api.stream.DataStream; +import io.ray.streaming.api.stream.JoinStream; import io.ray.streaming.api.stream.Stream; import io.ray.streaming.api.stream.StreamSink; import io.ray.streaming.api.stream.StreamSource; @@ -26,7 +27,7 @@ public class JobGraphBuilder { private List streamSinkList; public JobGraphBuilder(List streamSinkList) { - this(streamSinkList, "job-" + System.currentTimeMillis()); + this(streamSinkList, "job_" + System.currentTimeMillis()); } public JobGraphBuilder(List streamSinkList, String jobName) { @@ -61,18 +62,20 @@ public class JobGraphBuilder { "Reference stream should be skipped."); int vertexId = stream.getId(); int parallelism = stream.getParallelism(); + Map config = stream.getConfig(); JobVertex jobVertex; if (stream instanceof StreamSink) { - jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator); + jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator, config); Stream parentStream = stream.getInputStream(); int inputVertexId = parentStream.getId(); JobEdge jobEdge = new JobEdge(inputVertexId, vertexId, parentStream.getPartition()); this.jobGraph.addEdge(jobEdge); processStream(parentStream); } else if (stream instanceof StreamSource) { - jobVertex = new JobVertex(vertexId, parallelism, VertexType.SOURCE, streamOperator); + jobVertex = new JobVertex(vertexId, parallelism, VertexType.SOURCE, streamOperator, config); } else if (stream instanceof DataStream || stream instanceof PythonDataStream) { - jobVertex = new JobVertex(vertexId, parallelism, VertexType.TRANSFORMATION, streamOperator); + jobVertex = new JobVertex( + vertexId, parallelism, VertexType.TRANSFORMATION, streamOperator, config); Stream parentStream = stream.getInputStream(); int inputVertexId = parentStream.getId(); JobEdge jobEdge = new JobEdge(inputVertexId, vertexId, parentStream.getPartition()); @@ -92,10 +95,17 @@ public class JobGraphBuilder { this.jobGraph.addEdge(otherEdge); processStream(otherStream); } + + // process join stream + if (stream instanceof JoinStream) { + DataStream rightStream = ((JoinStream) stream).getRightStream(); + this.jobGraph.addEdge( + new JobEdge(rightStream.getId(), vertexId, rightStream.getPartition())); + processStream(rightStream); + } } else { throw new UnsupportedOperationException("Unsupported stream: " + stream); } - jobVertex.setConfig(stream.getConfig()); this.jobGraph.addVertex(jobVertex); } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphOptimizer.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphOptimizer.java new file mode 100644 index 000000000..c44d0be6a --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphOptimizer.java @@ -0,0 +1,187 @@ +package io.ray.streaming.jobgraph; + +import io.ray.streaming.api.Language; +import io.ray.streaming.api.partition.Partition; +import io.ray.streaming.api.partition.impl.ForwardPartition; +import io.ray.streaming.api.partition.impl.RoundRobinPartition; +import io.ray.streaming.operator.ChainStrategy; +import io.ray.streaming.operator.StreamOperator; +import io.ray.streaming.operator.chain.ChainedOperator; +import io.ray.streaming.python.PythonOperator; +import io.ray.streaming.python.PythonOperator.ChainedPythonOperator; +import io.ray.streaming.python.PythonPartition; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; + +/** + * Optimize job graph by chaining some operators so that some operators can be run in the + * same thread. + */ +public class JobGraphOptimizer { + private final JobGraph jobGraph; + private Set visited = new HashSet<>(); + // vertex id -> vertex + private Map vertexMap; + private Map> outputEdgesMap; + // tail vertex id -> mergedVertex + private Map>> mergedVertexMap; + + public JobGraphOptimizer(JobGraph jobGraph) { + this.jobGraph = jobGraph; + vertexMap = jobGraph.getJobVertices().stream() + .collect(Collectors.toMap(JobVertex::getVertexId, Function.identity())); + outputEdgesMap = vertexMap.keySet().stream().collect(Collectors.toMap( + id -> vertexMap.get(id), id -> new HashSet<>(jobGraph.getVertexOutputEdges(id)))); + mergedVertexMap = new HashMap<>(); + } + + public JobGraph optimize() { + // Deep-first traverse nodes from source to sink to merge vertices that can be chained + // together. + jobGraph.getSourceVertices().forEach(vertex -> { + List verticesToMerge = new ArrayList<>(); + verticesToMerge.add(vertex); + mergeVerticesRecursively(vertex, verticesToMerge); + }); + + List vertices = mergedVertexMap.values().stream() + .map(Pair::getLeft).collect(Collectors.toList()); + + return new JobGraph(jobGraph.getJobName(), jobGraph.getJobConfig(), vertices, createEdges()); + } + + private void mergeVerticesRecursively(JobVertex vertex, List verticesToMerge) { + if (!visited.contains(vertex)) { + visited.add(vertex); + Set outputEdges = outputEdgesMap.get(vertex); + if (outputEdges.isEmpty()) { + mergeAndAddVertex(verticesToMerge); + } else { + outputEdges.forEach(edge -> { + JobVertex succeedingVertex = vertexMap.get(edge.getTargetVertexId()); + if (canBeChained(vertex, succeedingVertex, edge)) { + verticesToMerge.add(succeedingVertex); + mergeVerticesRecursively(succeedingVertex, verticesToMerge); + } else { + mergeAndAddVertex(verticesToMerge); + List newMergedVertices = new ArrayList<>(); + newMergedVertices.add(succeedingVertex); + mergeVerticesRecursively(succeedingVertex, newMergedVertices); + } + }); + } + } + } + + private void mergeAndAddVertex(List verticesToMerge) { + JobVertex mergedVertex; + JobVertex headVertex = verticesToMerge.get(0); + Language language = headVertex.getLanguage(); + if (verticesToMerge.size() == 1) { + // no chain + mergedVertex = headVertex; + } else { + List operators = verticesToMerge.stream() + .map(v -> vertexMap.get(v.getVertexId()).getStreamOperator()) + .collect(Collectors.toList()); + List> configs = verticesToMerge.stream() + .map(v -> vertexMap.get(v.getVertexId()).getConfig()) + .collect(Collectors.toList()); + StreamOperator operator; + if (language == Language.JAVA) { + operator = ChainedOperator.newChainedOperator(operators, configs); + } else { + List pythonOperators = operators.stream() + .map(o -> (PythonOperator) o).collect(Collectors.toList()); + operator = new ChainedPythonOperator(pythonOperators, configs); + } + // chained operator config is placed into `ChainedOperator`. + mergedVertex = new JobVertex(headVertex.getVertexId(), headVertex.getParallelism(), + headVertex.getVertexType(), operator, new HashMap<>()); + } + + mergedVertexMap.put(mergedVertex.getVertexId(), Pair.of(mergedVertex, verticesToMerge)); + } + + private List createEdges() { + List edges = new ArrayList<>(); + mergedVertexMap.forEach((id, pair) -> { + JobVertex mergedVertex = pair.getLeft(); + List mergedVertices = pair.getRight(); + JobVertex tailVertex = mergedVertices.get(mergedVertices.size() - 1); + // input edge will be set up in input vertices + if (outputEdgesMap.containsKey(tailVertex)) { + outputEdgesMap.get(tailVertex).forEach(edge -> { + Pair> downstreamPair = + mergedVertexMap.get(edge.getTargetVertexId()); + // change ForwardPartition to RoundRobinPartition. + Partition partition = changePartition(edge.getPartition()); + JobEdge newEdge = new JobEdge( + mergedVertex.getVertexId(), + downstreamPair.getLeft().getVertexId(), + partition); + edges.add(newEdge); + }); + + } + }); + return edges; + } + + /** + * Change ForwardPartition to RoundRobinPartition. + */ + private Partition changePartition(Partition partition) { + if (partition instanceof PythonPartition) { + PythonPartition pythonPartition = (PythonPartition) partition; + if (!pythonPartition.isConstructedFromBinary() && + pythonPartition.getFunctionName().equals(PythonPartition.FORWARD_PARTITION_CLASS)) { + return PythonPartition.RoundRobinPartition; + } else { + return partition; + } + } else { + if (partition instanceof ForwardPartition) { + return new RoundRobinPartition(); + } else { + return partition; + } + } + } + + private boolean canBeChained(JobVertex precedingVertex, + JobVertex succeedingVertex, + JobEdge edge) { + if (jobGraph.getVertexOutputEdges(precedingVertex.getVertexId()).size() > 1 || + jobGraph.getVertexInputEdges(succeedingVertex.getVertexId()).size() > 1) { + return false; + } + if (precedingVertex.getParallelism() != succeedingVertex.getParallelism()) { + return false; + } + if (precedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.NEVER + || succeedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.NEVER + || succeedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.HEAD) { + return false; + } + if (precedingVertex.getLanguage() != succeedingVertex.getLanguage()) { + return false; + } + Partition partition = edge.getPartition(); + if (!(partition instanceof PythonPartition)) { + return partition instanceof ForwardPartition; + } else { + PythonPartition pythonPartition = (PythonPartition) partition; + return !pythonPartition.isConstructedFromBinary() && + pythonPartition.getFunctionName().equals(PythonPartition.FORWARD_PARTITION_CLASS); + } + } + +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobVertex.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobVertex.java index 98bb14b62..7fdb4efdd 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobVertex.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobVertex.java @@ -10,6 +10,7 @@ import java.util.Map; * Job vertex is a cell node where logic is executed. */ public class JobVertex implements Serializable { + private int vertexId; private int parallelism; private VertexType vertexType; @@ -20,12 +21,14 @@ public class JobVertex implements Serializable { public JobVertex(int vertexId, int parallelism, VertexType vertexType, - StreamOperator streamOperator) { + StreamOperator streamOperator, + Map config) { this.vertexId = vertexId; this.parallelism = parallelism; this.vertexType = vertexType; this.streamOperator = streamOperator; this.language = streamOperator.getLanguage(); + this.config = config; } public int getVertexId() { @@ -67,4 +70,5 @@ public class JobVertex implements Serializable { .add("config", config) .toString(); } + } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/ChainStrategy.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/ChainStrategy.java new file mode 100644 index 000000000..9a0084b10 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/ChainStrategy.java @@ -0,0 +1,20 @@ +package io.ray.streaming.operator; + +/** + * Chain strategy for streaming operators. Chained operators are run in the same thread. + */ +public enum ChainStrategy { + /** + * The operator won't be chained with preceding operators, but maybe chained with succeeding + * operators. + */ + HEAD, + /** + * Operators will be chained together when possible. + */ + ALWAYS, + /** + * The operator won't be chained with any operator. + */ + NEVER +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java index 58a0d5fa2..0bbb0d7a2 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java @@ -9,6 +9,8 @@ import java.util.List; public interface Operator extends Serializable { + String getName(); + void open(List collectors, RuntimeContext runtimeContext); void finish(); @@ -20,4 +22,7 @@ public interface Operator extends Serializable { Language getLanguage(); OperatorType getOpType(); + + ChainStrategy getChainStrategy(); + } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/SourceOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/SourceOperator.java new file mode 100644 index 000000000..3cf9ab1d7 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/SourceOperator.java @@ -0,0 +1,14 @@ +package io.ray.streaming.operator; + +import io.ray.streaming.api.function.impl.SourceFunction.SourceContext; + +public interface SourceOperator extends Operator { + + void run(); + + SourceContext getSourceContext(); + + default OperatorType getOpType() { + return OperatorType.SOURCE; + } +} \ No newline at end of file diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java index 4160c736f..4c9239966 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java @@ -12,13 +12,22 @@ import java.util.List; public abstract class StreamOperator implements Operator { protected final String name; - protected final F function; - protected final RichFunction richFunction; + protected F function; + protected RichFunction richFunction; protected List collectorList; protected RuntimeContext runtimeContext; + private ChainStrategy chainStrategy = ChainStrategy.ALWAYS; - public StreamOperator(F function) { + protected StreamOperator() { this.name = getClass().getSimpleName(); + } + + protected StreamOperator(F function) { + this(); + setFunction(function); + } + + public void setFunction(F function) { this.function = function; this.richFunction = Functions.wrap(function); } @@ -62,7 +71,17 @@ public abstract class StreamOperator implements Operator { } } + @Override public String getName() { return name; } + + public void setChainStrategy(ChainStrategy chainStrategy) { + this.chainStrategy = chainStrategy; + } + + @Override + public ChainStrategy getChainStrategy() { + return chainStrategy; + } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ChainedOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ChainedOperator.java new file mode 100644 index 000000000..ad44d8097 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ChainedOperator.java @@ -0,0 +1,169 @@ +package io.ray.streaming.operator.chain; + +import com.google.common.base.Preconditions; +import io.ray.streaming.api.Language; +import io.ray.streaming.api.collector.Collector; +import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.api.function.Function; +import io.ray.streaming.api.function.impl.SourceFunction.SourceContext; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.OneInputOperator; +import io.ray.streaming.operator.Operator; +import io.ray.streaming.operator.OperatorType; +import io.ray.streaming.operator.SourceOperator; +import io.ray.streaming.operator.StreamOperator; +import io.ray.streaming.operator.TwoInputOperator; +import java.lang.reflect.Proxy; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Abstract base class for chained operators. + */ +public abstract class ChainedOperator extends StreamOperator { + protected final List operators; + protected final Operator headOperator; + protected final Operator tailOperator; + private final List> configs; + + public ChainedOperator(List operators, List> configs) { + Preconditions.checkArgument(operators.size() >= 2, + "Need at lease two operators to be chained together"); + operators.stream().skip(1) + .forEach(operator -> Preconditions.checkArgument(operator instanceof OneInputOperator)); + this.operators = operators; + this.configs = configs; + this.headOperator = operators.get(0); + this.tailOperator = operators.get(operators.size() - 1); + } + + @Override + public void open(List collectorList, RuntimeContext runtimeContext) { + // Dont' call super.open() as we `open` every operator separately. + List succeedingCollectors = operators.stream().skip(1) + .map(operator -> new ForwardCollector((OneInputOperator) operator)) + .collect(Collectors.toList()); + for (int i = 0; i < operators.size() - 1; i++) { + StreamOperator operator = operators.get(i); + List forwardCollectors = + Collections.singletonList(succeedingCollectors.get(i)); + operator.open(forwardCollectors, createRuntimeContext(runtimeContext, i)); + } + // tail operator send data to downstream using provided collectors. + tailOperator.open(collectorList, createRuntimeContext(runtimeContext, operators.size() - 1)); + } + + @Override + public OperatorType getOpType() { + return headOperator.getOpType(); + } + + @Override + public Language getLanguage() { + return headOperator.getLanguage(); + } + + @Override + public String getName() { + return operators.stream().map(Operator::getName) + .collect(Collectors.joining(" -> ", "[", "]")); + } + + public List getOperators() { + return operators; + } + + public Operator getHeadOperator() { + return headOperator; + } + + public Operator getTailOperator() { + return tailOperator; + } + + private RuntimeContext createRuntimeContext(RuntimeContext runtimeContext, int index) { + return (RuntimeContext) Proxy.newProxyInstance(runtimeContext.getClass().getClassLoader(), + new Class[] {RuntimeContext.class}, + (proxy, method, methodArgs) -> { + if (method.getName().equals("getConfig")) { + return configs.get(index); + } else { + return method.invoke(runtimeContext, methodArgs); + } + }); + } + + public static ChainedOperator newChainedOperator( + List operators, + List> configs) { + switch (operators.get(0).getOpType()) { + case SOURCE: + return new ChainedSourceOperator(operators, configs); + case ONE_INPUT: + return new ChainedOneInputOperator(operators, configs); + case TWO_INPUT: + return new ChainedTwoInputOperator(operators, configs); + default: + throw new IllegalArgumentException( + "Unsupported operator type " + operators.get(0).getOpType()); + } + } + + static class ChainedSourceOperator extends ChainedOperator + implements SourceOperator { + private final SourceOperator sourceOperator; + + @SuppressWarnings("unchecked") + ChainedSourceOperator(List operators, List> configs) { + super(operators, configs); + sourceOperator = (SourceOperator) headOperator; + } + + @Override + public void run() { + sourceOperator.run(); + } + + @Override + public SourceContext getSourceContext() { + return sourceOperator.getSourceContext(); + } + + } + + static class ChainedOneInputOperator extends ChainedOperator + implements OneInputOperator { + private final OneInputOperator inputOperator; + + @SuppressWarnings("unchecked") + ChainedOneInputOperator(List operators, List> configs) { + super(operators, configs); + inputOperator = (OneInputOperator) headOperator; + } + + @Override + public void processElement(Record record) throws Exception { + inputOperator.processElement(record); + } + + } + + static class ChainedTwoInputOperator extends ChainedOperator + implements TwoInputOperator { + private final TwoInputOperator inputOperator; + + @SuppressWarnings("unchecked") + ChainedTwoInputOperator(List operators, List> configs) { + super(operators, configs); + inputOperator = (TwoInputOperator) headOperator; + } + + @Override + public void processElement(Record record1, Record record2) { + inputOperator.processElement(record1, record2); + } + + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ForwardCollector.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ForwardCollector.java new file mode 100644 index 000000000..6d82fbc18 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ForwardCollector.java @@ -0,0 +1,23 @@ +package io.ray.streaming.operator.chain; + +import io.ray.streaming.api.collector.Collector; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.OneInputOperator; + +class ForwardCollector implements Collector { + private final OneInputOperator succeedingOperator; + + ForwardCollector(OneInputOperator succeedingOperator) { + this.succeedingOperator = succeedingOperator; + } + + @SuppressWarnings("unchecked") + @Override + public void collect(Record record) { + try { + succeedingOperator.processElement(record); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/JoinOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/JoinOperator.java new file mode 100644 index 000000000..4050d7879 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/JoinOperator.java @@ -0,0 +1,39 @@ +package io.ray.streaming.operator.impl; + +import io.ray.streaming.api.function.impl.JoinFunction; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.ChainStrategy; +import io.ray.streaming.operator.OperatorType; +import io.ray.streaming.operator.StreamOperator; +import io.ray.streaming.operator.TwoInputOperator; + +/** + * Join operator + * + * @param Type of the data in the left stream. + * @param Type of the data in the right stream. + * @param Type of the data in the join key. + * @param Type of the data in the joined stream. + */ +public class JoinOperator extends StreamOperator> implements + TwoInputOperator { + public JoinOperator() { + + } + + public JoinOperator(JoinFunction function) { + super(function); + setChainStrategy(ChainStrategy.HEAD); + } + + @Override + public void processElement(Record record1, Record record2) { + + } + + @Override + public OperatorType getOpType() { + return OperatorType.TWO_INPUT; + } + +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/ReduceOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/ReduceOperator.java index 1cde66ff3..f7c4e7ce3 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/ReduceOperator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/ReduceOperator.java @@ -5,6 +5,7 @@ import io.ray.streaming.api.context.RuntimeContext; import io.ray.streaming.api.function.impl.ReduceFunction; import io.ray.streaming.message.KeyRecord; import io.ray.streaming.message.Record; +import io.ray.streaming.operator.ChainStrategy; import io.ray.streaming.operator.OneInputOperator; import io.ray.streaming.operator.StreamOperator; import java.util.HashMap; @@ -18,6 +19,7 @@ public class ReduceOperator extends StreamOperator> impl public ReduceOperator(ReduceFunction reduceFunction) { super(reduceFunction); + setChainStrategy(ChainStrategy.HEAD); } @Override @@ -41,4 +43,5 @@ public class ReduceOperator extends StreamOperator> impl collect(record); } } + } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperatorImpl.java similarity index 75% rename from streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperator.java rename to streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperatorImpl.java index f85f27551..6b59e7779 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperatorImpl.java @@ -5,15 +5,19 @@ import io.ray.streaming.api.context.RuntimeContext; import io.ray.streaming.api.function.impl.SourceFunction; import io.ray.streaming.api.function.impl.SourceFunction.SourceContext; import io.ray.streaming.message.Record; +import io.ray.streaming.operator.ChainStrategy; import io.ray.streaming.operator.OperatorType; +import io.ray.streaming.operator.SourceOperator; import io.ray.streaming.operator.StreamOperator; import java.util.List; -public class SourceOperator extends StreamOperator> { +public class SourceOperatorImpl extends StreamOperator> + implements SourceOperator { private SourceContextImpl sourceContext; - public SourceOperator(SourceFunction function) { + public SourceOperatorImpl(SourceFunction function) { super(function); + setChainStrategy(ChainStrategy.HEAD); } @Override @@ -23,6 +27,7 @@ public class SourceOperator extends StreamOperator> { this.function.init(runtimeContext.getParallelism(), runtimeContext.getTaskIndex()); } + @Override public void run() { try { this.function.run(this.sourceContext); @@ -31,13 +36,17 @@ public class SourceOperator extends StreamOperator> { } } + @Override + public SourceContext getSourceContext() { + return sourceContext; + } + @Override public OperatorType getOpType() { return OperatorType.SOURCE; } class SourceContextImpl implements SourceContext { - private List collectors; public SourceContextImpl(List collectors) { @@ -47,9 +56,10 @@ public class SourceOperator extends StreamOperator> { @Override public void collect(T t) throws Exception { for (Collector collector : collectors) { - collector.collect(new Record(t)); + collector.collect(new Record<>(t)); } } } + } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java index aac706d2f..5786f0af4 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java @@ -100,6 +100,14 @@ public class PythonFunction implements Function { return functionInterface; } + public String toSimpleString() { + if (function != null) { + return "binary function"; + } else { + return String.format("%s-%s.%s", functionInterface, moduleName, functionName); + } + } + @Override public String toString() { StringJoiner stringJoiner = new StringJoiner(", ", diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonOperator.java index 045814c7e..81afb497c 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonOperator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonOperator.java @@ -1,11 +1,16 @@ package io.ray.streaming.python; +import com.google.common.base.Preconditions; import io.ray.streaming.api.Language; import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.api.function.Function; +import io.ray.streaming.operator.Operator; import io.ray.streaming.operator.OperatorType; import io.ray.streaming.operator.StreamOperator; import java.util.List; +import java.util.Map; import java.util.StringJoiner; +import java.util.stream.Collectors; /** * Represents a {@link StreamOperator} that wraps python {@link PythonFunction}. @@ -27,30 +32,6 @@ public class PythonOperator extends StreamOperator { this.className = null; } - @Override - public void open(List list, RuntimeContext runtimeContext) { - String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName()); - throw new UnsupportedOperationException(msg); - } - - @Override - public void finish() { - String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName()); - throw new UnsupportedOperationException(msg); - } - - @Override - public void close() { - String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName()); - throw new UnsupportedOperationException(msg); - } - - @Override - public OperatorType getOpType() { - String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName()); - throw new UnsupportedOperationException(msg); - } - @Override public Language getLanguage() { return Language.PYTHON; @@ -64,6 +45,48 @@ public class PythonOperator extends StreamOperator { return className; } + @Override + public void open(List list, RuntimeContext runtimeContext) { + throwUnsupportedException(); + } + + @Override + public void finish() { + throwUnsupportedException(); + } + + @Override + public void close() { + throwUnsupportedException(); + } + + void throwUnsupportedException() { + StackTraceElement[] trace = Thread.currentThread().getStackTrace(); + Preconditions.checkState(trace.length >= 2); + StackTraceElement traceElement = trace[2]; + String msg = String.format("Method %s.%s shouldn't be called.", + traceElement.getClassName(), traceElement.getMethodName()); + throw new UnsupportedOperationException(msg); + } + + @Override + public OperatorType getOpType() { + String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName()); + throw new UnsupportedOperationException(msg); + } + + @Override + public String getName() { + StringBuilder builder = new StringBuilder(); + builder.append(PythonOperator.class.getSimpleName()).append("["); + if (function != null) { + builder.append(((PythonFunction)function).toSimpleString()); + } else { + builder.append(moduleName).append(".").append(className); + } + return builder.append("]").toString(); + } + @Override public String toString() { StringJoiner stringJoiner = new StringJoiner(", ", @@ -76,4 +99,71 @@ public class PythonOperator extends StreamOperator { } return stringJoiner.toString(); } + + public static class ChainedPythonOperator extends PythonOperator { + private final List operators; + private final PythonOperator headOperator; + private final PythonOperator tailOperator; + private final List> configs; + + public ChainedPythonOperator( + List operators, List> configs) { + super(null); + Preconditions.checkArgument(!operators.isEmpty()); + this.operators = operators; + this.configs = configs; + this.headOperator = operators.get(0); + this.tailOperator = operators.get(operators.size() - 1); + } + + @Override + public OperatorType getOpType() { + return headOperator.getOpType(); + } + + @Override + public Language getLanguage() { + return Language.PYTHON; + } + + @Override + public String getName() { + return operators.stream().map(Operator::getName) + .collect(Collectors.joining(" -> ", "[", "]")); + } + + @Override + public String getModuleName() { + throwUnsupportedException(); + return null; // impossible + } + + @Override + public String getClassName() { + throwUnsupportedException(); + return null; // impossible + } + + @Override + public Function getFunction() { + throwUnsupportedException(); + return null; // impossible + } + + public List getOperators() { + return operators; + } + + public PythonOperator getHeadOperator() { + return headOperator; + } + + public PythonOperator getTailOperator() { + return tailOperator; + } + + public List> getConfigs() { + return configs; + } + } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java index 9f3bcd7a1..ecdbbd49c 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java @@ -24,6 +24,9 @@ public class PythonPartition implements Partition { "ray.streaming.partition", "KeyPartition"); public static final PythonPartition RoundRobinPartition = new PythonPartition( "ray.streaming.partition", "RoundRobinPartition"); + public static final String FORWARD_PARTITION_CLASS = "ForwardPartition"; + public static final PythonPartition ForwardPartition = new PythonPartition( + "ray.streaming.partition", FORWARD_PARTITION_CLASS); private byte[] partition; private String moduleName; @@ -66,6 +69,10 @@ public class PythonPartition implements Partition { return functionName; } + public boolean isConstructedFromBinary() { + return partition != null; + } + @Override public String toString() { StringJoiner stringJoiner = new StringJoiner(", ", diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonKeyDataStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonKeyDataStream.java index a095b761b..88cb3e391 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonKeyDataStream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonKeyDataStream.java @@ -2,6 +2,7 @@ package io.ray.streaming.python.stream; import io.ray.streaming.api.stream.DataStream; import io.ray.streaming.api.stream.KeyDataStream; +import io.ray.streaming.operator.ChainStrategy; import io.ray.streaming.python.PythonFunction; import io.ray.streaming.python.PythonFunction.FunctionInterface; import io.ray.streaming.python.PythonOperator; @@ -37,7 +38,9 @@ public class PythonKeyDataStream extends PythonDataStream implements PythonStrea */ public PythonDataStream reduce(PythonFunction func) { func.setFunctionInterface(FunctionInterface.REDUCE_FUNCTION); - return new PythonDataStream(this, new PythonOperator(func)); + PythonDataStream stream = new PythonDataStream(this, new PythonOperator(func)); + stream.withChainStrategy(ChainStrategy.HEAD); + return stream; } /** diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSource.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSource.java index a123c5cf8..b33c07ec8 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSource.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSource.java @@ -2,10 +2,10 @@ package io.ray.streaming.python.stream; import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.stream.StreamSource; +import io.ray.streaming.operator.ChainStrategy; import io.ray.streaming.python.PythonFunction; import io.ray.streaming.python.PythonFunction.FunctionInterface; import io.ray.streaming.python.PythonOperator; -import io.ray.streaming.python.PythonPartition; /** * Represents a source of the PythonStream. @@ -13,8 +13,8 @@ import io.ray.streaming.python.PythonPartition; public class PythonStreamSource extends PythonDataStream implements StreamSource { private PythonStreamSource(StreamingContext streamingContext, PythonFunction sourceFunction) { - super(streamingContext, new PythonOperator(sourceFunction), - PythonPartition.RoundRobinPartition); + super(streamingContext, new PythonOperator(sourceFunction)); + withChainStrategy(ChainStrategy.HEAD); } public static PythonStreamSource from(StreamingContext streamingContext, diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonUnionStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonUnionStream.java index 01f9087d5..2e4e4957e 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonUnionStream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonUnionStream.java @@ -14,6 +14,8 @@ public class PythonUnionStream extends PythonDataStream { private List unionStreams; public PythonUnionStream(PythonDataStream input, List others) { + // Union stream does not create a physical operation, so we don't have to set partition + // function for it. super(input, new PythonOperator( "ray.streaming.operator", "UnionOperator")); this.unionStreams = new ArrayList<>(); diff --git a/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphBuilderTest.java b/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphBuilderTest.java index 1f9c367df..37015585e 100644 --- a/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphBuilderTest.java +++ b/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphBuilderTest.java @@ -2,8 +2,8 @@ package io.ray.streaming.jobgraph; import com.google.common.collect.Lists; import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.partition.impl.ForwardPartition; import io.ray.streaming.api.partition.impl.KeyPartition; -import io.ray.streaming.api.partition.impl.RoundRobinPartition; import io.ray.streaming.api.stream.DataStream; import io.ray.streaming.api.stream.DataStreamSource; import io.ray.streaming.api.stream.StreamSink; @@ -20,14 +20,14 @@ public class JobGraphBuilderTest { @Test public void testDataSync() { JobGraph jobGraph = buildDataSyncJobGraph(); - List jobVertexList = jobGraph.getJobVertexList(); - List jobEdgeList = jobGraph.getJobEdgeList(); + List jobVertexList = jobGraph.getJobVertices(); + List jobEdgeList = jobGraph.getJobEdges(); Assert.assertEquals(jobVertexList.size(), 2); Assert.assertEquals(jobEdgeList.size(), 1); JobEdge jobEdge = jobEdgeList.get(0); - Assert.assertEquals(jobEdge.getPartition().getClass(), RoundRobinPartition.class); + Assert.assertEquals(jobEdge.getPartition().getClass(), ForwardPartition.class); JobVertex sinkVertex = jobVertexList.get(1); JobVertex sourceVertex = jobVertexList.get(0); @@ -50,8 +50,8 @@ public class JobGraphBuilderTest { @Test public void testKeyByJobGraph() { JobGraph jobGraph = buildKeyByJobGraph(); - List jobVertexList = jobGraph.getJobVertexList(); - List jobEdgeList = jobGraph.getJobEdgeList(); + List jobVertexList = jobGraph.getJobVertices(); + List jobEdgeList = jobGraph.getJobEdges(); Assert.assertEquals(jobVertexList.size(), 3); Assert.assertEquals(jobEdgeList.size(), 2); @@ -68,7 +68,7 @@ public class JobGraphBuilderTest { JobEdge source2KeyBy = jobEdgeList.get(1); Assert.assertEquals(keyBy2Sink.getPartition().getClass(), KeyPartition.class); - Assert.assertEquals(source2KeyBy.getPartition().getClass(), RoundRobinPartition.class); + Assert.assertEquals(source2KeyBy.getPartition().getClass(), ForwardPartition.class); } public JobGraph buildKeyByJobGraph() { @@ -88,8 +88,8 @@ public class JobGraphBuilderTest { JobGraph jobGraph = buildKeyByJobGraph(); jobGraph.generateDigraph(); String diGraph = jobGraph.getDigraph(); - System.out.println(diGraph); - Assert.assertTrue(diGraph.contains("1-SourceOperator -> 2-KeyByOperator")); - Assert.assertTrue(diGraph.contains("2-KeyByOperator -> 3-SinkOperator")); + LOG.info(diGraph); + Assert.assertTrue(diGraph.contains("\"1-SourceOperatorImpl\" -> \"2-KeyByOperator\"")); + Assert.assertTrue(diGraph.contains("\"2-KeyByOperator\" -> \"3-SinkOperator\"")); } } \ No newline at end of file diff --git a/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphOptimizerTest.java b/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphOptimizerTest.java new file mode 100644 index 000000000..a128d77a8 --- /dev/null +++ b/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphOptimizerTest.java @@ -0,0 +1,70 @@ +package io.ray.streaming.jobgraph; + +import static org.testng.Assert.assertEquals; + +import com.google.common.collect.Lists; +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.stream.DataStream; +import io.ray.streaming.api.stream.DataStreamSource; +import io.ray.streaming.python.PythonFunction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.annotations.Test; + +public class JobGraphOptimizerTest { + private static final Logger LOG = LoggerFactory.getLogger( JobGraphOptimizerTest.class ); + + @Test + public void testOptimize() { + StreamingContext context = StreamingContext.buildContext(); + DataStream source1 = DataStreamSource.fromCollection(context, + Lists.newArrayList(1 ,2 ,3)); + DataStream source2 = DataStreamSource.fromCollection(context, + Lists.newArrayList("1", "2", "3")); + DataStream source3 = DataStreamSource.fromCollection(context, + Lists.newArrayList("2", "3", "4")); + source1.filter(x -> x > 1) + .map(String::valueOf) + .union(source2) + .join(source3) + .sink(x -> System.out.println("Sink " + x)); + JobGraph jobGraph = new JobGraphBuilder(context.getStreamSinks()).build(); + LOG.info("Digraph {}", jobGraph.generateDigraph()); + assertEquals(jobGraph.getJobVertices().size(), 8); + + JobGraphOptimizer graphOptimizer = new JobGraphOptimizer(jobGraph); + JobGraph optimizedJobGraph = graphOptimizer.optimize(); + optimizedJobGraph.printJobGraph(); + LOG.info("Optimized graph {}", optimizedJobGraph.generateDigraph()); + assertEquals(optimizedJobGraph.getJobVertices().size(), 5); + } + + @Test + public void testOptimizeHybridStream() { + StreamingContext context = StreamingContext.buildContext(); + DataStream source1 = DataStreamSource.fromCollection(context, + Lists.newArrayList(1 ,2 ,3)); + DataStream source2 = DataStreamSource.fromCollection(context, + Lists.newArrayList("1", "2", "3")); + source1.asPythonStream() + .map(pyFunc(1)) + .filter(pyFunc(2)) + .union(source2.asPythonStream().filter(pyFunc(3)).map(pyFunc(4))) + .asJavaStream() + .sink(x -> System.out.println("Sink " + x)); + JobGraph jobGraph = new JobGraphBuilder(context.getStreamSinks()).build(); + LOG.info("Digraph {}", jobGraph.generateDigraph()); + assertEquals(jobGraph.getJobVertices().size(), 8); + + JobGraphOptimizer graphOptimizer = new JobGraphOptimizer(jobGraph); + JobGraph optimizedJobGraph = graphOptimizer.optimize(); + optimizedJobGraph.printJobGraph(); + LOG.info("Optimized graph {}", optimizedJobGraph.generateDigraph()); + assertEquals(optimizedJobGraph.getJobVertices().size(), 6); + } + + private PythonFunction pyFunc(int number) { + return new PythonFunction("module", "func" + number); + } + +} \ No newline at end of file diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java index e9c53659d..500721d3d 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java @@ -2,9 +2,9 @@ package io.ray.streaming.runtime.core.processor; import io.ray.streaming.operator.OneInputOperator; import io.ray.streaming.operator.OperatorType; +import io.ray.streaming.operator.SourceOperator; import io.ray.streaming.operator.StreamOperator; import io.ray.streaming.operator.TwoInputOperator; -import io.ray.streaming.operator.impl.SourceOperator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java index 35cdc750c..020f39d16 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java @@ -1,7 +1,7 @@ package io.ray.streaming.runtime.core.processor; import io.ray.streaming.message.Record; -import io.ray.streaming.operator.impl.SourceOperator; +import io.ray.streaming.operator.SourceOperator; /** * The processor for the stream sources, containing a SourceOperator. diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java index 287d526c8..e76963a47 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java @@ -30,7 +30,7 @@ public class GraphManagerImpl implements GraphManager { ExecutionGraph executionGraph = setupStructure(jobGraph); // set max parallelism - int maxParallelism = jobGraph.getJobVertexList().stream() + int maxParallelism = jobGraph.getJobVertices().stream() .map(JobVertex::getParallelism) .max(Integer::compareTo).get(); executionGraph.setMaxParallelism(maxParallelism); @@ -49,7 +49,7 @@ public class GraphManagerImpl implements GraphManager { // create vertex Map exeJobVertexMap = new LinkedHashMap<>(); long buildTime = executionGraph.getBuildTime(); - for (JobVertex jobVertex : jobGraph.getJobVertexList()) { + for (JobVertex jobVertex : jobGraph.getJobVertices()) { int jobVertexId = jobVertex.getVertexId(); exeJobVertexMap.put(jobVertexId, new ExecutionJobVertex( @@ -60,7 +60,7 @@ public class GraphManagerImpl implements GraphManager { } // connect vertex - jobGraph.getJobEdgeList().stream().forEach(jobEdge -> { + jobGraph.getJobEdges().forEach(jobEdge -> { ExecutionJobVertex source = exeJobVertexMap.get(jobEdge.getSrcVertexId()); ExecutionJobVertex target = exeJobVertexMap.get(jobEdge.getTargetVertexId()); @@ -70,8 +70,8 @@ public class GraphManagerImpl implements GraphManager { source.getOutputEdges().add(executionJobEdge); target.getInputEdges().add(executionJobEdge); - source.getExecutionVertices().stream().forEach(vertex -> { - target.getExecutionVertices().stream().forEach(outputVertex -> { + source.getExecutionVertices().forEach(vertex -> { + target.getExecutionVertices().forEach(outputVertex -> { ExecutionEdge executionEdge = new ExecutionEdge(vertex, outputVertex, executionJobEdge); vertex.getOutputEdges().add(executionEdge); outputVertex.getInputEdges().add(executionEdge); diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java index b66e1a125..4f93bc0c2 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java @@ -7,6 +7,7 @@ import io.ray.streaming.api.partition.Partition; import io.ray.streaming.operator.Operator; import io.ray.streaming.python.PythonFunction; import io.ray.streaming.python.PythonOperator; +import io.ray.streaming.python.PythonOperator.ChainedPythonOperator; import io.ray.streaming.python.PythonPartition; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionEdge; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; @@ -77,6 +78,7 @@ public class GraphPbBuilder { executionVertexBuilder.setOperator( ByteString.copyFrom( serializeOperator(executionVertex.getStreamOperator()))); + executionVertexBuilder.setChained(isPythonChainedOperator(executionVertex.getStreamOperator())); executionVertexBuilder.setWorkerActor( ByteString.copyFrom( ((NativeActorHandle) (executionVertex.getWorkerActor())).toBytes())); @@ -104,17 +106,35 @@ public class GraphPbBuilder { private byte[] serializeOperator(Operator operator) { if (operator instanceof PythonOperator) { - PythonOperator pythonOperator = (PythonOperator) operator; - return serializer.serialize(Arrays.asList( - serializeFunction(pythonOperator.getFunction()), - pythonOperator.getModuleName(), - pythonOperator.getClassName() - )); + if (isPythonChainedOperator(operator)) { + return serializePythonChainedOperator((ChainedPythonOperator) operator); + } else { + PythonOperator pythonOperator = (PythonOperator) operator; + return serializer.serialize(Arrays.asList( + serializeFunction(pythonOperator.getFunction()), + pythonOperator.getModuleName(), + pythonOperator.getClassName() + )); + } } else { return new byte[0]; } } + private boolean isPythonChainedOperator(Operator operator) { + return operator instanceof ChainedPythonOperator; + } + + private byte[] serializePythonChainedOperator(ChainedPythonOperator operator) { + List serializedOperators = operator.getOperators().stream() + .map(this::serializeOperator).collect(Collectors.toList()); + return serializer.serialize(Arrays.asList( + serializedOperators, + operator.getConfigs() + )); + } + + private byte[] serializeFunction(Function function) { if (function instanceof PythonFunction) { PythonFunction pyFunc = (PythonFunction) function; diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java index a23a781c3..64fa332e0 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java @@ -1,6 +1,6 @@ package io.ray.streaming.runtime.worker.tasks; -import io.ray.streaming.operator.impl.SourceOperator; +import io.ray.streaming.operator.SourceOperator; import io.ray.streaming.runtime.core.processor.Processor; import io.ray.streaming.runtime.core.processor.SourceProcessor; import io.ray.streaming.runtime.worker.JobWorker; diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java index 6577c81a7..36c3b71ad 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java @@ -7,6 +7,7 @@ import io.ray.streaming.api.stream.DataStreamSource; import io.ray.streaming.api.stream.StreamSink; import io.ray.streaming.jobgraph.JobGraph; import io.ray.streaming.jobgraph.JobGraphBuilder; +import io.ray.streaming.jobgraph.JobVertex; import io.ray.streaming.runtime.BaseUnitTest; import io.ray.streaming.runtime.config.StreamingConfig; import io.ray.streaming.runtime.config.master.ResourceConfig; @@ -40,10 +41,10 @@ public class ExecutionGraphTest extends BaseUnitTest { ExecutionGraph executionGraph = buildExecutionGraph(graphManager, jobGraph); List executionJobVertices = executionGraph.getExecutionJobVertexList(); - Assert.assertEquals(executionJobVertices.size(), jobGraph.getJobVertexList().size()); + Assert.assertEquals(executionJobVertices.size(), jobGraph.getJobVertices().size()); - int totalVertexNum = jobGraph.getJobVertexList().stream() - .mapToInt(vertex -> vertex.getParallelism()).sum(); + int totalVertexNum = jobGraph.getJobVertices().stream() + .mapToInt(JobVertex::getParallelism).sum(); Assert.assertEquals(executionGraph.getAllExecutionVertices().size(), totalVertexNum); Assert.assertEquals(executionGraph.getAllExecutionVertices().size(), executionGraph.getExecutionVertexIdGenerator().get()); @@ -66,7 +67,7 @@ public class ExecutionGraphTest extends BaseUnitTest { List downStreamVertices = downStream.getExecutionVertices(); upStreamVertices.forEach(vertex -> { Assert.assertEquals(vertex.getResource().get(ResourceType.CPU.name()), 2.0); - vertex.getOutputEdges().stream().forEach(upStreamOutPutEdge -> { + vertex.getOutputEdges().forEach(upStreamOutPutEdge -> { Assert.assertTrue(downStreamVertices.contains(upStreamOutPutEdge.getTargetExecutionVertex())); }); }); diff --git a/streaming/java/test.sh b/streaming/java/test.sh index 5641a271b..f0db1cf49 100755 --- a/streaming/java/test.sh +++ b/streaming/java/test.sh @@ -39,7 +39,7 @@ fi if [ $exit_code -ne 2 ] && [ $exit_code -ne 0 ] ; then if [ -d "/tmp/ray_streaming_java_test_output/" ] ; then echo "all test output" - for f in /tmp/ray_streaming_java_test_output/*; do + for f in /tmp/ray_streaming_java_test_output/*.{log,xml}; do if [ -f "$f" ]; then echo "Cat file $f" cat "$f" diff --git a/streaming/python/datastream.py b/streaming/python/datastream.py index 0fc50ea4b..91193be22 100644 --- a/streaming/python/datastream.py +++ b/streaming/python/datastream.py @@ -94,6 +94,18 @@ class Stream(ABC): def get_language(self): pass + def forward(self): + """Set the partition function of this {@link Stream} so that output + elements are forwarded to next operator locally.""" + self._gateway_client().call_method(self._j_stream, "forward") + return self + + def disable_chain(self): + """Disable chain for this stream so that it will be run in a separate + task.""" + self._gateway_client().call_method(self._j_stream, "disableChain") + return self + def _gateway_client(self): return self.get_streaming_context()._gateway_client diff --git a/streaming/python/operator.py b/streaming/python/operator.py index d92821de7..4952b2d00 100644 --- a/streaming/python/operator.py +++ b/streaming/python/operator.py @@ -1,12 +1,16 @@ import enum import importlib +import logging from abc import ABC, abstractmethod from ray import streaming from ray.streaming import function from ray.streaming import message +from ray.streaming.collector import Collector from ray.streaming.runtime import gateway_client +logger = logging.getLogger(__name__) + class OperatorType(enum.Enum): SOURCE = 0 # Sources are where your program reads its input from @@ -227,15 +231,93 @@ class UnionOperator(StreamOperator, OneInputOperator): self.collect(record) -_function_to_operator = { - function.SourceFunction: SourceOperator, - function.MapFunction: MapOperator, - function.FlatMapFunction: FlatMapOperator, - function.FilterFunction: FilterOperator, - function.KeyFunction: KeyByOperator, - function.ReduceFunction: ReduceOperator, - function.SinkFunction: SinkOperator, -} +class ChainedOperator(StreamOperator, ABC): + class ForwardCollector(Collector): + def __init__(self, succeeding_operator): + self.succeeding_operator = succeeding_operator + + def collect(self, record): + self.succeeding_operator.process_element(record) + + def __init__(self, operators, configs): + super().__init__(operators[0].func) + self.operators = operators + self.configs = configs + + def open(self, collectors, runtime_context): + # Dont' call super.open() as we `open` every operator separately. + num_operators = len(self.operators) + succeeding_collectors = [ + ChainedOperator.ForwardCollector(operator) + for operator in self.operators[1:] + ] + for i in range(0, num_operators - 1): + forward_collectors = [succeeding_collectors[i]] + self.operators[i].open( + forward_collectors, + self.__create_runtime_context(runtime_context, i)) + self.operators[-1].open( + collectors, + self.__create_runtime_context(runtime_context, num_operators - 1)) + + def operator_type(self) -> OperatorType: + return self.operators[0].operator_type() + + def __create_runtime_context(self, runtime_context, index): + def get_config(): + return self.configs[index] + + runtime_context.get_config = get_config + return runtime_context + + @staticmethod + def new_chained_operator(operators, configs): + operator_type = operators[0].operator_type() + logger.info( + "Building ChainedOperator from operators {} and configs {}." + .format(operators, configs)) + if operator_type == OperatorType.SOURCE: + return ChainedSourceOperator(operators, configs) + elif operator_type == OperatorType.ONE_INPUT: + return ChainedOneInputOperator(operators, configs) + elif operator_type == OperatorType.TWO_INPUT: + return ChainedTwoInputOperator(operators, configs) + else: + raise Exception("Current operator type is not supported") + + +class ChainedSourceOperator(ChainedOperator): + def __init__(self, operators, configs): + super().__init__(operators, configs) + + def run(self): + self.operators[0].run() + + +class ChainedOneInputOperator(ChainedOperator): + def __init__(self, operators, configs): + super().__init__(operators, configs) + + def process_element(self, record): + self.operators[0].process_element(record) + + +class ChainedTwoInputOperator(ChainedOperator): + def __init__(self, operators, configs): + super().__init__(operators, configs) + + def process_element(self, record1, record2): + self.operators[0].process_element(record1, record2) + + +def load_chained_operator(chained_operator_bytes: bytes): + """Load chained operator from serialized operators and configs""" + serialized_operators, configs = gateway_client.deserialize( + chained_operator_bytes) + operators = [ + load_operator(desc_bytes) for desc_bytes in serialized_operators + ] + return ChainedOperator.new_chained_operator(operators, configs) def load_operator(descriptor_operator_bytes: bytes): @@ -267,6 +349,17 @@ def load_operator(descriptor_operator_bytes: bytes): return cls() +_function_to_operator = { + function.SourceFunction: SourceOperator, + function.MapFunction: MapOperator, + function.FlatMapFunction: FlatMapOperator, + function.FilterFunction: FilterOperator, + function.KeyFunction: KeyByOperator, + function.ReduceFunction: ReduceOperator, + function.SinkFunction: SinkOperator, +} + + def create_operator_with_func(func: function.Function): """Create an operator according to a :class:`function.Function` diff --git a/streaming/python/partition.py b/streaming/python/partition.py index 198fbe3d7..fb30ba7cc 100644 --- a/streaming/python/partition.py +++ b/streaming/python/partition.py @@ -60,6 +60,17 @@ class RoundRobinPartition(Partition): return self.__partitions +class ForwardPartition(Partition): + """Default partition for operator if the operator can be chained with + succeeding operators.""" + + def __init__(self): + self.__partitions = [0] + + def partition(self, key_record, num_partition: int): + return self.__partitions + + class SimplePartition(Partition): """Wrap a python function as subclass of :class:`Partition`""" diff --git a/streaming/python/runtime/graph.py b/streaming/python/runtime/graph.py index b5e738e29..6db9cf39e 100644 --- a/streaming/python/runtime/graph.py +++ b/streaming/python/runtime/graph.py @@ -1,4 +1,5 @@ import enum +import logging import ray import ray.streaming.generated.remote_call_pb2 as remote_call_pb @@ -6,6 +7,8 @@ import ray.streaming.operator as operator import ray.streaming.partition as partition from ray.streaming.generated.streaming_pb2 import Language +logger = logging.getLogger(__name__) + class NodeType(enum.Enum): """ @@ -43,7 +46,13 @@ class ExecutionVertex: self.parallelism = vertex_pb.parallelism if vertex_pb.language == Language.PYTHON: operator_bytes = vertex_pb.operator # python operator descriptor - self.stream_operator = operator.load_operator(operator_bytes) + if vertex_pb.chained: + logger.info("Load chained operator") + self.stream_operator = operator.load_chained_operator( + operator_bytes) + else: + logger.info("Load operator") + self.stream_operator = operator.load_operator(operator_bytes) self.worker_actor = ray.actor.ActorHandle. \ _deserialization_helper(vertex_pb.worker_actor) self.container_id = vertex_pb.container_id diff --git a/streaming/src/protobuf/remote_call.proto b/streaming/src/protobuf/remote_call.proto index a55ca942f..7e3a9fea8 100644 --- a/streaming/src/protobuf/remote_call.proto +++ b/streaming/src/protobuf/remote_call.proto @@ -30,12 +30,13 @@ message ExecutionVertexContext { int32 parallelism = 5; // serialized operator bytes operator = 6; - bytes worker_actor = 7; - string container_id = 8; - uint64 build_time = 9; - Language language = 10; - map config = 11; - map resource = 12; + bool chained = 7; + bytes worker_actor = 8; + string container_id = 9; + uint64 build_time = 10; + Language language = 11; + map config = 12; + map resource = 13; } // vertices