mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 16:49:48 +08:00
[Streaming] operator chain (#8910)
This commit is contained in:
+3
-1
@@ -6,6 +6,7 @@ import io.ray.streaming.api.stream.StreamSink;
|
||||
import io.ray.streaming.client.JobClient;
|
||||
import io.ray.streaming.jobgraph.JobGraph;
|
||||
import io.ray.streaming.jobgraph.JobGraphBuilder;
|
||||
import io.ray.streaming.jobgraph.JobGraphOptimizer;
|
||||
import io.ray.streaming.util.Config;
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
@@ -56,7 +57,8 @@ public class StreamingContext implements Serializable {
|
||||
*/
|
||||
public void execute(String jobName) {
|
||||
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(this.streamSinks, jobName);
|
||||
this.jobGraph = jobGraphBuilder.build();
|
||||
JobGraph originalJobGraph = jobGraphBuilder.build();
|
||||
this.jobGraph = new JobGraphOptimizer(originalJobGraph).optimize();
|
||||
jobGraph.printJobGraph();
|
||||
LOG.info("JobGraph digraph\n{}", jobGraph.generateDigraph());
|
||||
|
||||
|
||||
+19
@@ -0,0 +1,19 @@
|
||||
package io.ray.streaming.api.partition.impl;
|
||||
|
||||
import io.ray.streaming.api.partition.Partition;
|
||||
|
||||
/**
|
||||
* Default partition for operator if the operator can be chained with succeeding operators.
|
||||
* Partition will be set to {@link RoundRobinPartition} if the operator can't be chiained with
|
||||
* succeeding operators.
|
||||
*
|
||||
* @param <T> Type of the input record.
|
||||
*/
|
||||
public class ForwardPartition<T> implements Partition<T> {
|
||||
private int[] partitions = new int[] {0};
|
||||
|
||||
@Override
|
||||
public int[] partition(T record, int numPartition) {
|
||||
return partitions;
|
||||
}
|
||||
}
|
||||
+2
-3
@@ -3,8 +3,7 @@ package io.ray.streaming.api.stream;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.function.impl.SourceFunction;
|
||||
import io.ray.streaming.api.function.internal.CollectionSourceFunction;
|
||||
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
|
||||
import io.ray.streaming.operator.impl.SourceOperator;
|
||||
import io.ray.streaming.operator.impl.SourceOperatorImpl;
|
||||
import java.util.Collection;
|
||||
|
||||
/**
|
||||
@@ -15,7 +14,7 @@ import java.util.Collection;
|
||||
public class DataStreamSource<T> extends DataStream<T> implements StreamSource<T> {
|
||||
|
||||
private DataStreamSource(StreamingContext streamingContext, SourceFunction<T> sourceFunction) {
|
||||
super(streamingContext, new SourceOperator<>(sourceFunction), new RoundRobinPartition<>());
|
||||
super(streamingContext, new SourceOperatorImpl<>(sourceFunction));
|
||||
}
|
||||
|
||||
public static <T> DataStreamSource<T> fromSource(
|
||||
|
||||
+24
-22
@@ -2,6 +2,7 @@ package io.ray.streaming.api.stream;
|
||||
|
||||
import io.ray.streaming.api.function.impl.JoinFunction;
|
||||
import io.ray.streaming.api.function.impl.KeyFunction;
|
||||
import io.ray.streaming.operator.impl.JoinOperator;
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
@@ -9,40 +10,42 @@ import java.io.Serializable;
|
||||
*
|
||||
* @param <L> Type of the data in the left stream.
|
||||
* @param <R> Type of the data in the right stream.
|
||||
* @param <J> Type of the data in the joined stream.
|
||||
* @param <O> Type of the data in the joined stream.
|
||||
*/
|
||||
public class JoinStream<L, R, J> extends DataStream<L> {
|
||||
public class JoinStream<L, R, O> extends DataStream<L> {
|
||||
private final DataStream<R> rightStream;
|
||||
|
||||
public JoinStream(DataStream<L> leftStream, DataStream<R> rightStream) {
|
||||
super(leftStream, null);
|
||||
super(leftStream, new JoinOperator<>());
|
||||
this.rightStream = rightStream;
|
||||
}
|
||||
|
||||
public DataStream<R> getRightStream() {
|
||||
return rightStream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply key-by to the left join stream.
|
||||
*/
|
||||
public <K> Where<L, R, J, K> where(KeyFunction<L, K> keyFunction) {
|
||||
public <K> Where<K> where(KeyFunction<L, K> keyFunction) {
|
||||
return new Where<>(this, keyFunction);
|
||||
}
|
||||
|
||||
/**
|
||||
* Where clause of the join transformation.
|
||||
*
|
||||
* @param <L> Type of the data in the left stream.
|
||||
* @param <R> Type of the data in the right stream.
|
||||
* @param <J> Type of the data in the joined stream.
|
||||
* @param <K> Type of the join key.
|
||||
*/
|
||||
class Where<L, R, J, K> implements Serializable {
|
||||
|
||||
private JoinStream<L, R, J> joinStream;
|
||||
class Where<K> implements Serializable {
|
||||
private JoinStream<L, R, O> joinStream;
|
||||
private KeyFunction<L, K> leftKeyByFunction;
|
||||
|
||||
public Where(JoinStream<L, R, J> joinStream, KeyFunction<L, K> leftKeyByFunction) {
|
||||
Where(JoinStream<L, R, O> joinStream, KeyFunction<L, K> leftKeyByFunction) {
|
||||
this.joinStream = joinStream;
|
||||
this.leftKeyByFunction = leftKeyByFunction;
|
||||
}
|
||||
|
||||
public Equal<L, R, J, K> equalLo(KeyFunction<R, K> rightKeyFunction) {
|
||||
public Equal<K> equalTo(KeyFunction<R, K> rightKeyFunction) {
|
||||
return new Equal<>(joinStream, leftKeyByFunction, rightKeyFunction);
|
||||
}
|
||||
}
|
||||
@@ -50,26 +53,25 @@ public class JoinStream<L, R, J> extends DataStream<L> {
|
||||
/**
|
||||
* Equal clause of the join transformation.
|
||||
*
|
||||
* @param <L> Type of the data in the left stream.
|
||||
* @param <R> Type of the data in the right stream.
|
||||
* @param <J> Type of the data in the joined stream.
|
||||
* @param <K> Type of the join key.
|
||||
*/
|
||||
class Equal<L, R, J, K> implements Serializable {
|
||||
|
||||
private JoinStream<L, R, J> joinStream;
|
||||
class Equal<K> implements Serializable {
|
||||
private JoinStream<L, R, O> joinStream;
|
||||
private KeyFunction<L, K> leftKeyByFunction;
|
||||
private KeyFunction<R, K> rightKeyByFunction;
|
||||
|
||||
public Equal(JoinStream<L, R, J> joinStream, KeyFunction<L, K> leftKeyByFunction,
|
||||
KeyFunction<R, K> rightKeyByFunction) {
|
||||
Equal(JoinStream<L, R, O> joinStream, KeyFunction<L, K> leftKeyByFunction,
|
||||
KeyFunction<R, K> rightKeyByFunction) {
|
||||
this.joinStream = joinStream;
|
||||
this.leftKeyByFunction = leftKeyByFunction;
|
||||
this.rightKeyByFunction = rightKeyByFunction;
|
||||
}
|
||||
|
||||
public DataStream<J> with(JoinFunction<L, R, J> joinFunction) {
|
||||
return (DataStream<J>) joinStream;
|
||||
@SuppressWarnings("unchecked")
|
||||
public DataStream<O> with(JoinFunction<L, R, O> joinFunction) {
|
||||
JoinOperator joinOperator = (JoinOperator) joinStream.getOperator();
|
||||
joinOperator.setFunction(joinFunction);
|
||||
return (DataStream<O>) joinStream;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+35
-10
@@ -4,7 +4,8 @@ import com.google.common.base.Preconditions;
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.partition.Partition;
|
||||
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
|
||||
import io.ray.streaming.api.partition.impl.ForwardPartition;
|
||||
import io.ray.streaming.operator.ChainStrategy;
|
||||
import io.ray.streaming.operator.Operator;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import io.ray.streaming.python.PythonPartition;
|
||||
@@ -30,8 +31,7 @@ public abstract class Stream<S extends Stream<S, T>, T>
|
||||
private Stream originalStream;
|
||||
|
||||
public Stream(StreamingContext streamingContext, StreamOperator streamOperator) {
|
||||
this(streamingContext, null, streamOperator,
|
||||
selectPartition(streamOperator));
|
||||
this(streamingContext, null, streamOperator, getForwardPartition(streamOperator));
|
||||
}
|
||||
|
||||
public Stream(StreamingContext streamingContext,
|
||||
@@ -42,7 +42,7 @@ public abstract class Stream<S extends Stream<S, T>, T>
|
||||
|
||||
public Stream(Stream inputStream, StreamOperator streamOperator) {
|
||||
this(inputStream.getStreamingContext(), inputStream, streamOperator,
|
||||
selectPartition(streamOperator));
|
||||
getForwardPartition(streamOperator));
|
||||
}
|
||||
|
||||
public Stream(Stream inputStream, StreamOperator streamOperator, Partition<T> partition) {
|
||||
@@ -50,9 +50,9 @@ public abstract class Stream<S extends Stream<S, T>, T>
|
||||
}
|
||||
|
||||
protected Stream(StreamingContext streamingContext,
|
||||
Stream inputStream,
|
||||
StreamOperator streamOperator,
|
||||
Partition<T> partition) {
|
||||
Stream inputStream,
|
||||
StreamOperator streamOperator,
|
||||
Partition<T> partition) {
|
||||
this.streamingContext = streamingContext;
|
||||
this.inputStream = inputStream;
|
||||
this.operator = streamOperator;
|
||||
@@ -73,15 +73,16 @@ public abstract class Stream<S extends Stream<S, T>, T>
|
||||
this.streamingContext = originalStream.getStreamingContext();
|
||||
this.inputStream = originalStream.getInputStream();
|
||||
this.operator = originalStream.getOperator();
|
||||
Preconditions.checkNotNull(operator);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static <T> Partition<T> selectPartition(Operator operator) {
|
||||
private static <T> Partition<T> getForwardPartition(Operator operator) {
|
||||
switch (operator.getLanguage()) {
|
||||
case PYTHON:
|
||||
return (Partition<T>) PythonPartition.RoundRobinPartition;
|
||||
return (Partition<T>) PythonPartition.ForwardPartition;
|
||||
case JAVA:
|
||||
return new RoundRobinPartition<>();
|
||||
return new ForwardPartition<>();
|
||||
default:
|
||||
throw new UnsupportedOperationException(
|
||||
"Unsupported language " + operator.getLanguage());
|
||||
@@ -165,5 +166,29 @@ public abstract class Stream<S extends Stream<S, T>, T>
|
||||
return originalStream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set chain strategy for this stream
|
||||
*/
|
||||
public S withChainStrategy(ChainStrategy chainStrategy) {
|
||||
Preconditions.checkArgument(!isProxyStream());
|
||||
operator.setChainStrategy(chainStrategy);
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* Disable chain for this stream
|
||||
*/
|
||||
public S disableChain() {
|
||||
return withChainStrategy(ChainStrategy.NEVER);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the partition function of this {@link Stream} so that output elements are forwarded to
|
||||
* next operator locally.
|
||||
*/
|
||||
public S forward() {
|
||||
return setPartition(getForwardPartition(operator));
|
||||
}
|
||||
|
||||
public abstract Language getLanguage();
|
||||
}
|
||||
|
||||
+2
@@ -16,6 +16,8 @@ public class UnionStream<T> extends DataStream<T> {
|
||||
private List<DataStream<T>> unionStreams;
|
||||
|
||||
public UnionStream(DataStream<T> input, List<DataStream<T>> streams) {
|
||||
// Union stream does not create a physical operation, so we don't have to set partition
|
||||
// function for it.
|
||||
super(input, new UnionOperator());
|
||||
this.unionStreams = new ArrayList<>();
|
||||
streams.forEach(this::addStream);
|
||||
|
||||
+57
-18
@@ -5,6 +5,8 @@ import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -17,15 +19,24 @@ public class JobGraph implements Serializable {
|
||||
|
||||
private final String jobName;
|
||||
private final Map<String, String> jobConfig;
|
||||
private List<JobVertex> jobVertexList;
|
||||
private List<JobEdge> jobEdgeList;
|
||||
private List<JobVertex> jobVertices;
|
||||
private List<JobEdge> jobEdges;
|
||||
private String digraph;
|
||||
|
||||
public JobGraph(String jobName, Map<String, String> jobConfig) {
|
||||
this.jobName = jobName;
|
||||
this.jobConfig = jobConfig;
|
||||
this.jobVertexList = new ArrayList<>();
|
||||
this.jobEdgeList = new ArrayList<>();
|
||||
this.jobVertices = new ArrayList<>();
|
||||
this.jobEdges = new ArrayList<>();
|
||||
}
|
||||
|
||||
public JobGraph(String jobName, Map<String, String> jobConfig,
|
||||
List<JobVertex> jobVertices, List<JobEdge> jobEdges) {
|
||||
this.jobName = jobName;
|
||||
this.jobConfig = jobConfig;
|
||||
this.jobVertices = jobVertices;
|
||||
this.jobEdges = jobEdges;
|
||||
generateDigraph();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -36,12 +47,12 @@ public class JobGraph implements Serializable {
|
||||
*/
|
||||
public String generateDigraph() {
|
||||
StringBuilder digraph = new StringBuilder();
|
||||
digraph.append("digraph ").append(jobName + " ").append(" {");
|
||||
digraph.append("digraph ").append(jobName).append(" ").append(" {");
|
||||
|
||||
for (JobEdge jobEdge : jobEdgeList) {
|
||||
for (JobEdge jobEdge : jobEdges) {
|
||||
String srcNode = null;
|
||||
String targetNode = null;
|
||||
for (JobVertex jobVertex : jobVertexList) {
|
||||
for (JobVertex jobVertex : jobVertices) {
|
||||
if (jobEdge.getSrcVertexId() == jobVertex.getVertexId()) {
|
||||
srcNode = jobVertex.getVertexId() + "-" + jobVertex.getStreamOperator().getName();
|
||||
} else if (jobEdge.getTargetVertexId() == jobVertex.getVertexId()) {
|
||||
@@ -49,7 +60,7 @@ public class JobGraph implements Serializable {
|
||||
}
|
||||
}
|
||||
digraph.append(System.getProperty("line.separator"));
|
||||
digraph.append(srcNode).append(" -> ").append(targetNode);
|
||||
digraph.append(String.format(" \"%s\" -> \"%s\"", srcNode, targetNode));
|
||||
}
|
||||
digraph.append(System.getProperty("line.separator")).append("}");
|
||||
|
||||
@@ -58,19 +69,47 @@ public class JobGraph implements Serializable {
|
||||
}
|
||||
|
||||
public void addVertex(JobVertex vertex) {
|
||||
this.jobVertexList.add(vertex);
|
||||
this.jobVertices.add(vertex);
|
||||
}
|
||||
|
||||
public void addEdge(JobEdge jobEdge) {
|
||||
this.jobEdgeList.add(jobEdge);
|
||||
this.jobEdges.add(jobEdge);
|
||||
}
|
||||
|
||||
public List<JobVertex> getJobVertexList() {
|
||||
return jobVertexList;
|
||||
public List<JobVertex> getJobVertices() {
|
||||
return jobVertices;
|
||||
}
|
||||
|
||||
public List<JobEdge> getJobEdgeList() {
|
||||
return jobEdgeList;
|
||||
public List<JobVertex> getSourceVertices() {
|
||||
return jobVertices.stream()
|
||||
.filter(v -> v.getVertexType() == VertexType.SOURCE)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public List<JobVertex> getSinkVertices() {
|
||||
return jobVertices.stream()
|
||||
.filter(v -> v.getVertexType() == VertexType.SINK)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public JobVertex getVertex(int vertexId) {
|
||||
return jobVertices.stream().filter(v -> v.getVertexId() == vertexId).findFirst().get();
|
||||
}
|
||||
|
||||
public List<JobEdge> getJobEdges() {
|
||||
return jobEdges;
|
||||
}
|
||||
|
||||
public Set<JobEdge> getVertexInputEdges(int vertexId) {
|
||||
return jobEdges.stream()
|
||||
.filter(jobEdge -> jobEdge.getTargetVertexId() == vertexId)
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
public Set<JobEdge> getVertexOutputEdges(int vertexId) {
|
||||
return jobEdges.stream()
|
||||
.filter(jobEdge -> jobEdge.getSrcVertexId() == vertexId)
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
public String getDigraph() {
|
||||
@@ -90,17 +129,17 @@ public class JobGraph implements Serializable {
|
||||
return;
|
||||
}
|
||||
LOG.info("Printing job graph:");
|
||||
for (JobVertex jobVertex : jobVertexList) {
|
||||
for (JobVertex jobVertex : jobVertices) {
|
||||
LOG.info(jobVertex.toString());
|
||||
}
|
||||
for (JobEdge jobEdge : jobEdgeList) {
|
||||
for (JobEdge jobEdge : jobEdges) {
|
||||
LOG.info(jobEdge.toString());
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isCrossLanguageGraph() {
|
||||
Language language = jobVertexList.get(0).getLanguage();
|
||||
for (JobVertex jobVertex : jobVertexList) {
|
||||
Language language = jobVertices.get(0).getLanguage();
|
||||
for (JobVertex jobVertex : jobVertices) {
|
||||
if (jobVertex.getLanguage() != language) {
|
||||
return true;
|
||||
}
|
||||
|
||||
+15
-5
@@ -2,6 +2,7 @@ package io.ray.streaming.jobgraph;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.streaming.api.stream.DataStream;
|
||||
import io.ray.streaming.api.stream.JoinStream;
|
||||
import io.ray.streaming.api.stream.Stream;
|
||||
import io.ray.streaming.api.stream.StreamSink;
|
||||
import io.ray.streaming.api.stream.StreamSource;
|
||||
@@ -26,7 +27,7 @@ public class JobGraphBuilder {
|
||||
private List<StreamSink> streamSinkList;
|
||||
|
||||
public JobGraphBuilder(List<StreamSink> streamSinkList) {
|
||||
this(streamSinkList, "job-" + System.currentTimeMillis());
|
||||
this(streamSinkList, "job_" + System.currentTimeMillis());
|
||||
}
|
||||
|
||||
public JobGraphBuilder(List<StreamSink> streamSinkList, String jobName) {
|
||||
@@ -61,18 +62,20 @@ public class JobGraphBuilder {
|
||||
"Reference stream should be skipped.");
|
||||
int vertexId = stream.getId();
|
||||
int parallelism = stream.getParallelism();
|
||||
Map<String, String> config = stream.getConfig();
|
||||
JobVertex jobVertex;
|
||||
if (stream instanceof StreamSink) {
|
||||
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator);
|
||||
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator, config);
|
||||
Stream parentStream = stream.getInputStream();
|
||||
int inputVertexId = parentStream.getId();
|
||||
JobEdge jobEdge = new JobEdge(inputVertexId, vertexId, parentStream.getPartition());
|
||||
this.jobGraph.addEdge(jobEdge);
|
||||
processStream(parentStream);
|
||||
} else if (stream instanceof StreamSource) {
|
||||
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SOURCE, streamOperator);
|
||||
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SOURCE, streamOperator, config);
|
||||
} else if (stream instanceof DataStream || stream instanceof PythonDataStream) {
|
||||
jobVertex = new JobVertex(vertexId, parallelism, VertexType.TRANSFORMATION, streamOperator);
|
||||
jobVertex = new JobVertex(
|
||||
vertexId, parallelism, VertexType.TRANSFORMATION, streamOperator, config);
|
||||
Stream parentStream = stream.getInputStream();
|
||||
int inputVertexId = parentStream.getId();
|
||||
JobEdge jobEdge = new JobEdge(inputVertexId, vertexId, parentStream.getPartition());
|
||||
@@ -92,10 +95,17 @@ public class JobGraphBuilder {
|
||||
this.jobGraph.addEdge(otherEdge);
|
||||
processStream(otherStream);
|
||||
}
|
||||
|
||||
// process join stream
|
||||
if (stream instanceof JoinStream) {
|
||||
DataStream rightStream = ((JoinStream) stream).getRightStream();
|
||||
this.jobGraph.addEdge(
|
||||
new JobEdge(rightStream.getId(), vertexId, rightStream.getPartition()));
|
||||
processStream(rightStream);
|
||||
}
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Unsupported stream: " + stream);
|
||||
}
|
||||
jobVertex.setConfig(stream.getConfig());
|
||||
this.jobGraph.addVertex(jobVertex);
|
||||
}
|
||||
|
||||
|
||||
+187
@@ -0,0 +1,187 @@
|
||||
package io.ray.streaming.jobgraph;
|
||||
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.api.partition.Partition;
|
||||
import io.ray.streaming.api.partition.impl.ForwardPartition;
|
||||
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
|
||||
import io.ray.streaming.operator.ChainStrategy;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import io.ray.streaming.operator.chain.ChainedOperator;
|
||||
import io.ray.streaming.python.PythonOperator;
|
||||
import io.ray.streaming.python.PythonOperator.ChainedPythonOperator;
|
||||
import io.ray.streaming.python.PythonPartition;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
/**
|
||||
* Optimize job graph by chaining some operators so that some operators can be run in the
|
||||
* same thread.
|
||||
*/
|
||||
public class JobGraphOptimizer {
|
||||
private final JobGraph jobGraph;
|
||||
private Set<JobVertex> visited = new HashSet<>();
|
||||
// vertex id -> vertex
|
||||
private Map<Integer, JobVertex> vertexMap;
|
||||
private Map<JobVertex, Set<JobEdge>> outputEdgesMap;
|
||||
// tail vertex id -> mergedVertex
|
||||
private Map<Integer, Pair<JobVertex, List<JobVertex>>> mergedVertexMap;
|
||||
|
||||
public JobGraphOptimizer(JobGraph jobGraph) {
|
||||
this.jobGraph = jobGraph;
|
||||
vertexMap = jobGraph.getJobVertices().stream()
|
||||
.collect(Collectors.toMap(JobVertex::getVertexId, Function.identity()));
|
||||
outputEdgesMap = vertexMap.keySet().stream().collect(Collectors.toMap(
|
||||
id -> vertexMap.get(id), id -> new HashSet<>(jobGraph.getVertexOutputEdges(id))));
|
||||
mergedVertexMap = new HashMap<>();
|
||||
}
|
||||
|
||||
public JobGraph optimize() {
|
||||
// Deep-first traverse nodes from source to sink to merge vertices that can be chained
|
||||
// together.
|
||||
jobGraph.getSourceVertices().forEach(vertex -> {
|
||||
List<JobVertex> verticesToMerge = new ArrayList<>();
|
||||
verticesToMerge.add(vertex);
|
||||
mergeVerticesRecursively(vertex, verticesToMerge);
|
||||
});
|
||||
|
||||
List<JobVertex> vertices = mergedVertexMap.values().stream()
|
||||
.map(Pair::getLeft).collect(Collectors.toList());
|
||||
|
||||
return new JobGraph(jobGraph.getJobName(), jobGraph.getJobConfig(), vertices, createEdges());
|
||||
}
|
||||
|
||||
private void mergeVerticesRecursively(JobVertex vertex, List<JobVertex> verticesToMerge) {
|
||||
if (!visited.contains(vertex)) {
|
||||
visited.add(vertex);
|
||||
Set<JobEdge> outputEdges = outputEdgesMap.get(vertex);
|
||||
if (outputEdges.isEmpty()) {
|
||||
mergeAndAddVertex(verticesToMerge);
|
||||
} else {
|
||||
outputEdges.forEach(edge -> {
|
||||
JobVertex succeedingVertex = vertexMap.get(edge.getTargetVertexId());
|
||||
if (canBeChained(vertex, succeedingVertex, edge)) {
|
||||
verticesToMerge.add(succeedingVertex);
|
||||
mergeVerticesRecursively(succeedingVertex, verticesToMerge);
|
||||
} else {
|
||||
mergeAndAddVertex(verticesToMerge);
|
||||
List<JobVertex> newMergedVertices = new ArrayList<>();
|
||||
newMergedVertices.add(succeedingVertex);
|
||||
mergeVerticesRecursively(succeedingVertex, newMergedVertices);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void mergeAndAddVertex(List<JobVertex> verticesToMerge) {
|
||||
JobVertex mergedVertex;
|
||||
JobVertex headVertex = verticesToMerge.get(0);
|
||||
Language language = headVertex.getLanguage();
|
||||
if (verticesToMerge.size() == 1) {
|
||||
// no chain
|
||||
mergedVertex = headVertex;
|
||||
} else {
|
||||
List<StreamOperator> operators = verticesToMerge.stream()
|
||||
.map(v -> vertexMap.get(v.getVertexId()).getStreamOperator())
|
||||
.collect(Collectors.toList());
|
||||
List<Map<String, String>> configs = verticesToMerge.stream()
|
||||
.map(v -> vertexMap.get(v.getVertexId()).getConfig())
|
||||
.collect(Collectors.toList());
|
||||
StreamOperator operator;
|
||||
if (language == Language.JAVA) {
|
||||
operator = ChainedOperator.newChainedOperator(operators, configs);
|
||||
} else {
|
||||
List<PythonOperator> pythonOperators = operators.stream()
|
||||
.map(o -> (PythonOperator) o).collect(Collectors.toList());
|
||||
operator = new ChainedPythonOperator(pythonOperators, configs);
|
||||
}
|
||||
// chained operator config is placed into `ChainedOperator`.
|
||||
mergedVertex = new JobVertex(headVertex.getVertexId(), headVertex.getParallelism(),
|
||||
headVertex.getVertexType(), operator, new HashMap<>());
|
||||
}
|
||||
|
||||
mergedVertexMap.put(mergedVertex.getVertexId(), Pair.of(mergedVertex, verticesToMerge));
|
||||
}
|
||||
|
||||
private List<JobEdge> createEdges() {
|
||||
List<JobEdge> edges = new ArrayList<>();
|
||||
mergedVertexMap.forEach((id, pair) -> {
|
||||
JobVertex mergedVertex = pair.getLeft();
|
||||
List<JobVertex> mergedVertices = pair.getRight();
|
||||
JobVertex tailVertex = mergedVertices.get(mergedVertices.size() - 1);
|
||||
// input edge will be set up in input vertices
|
||||
if (outputEdgesMap.containsKey(tailVertex)) {
|
||||
outputEdgesMap.get(tailVertex).forEach(edge -> {
|
||||
Pair<JobVertex, List<JobVertex>> downstreamPair =
|
||||
mergedVertexMap.get(edge.getTargetVertexId());
|
||||
// change ForwardPartition to RoundRobinPartition.
|
||||
Partition partition = changePartition(edge.getPartition());
|
||||
JobEdge newEdge = new JobEdge(
|
||||
mergedVertex.getVertexId(),
|
||||
downstreamPair.getLeft().getVertexId(),
|
||||
partition);
|
||||
edges.add(newEdge);
|
||||
});
|
||||
|
||||
}
|
||||
});
|
||||
return edges;
|
||||
}
|
||||
|
||||
/**
|
||||
* Change ForwardPartition to RoundRobinPartition.
|
||||
*/
|
||||
private Partition changePartition(Partition partition) {
|
||||
if (partition instanceof PythonPartition) {
|
||||
PythonPartition pythonPartition = (PythonPartition) partition;
|
||||
if (!pythonPartition.isConstructedFromBinary() &&
|
||||
pythonPartition.getFunctionName().equals(PythonPartition.FORWARD_PARTITION_CLASS)) {
|
||||
return PythonPartition.RoundRobinPartition;
|
||||
} else {
|
||||
return partition;
|
||||
}
|
||||
} else {
|
||||
if (partition instanceof ForwardPartition) {
|
||||
return new RoundRobinPartition();
|
||||
} else {
|
||||
return partition;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private boolean canBeChained(JobVertex precedingVertex,
|
||||
JobVertex succeedingVertex,
|
||||
JobEdge edge) {
|
||||
if (jobGraph.getVertexOutputEdges(precedingVertex.getVertexId()).size() > 1 ||
|
||||
jobGraph.getVertexInputEdges(succeedingVertex.getVertexId()).size() > 1) {
|
||||
return false;
|
||||
}
|
||||
if (precedingVertex.getParallelism() != succeedingVertex.getParallelism()) {
|
||||
return false;
|
||||
}
|
||||
if (precedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.NEVER
|
||||
|| succeedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.NEVER
|
||||
|| succeedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.HEAD) {
|
||||
return false;
|
||||
}
|
||||
if (precedingVertex.getLanguage() != succeedingVertex.getLanguage()) {
|
||||
return false;
|
||||
}
|
||||
Partition partition = edge.getPartition();
|
||||
if (!(partition instanceof PythonPartition)) {
|
||||
return partition instanceof ForwardPartition;
|
||||
} else {
|
||||
PythonPartition pythonPartition = (PythonPartition) partition;
|
||||
return !pythonPartition.isConstructedFromBinary() &&
|
||||
pythonPartition.getFunctionName().equals(PythonPartition.FORWARD_PARTITION_CLASS);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import java.util.Map;
|
||||
* Job vertex is a cell node where logic is executed.
|
||||
*/
|
||||
public class JobVertex implements Serializable {
|
||||
|
||||
private int vertexId;
|
||||
private int parallelism;
|
||||
private VertexType vertexType;
|
||||
@@ -20,12 +21,14 @@ public class JobVertex implements Serializable {
|
||||
public JobVertex(int vertexId,
|
||||
int parallelism,
|
||||
VertexType vertexType,
|
||||
StreamOperator streamOperator) {
|
||||
StreamOperator streamOperator,
|
||||
Map<String, String> config) {
|
||||
this.vertexId = vertexId;
|
||||
this.parallelism = parallelism;
|
||||
this.vertexType = vertexType;
|
||||
this.streamOperator = streamOperator;
|
||||
this.language = streamOperator.getLanguage();
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
public int getVertexId() {
|
||||
@@ -67,4 +70,5 @@ public class JobVertex implements Serializable {
|
||||
.add("config", config)
|
||||
.toString();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
+20
@@ -0,0 +1,20 @@
|
||||
package io.ray.streaming.operator;
|
||||
|
||||
/**
|
||||
* Chain strategy for streaming operators. Chained operators are run in the same thread.
|
||||
*/
|
||||
public enum ChainStrategy {
|
||||
/**
|
||||
* The operator won't be chained with preceding operators, but maybe chained with succeeding
|
||||
* operators.
|
||||
*/
|
||||
HEAD,
|
||||
/**
|
||||
* Operators will be chained together when possible.
|
||||
*/
|
||||
ALWAYS,
|
||||
/**
|
||||
* The operator won't be chained with any operator.
|
||||
*/
|
||||
NEVER
|
||||
}
|
||||
@@ -9,6 +9,8 @@ import java.util.List;
|
||||
|
||||
public interface Operator extends Serializable {
|
||||
|
||||
String getName();
|
||||
|
||||
void open(List<Collector> collectors, RuntimeContext runtimeContext);
|
||||
|
||||
void finish();
|
||||
@@ -20,4 +22,7 @@ public interface Operator extends Serializable {
|
||||
Language getLanguage();
|
||||
|
||||
OperatorType getOpType();
|
||||
|
||||
ChainStrategy getChainStrategy();
|
||||
|
||||
}
|
||||
|
||||
+14
@@ -0,0 +1,14 @@
|
||||
package io.ray.streaming.operator;
|
||||
|
||||
import io.ray.streaming.api.function.impl.SourceFunction.SourceContext;
|
||||
|
||||
public interface SourceOperator<T> extends Operator {
|
||||
|
||||
void run();
|
||||
|
||||
SourceContext<T> getSourceContext();
|
||||
|
||||
default OperatorType getOpType() {
|
||||
return OperatorType.SOURCE;
|
||||
}
|
||||
}
|
||||
+22
-3
@@ -12,13 +12,22 @@ import java.util.List;
|
||||
|
||||
public abstract class StreamOperator<F extends Function> implements Operator {
|
||||
protected final String name;
|
||||
protected final F function;
|
||||
protected final RichFunction richFunction;
|
||||
protected F function;
|
||||
protected RichFunction richFunction;
|
||||
protected List<Collector> collectorList;
|
||||
protected RuntimeContext runtimeContext;
|
||||
private ChainStrategy chainStrategy = ChainStrategy.ALWAYS;
|
||||
|
||||
public StreamOperator(F function) {
|
||||
protected StreamOperator() {
|
||||
this.name = getClass().getSimpleName();
|
||||
}
|
||||
|
||||
protected StreamOperator(F function) {
|
||||
this();
|
||||
setFunction(function);
|
||||
}
|
||||
|
||||
public void setFunction(F function) {
|
||||
this.function = function;
|
||||
this.richFunction = Functions.wrap(function);
|
||||
}
|
||||
@@ -62,7 +71,17 @@ public abstract class StreamOperator<F extends Function> implements Operator {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public void setChainStrategy(ChainStrategy chainStrategy) {
|
||||
this.chainStrategy = chainStrategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChainStrategy getChainStrategy() {
|
||||
return chainStrategy;
|
||||
}
|
||||
}
|
||||
|
||||
+169
@@ -0,0 +1,169 @@
|
||||
package io.ray.streaming.operator.chain;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.api.collector.Collector;
|
||||
import io.ray.streaming.api.context.RuntimeContext;
|
||||
import io.ray.streaming.api.function.Function;
|
||||
import io.ray.streaming.api.function.impl.SourceFunction.SourceContext;
|
||||
import io.ray.streaming.message.Record;
|
||||
import io.ray.streaming.operator.OneInputOperator;
|
||||
import io.ray.streaming.operator.Operator;
|
||||
import io.ray.streaming.operator.OperatorType;
|
||||
import io.ray.streaming.operator.SourceOperator;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import io.ray.streaming.operator.TwoInputOperator;
|
||||
import java.lang.reflect.Proxy;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Abstract base class for chained operators.
|
||||
*/
|
||||
public abstract class ChainedOperator extends StreamOperator<Function> {
|
||||
protected final List<StreamOperator> operators;
|
||||
protected final Operator headOperator;
|
||||
protected final Operator tailOperator;
|
||||
private final List<Map<String, String>> configs;
|
||||
|
||||
public ChainedOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
|
||||
Preconditions.checkArgument(operators.size() >= 2,
|
||||
"Need at lease two operators to be chained together");
|
||||
operators.stream().skip(1)
|
||||
.forEach(operator -> Preconditions.checkArgument(operator instanceof OneInputOperator));
|
||||
this.operators = operators;
|
||||
this.configs = configs;
|
||||
this.headOperator = operators.get(0);
|
||||
this.tailOperator = operators.get(operators.size() - 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void open(List<Collector> collectorList, RuntimeContext runtimeContext) {
|
||||
// Dont' call super.open() as we `open` every operator separately.
|
||||
List<ForwardCollector> succeedingCollectors = operators.stream().skip(1)
|
||||
.map(operator -> new ForwardCollector((OneInputOperator) operator))
|
||||
.collect(Collectors.toList());
|
||||
for (int i = 0; i < operators.size() - 1; i++) {
|
||||
StreamOperator operator = operators.get(i);
|
||||
List<ForwardCollector> forwardCollectors =
|
||||
Collections.singletonList(succeedingCollectors.get(i));
|
||||
operator.open(forwardCollectors, createRuntimeContext(runtimeContext, i));
|
||||
}
|
||||
// tail operator send data to downstream using provided collectors.
|
||||
tailOperator.open(collectorList, createRuntimeContext(runtimeContext, operators.size() - 1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperatorType getOpType() {
|
||||
return headOperator.getOpType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Language getLanguage() {
|
||||
return headOperator.getLanguage();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return operators.stream().map(Operator::getName)
|
||||
.collect(Collectors.joining(" -> ", "[", "]"));
|
||||
}
|
||||
|
||||
public List<StreamOperator> getOperators() {
|
||||
return operators;
|
||||
}
|
||||
|
||||
public Operator getHeadOperator() {
|
||||
return headOperator;
|
||||
}
|
||||
|
||||
public Operator getTailOperator() {
|
||||
return tailOperator;
|
||||
}
|
||||
|
||||
private RuntimeContext createRuntimeContext(RuntimeContext runtimeContext, int index) {
|
||||
return (RuntimeContext) Proxy.newProxyInstance(runtimeContext.getClass().getClassLoader(),
|
||||
new Class[] {RuntimeContext.class},
|
||||
(proxy, method, methodArgs) -> {
|
||||
if (method.getName().equals("getConfig")) {
|
||||
return configs.get(index);
|
||||
} else {
|
||||
return method.invoke(runtimeContext, methodArgs);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public static ChainedOperator newChainedOperator(
|
||||
List<StreamOperator> operators,
|
||||
List<Map<String, String>> configs) {
|
||||
switch (operators.get(0).getOpType()) {
|
||||
case SOURCE:
|
||||
return new ChainedSourceOperator(operators, configs);
|
||||
case ONE_INPUT:
|
||||
return new ChainedOneInputOperator(operators, configs);
|
||||
case TWO_INPUT:
|
||||
return new ChainedTwoInputOperator(operators, configs);
|
||||
default:
|
||||
throw new IllegalArgumentException(
|
||||
"Unsupported operator type " + operators.get(0).getOpType());
|
||||
}
|
||||
}
|
||||
|
||||
static class ChainedSourceOperator<T> extends ChainedOperator
|
||||
implements SourceOperator<T> {
|
||||
private final SourceOperator<T> sourceOperator;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
ChainedSourceOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
|
||||
super(operators, configs);
|
||||
sourceOperator = (SourceOperator<T>) headOperator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
sourceOperator.run();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SourceContext<T> getSourceContext() {
|
||||
return sourceOperator.getSourceContext();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static class ChainedOneInputOperator<T> extends ChainedOperator
|
||||
implements OneInputOperator<T> {
|
||||
private final OneInputOperator<T> inputOperator;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
ChainedOneInputOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
|
||||
super(operators, configs);
|
||||
inputOperator = (OneInputOperator<T>) headOperator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void processElement(Record<T> record) throws Exception {
|
||||
inputOperator.processElement(record);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static class ChainedTwoInputOperator<L, R> extends ChainedOperator
|
||||
implements TwoInputOperator<L, R> {
|
||||
private final TwoInputOperator<L, R> inputOperator;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
ChainedTwoInputOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
|
||||
super(operators, configs);
|
||||
inputOperator = (TwoInputOperator<L, R>) headOperator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void processElement(Record<L> record1, Record<R> record2) {
|
||||
inputOperator.processElement(record1, record2);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
package io.ray.streaming.operator.chain;
|
||||
|
||||
import io.ray.streaming.api.collector.Collector;
|
||||
import io.ray.streaming.message.Record;
|
||||
import io.ray.streaming.operator.OneInputOperator;
|
||||
|
||||
class ForwardCollector implements Collector<Record> {
|
||||
private final OneInputOperator succeedingOperator;
|
||||
|
||||
ForwardCollector(OneInputOperator succeedingOperator) {
|
||||
this.succeedingOperator = succeedingOperator;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public void collect(Record record) {
|
||||
try {
|
||||
succeedingOperator.processElement(record);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
+39
@@ -0,0 +1,39 @@
|
||||
package io.ray.streaming.operator.impl;
|
||||
|
||||
import io.ray.streaming.api.function.impl.JoinFunction;
|
||||
import io.ray.streaming.message.Record;
|
||||
import io.ray.streaming.operator.ChainStrategy;
|
||||
import io.ray.streaming.operator.OperatorType;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import io.ray.streaming.operator.TwoInputOperator;
|
||||
|
||||
/**
|
||||
* Join operator
|
||||
*
|
||||
* @param <L> Type of the data in the left stream.
|
||||
* @param <R> Type of the data in the right stream.
|
||||
* @param <K> Type of the data in the join key.
|
||||
* @param <O> Type of the data in the joined stream.
|
||||
*/
|
||||
public class JoinOperator<L, R, K, O> extends StreamOperator<JoinFunction<L, R, O>> implements
|
||||
TwoInputOperator<L, R> {
|
||||
public JoinOperator() {
|
||||
|
||||
}
|
||||
|
||||
public JoinOperator(JoinFunction<L, R, O> function) {
|
||||
super(function);
|
||||
setChainStrategy(ChainStrategy.HEAD);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void processElement(Record<L> record1, Record<R> record2) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperatorType getOpType() {
|
||||
return OperatorType.TWO_INPUT;
|
||||
}
|
||||
|
||||
}
|
||||
+3
@@ -5,6 +5,7 @@ import io.ray.streaming.api.context.RuntimeContext;
|
||||
import io.ray.streaming.api.function.impl.ReduceFunction;
|
||||
import io.ray.streaming.message.KeyRecord;
|
||||
import io.ray.streaming.message.Record;
|
||||
import io.ray.streaming.operator.ChainStrategy;
|
||||
import io.ray.streaming.operator.OneInputOperator;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import java.util.HashMap;
|
||||
@@ -18,6 +19,7 @@ public class ReduceOperator<K, T> extends StreamOperator<ReduceFunction<T>> impl
|
||||
|
||||
public ReduceOperator(ReduceFunction<T> reduceFunction) {
|
||||
super(reduceFunction);
|
||||
setChainStrategy(ChainStrategy.HEAD);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -41,4 +43,5 @@ public class ReduceOperator<K, T> extends StreamOperator<ReduceFunction<T>> impl
|
||||
collect(record);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
+14
-4
@@ -5,15 +5,19 @@ import io.ray.streaming.api.context.RuntimeContext;
|
||||
import io.ray.streaming.api.function.impl.SourceFunction;
|
||||
import io.ray.streaming.api.function.impl.SourceFunction.SourceContext;
|
||||
import io.ray.streaming.message.Record;
|
||||
import io.ray.streaming.operator.ChainStrategy;
|
||||
import io.ray.streaming.operator.OperatorType;
|
||||
import io.ray.streaming.operator.SourceOperator;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import java.util.List;
|
||||
|
||||
public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
|
||||
public class SourceOperatorImpl<T> extends StreamOperator<SourceFunction<T>>
|
||||
implements SourceOperator {
|
||||
private SourceContextImpl sourceContext;
|
||||
|
||||
public SourceOperator(SourceFunction<T> function) {
|
||||
public SourceOperatorImpl(SourceFunction<T> function) {
|
||||
super(function);
|
||||
setChainStrategy(ChainStrategy.HEAD);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -23,6 +27,7 @@ public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
|
||||
this.function.init(runtimeContext.getParallelism(), runtimeContext.getTaskIndex());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
this.function.run(this.sourceContext);
|
||||
@@ -31,13 +36,17 @@ public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public SourceContext getSourceContext() {
|
||||
return sourceContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperatorType getOpType() {
|
||||
return OperatorType.SOURCE;
|
||||
}
|
||||
|
||||
class SourceContextImpl implements SourceContext<T> {
|
||||
|
||||
private List<Collector> collectors;
|
||||
|
||||
public SourceContextImpl(List<Collector> collectors) {
|
||||
@@ -47,9 +56,10 @@ public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
|
||||
@Override
|
||||
public void collect(T t) throws Exception {
|
||||
for (Collector collector : collectors) {
|
||||
collector.collect(new Record(t));
|
||||
collector.collect(new Record<>(t));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -100,6 +100,14 @@ public class PythonFunction implements Function {
|
||||
return functionInterface;
|
||||
}
|
||||
|
||||
public String toSimpleString() {
|
||||
if (function != null) {
|
||||
return "binary function";
|
||||
} else {
|
||||
return String.format("%s-%s.%s", functionInterface, moduleName, functionName);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringJoiner stringJoiner = new StringJoiner(", ",
|
||||
|
||||
+114
-24
@@ -1,11 +1,16 @@
|
||||
package io.ray.streaming.python;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.api.context.RuntimeContext;
|
||||
import io.ray.streaming.api.function.Function;
|
||||
import io.ray.streaming.operator.Operator;
|
||||
import io.ray.streaming.operator.OperatorType;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.StringJoiner;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Represents a {@link StreamOperator} that wraps python {@link PythonFunction}.
|
||||
@@ -27,30 +32,6 @@ public class PythonOperator extends StreamOperator {
|
||||
this.className = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void open(List list, RuntimeContext runtimeContext) {
|
||||
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
|
||||
throw new UnsupportedOperationException(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() {
|
||||
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
|
||||
throw new UnsupportedOperationException(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
|
||||
throw new UnsupportedOperationException(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperatorType getOpType() {
|
||||
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
|
||||
throw new UnsupportedOperationException(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Language getLanguage() {
|
||||
return Language.PYTHON;
|
||||
@@ -64,6 +45,48 @@ public class PythonOperator extends StreamOperator {
|
||||
return className;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void open(List list, RuntimeContext runtimeContext) {
|
||||
throwUnsupportedException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() {
|
||||
throwUnsupportedException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
throwUnsupportedException();
|
||||
}
|
||||
|
||||
void throwUnsupportedException() {
|
||||
StackTraceElement[] trace = Thread.currentThread().getStackTrace();
|
||||
Preconditions.checkState(trace.length >= 2);
|
||||
StackTraceElement traceElement = trace[2];
|
||||
String msg = String.format("Method %s.%s shouldn't be called.",
|
||||
traceElement.getClassName(), traceElement.getMethodName());
|
||||
throw new UnsupportedOperationException(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperatorType getOpType() {
|
||||
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
|
||||
throw new UnsupportedOperationException(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append(PythonOperator.class.getSimpleName()).append("[");
|
||||
if (function != null) {
|
||||
builder.append(((PythonFunction)function).toSimpleString());
|
||||
} else {
|
||||
builder.append(moduleName).append(".").append(className);
|
||||
}
|
||||
return builder.append("]").toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringJoiner stringJoiner = new StringJoiner(", ",
|
||||
@@ -76,4 +99,71 @@ public class PythonOperator extends StreamOperator {
|
||||
}
|
||||
return stringJoiner.toString();
|
||||
}
|
||||
|
||||
public static class ChainedPythonOperator extends PythonOperator {
|
||||
private final List<PythonOperator> operators;
|
||||
private final PythonOperator headOperator;
|
||||
private final PythonOperator tailOperator;
|
||||
private final List<Map<String, String>> configs;
|
||||
|
||||
public ChainedPythonOperator(
|
||||
List<PythonOperator> operators, List<Map<String, String>> configs) {
|
||||
super(null);
|
||||
Preconditions.checkArgument(!operators.isEmpty());
|
||||
this.operators = operators;
|
||||
this.configs = configs;
|
||||
this.headOperator = operators.get(0);
|
||||
this.tailOperator = operators.get(operators.size() - 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperatorType getOpType() {
|
||||
return headOperator.getOpType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Language getLanguage() {
|
||||
return Language.PYTHON;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return operators.stream().map(Operator::getName)
|
||||
.collect(Collectors.joining(" -> ", "[", "]"));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModuleName() {
|
||||
throwUnsupportedException();
|
||||
return null; // impossible
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getClassName() {
|
||||
throwUnsupportedException();
|
||||
return null; // impossible
|
||||
}
|
||||
|
||||
@Override
|
||||
public Function getFunction() {
|
||||
throwUnsupportedException();
|
||||
return null; // impossible
|
||||
}
|
||||
|
||||
public List<PythonOperator> getOperators() {
|
||||
return operators;
|
||||
}
|
||||
|
||||
public PythonOperator getHeadOperator() {
|
||||
return headOperator;
|
||||
}
|
||||
|
||||
public PythonOperator getTailOperator() {
|
||||
return tailOperator;
|
||||
}
|
||||
|
||||
public List<Map<String, String>> getConfigs() {
|
||||
return configs;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+7
@@ -24,6 +24,9 @@ public class PythonPartition implements Partition<Object> {
|
||||
"ray.streaming.partition", "KeyPartition");
|
||||
public static final PythonPartition RoundRobinPartition = new PythonPartition(
|
||||
"ray.streaming.partition", "RoundRobinPartition");
|
||||
public static final String FORWARD_PARTITION_CLASS = "ForwardPartition";
|
||||
public static final PythonPartition ForwardPartition = new PythonPartition(
|
||||
"ray.streaming.partition", FORWARD_PARTITION_CLASS);
|
||||
|
||||
private byte[] partition;
|
||||
private String moduleName;
|
||||
@@ -66,6 +69,10 @@ public class PythonPartition implements Partition<Object> {
|
||||
return functionName;
|
||||
}
|
||||
|
||||
public boolean isConstructedFromBinary() {
|
||||
return partition != null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringJoiner stringJoiner = new StringJoiner(", ",
|
||||
|
||||
+4
-1
@@ -2,6 +2,7 @@ package io.ray.streaming.python.stream;
|
||||
|
||||
import io.ray.streaming.api.stream.DataStream;
|
||||
import io.ray.streaming.api.stream.KeyDataStream;
|
||||
import io.ray.streaming.operator.ChainStrategy;
|
||||
import io.ray.streaming.python.PythonFunction;
|
||||
import io.ray.streaming.python.PythonFunction.FunctionInterface;
|
||||
import io.ray.streaming.python.PythonOperator;
|
||||
@@ -37,7 +38,9 @@ public class PythonKeyDataStream extends PythonDataStream implements PythonStrea
|
||||
*/
|
||||
public PythonDataStream reduce(PythonFunction func) {
|
||||
func.setFunctionInterface(FunctionInterface.REDUCE_FUNCTION);
|
||||
return new PythonDataStream(this, new PythonOperator(func));
|
||||
PythonDataStream stream = new PythonDataStream(this, new PythonOperator(func));
|
||||
stream.withChainStrategy(ChainStrategy.HEAD);
|
||||
return stream;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
+3
-3
@@ -2,10 +2,10 @@ package io.ray.streaming.python.stream;
|
||||
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.stream.StreamSource;
|
||||
import io.ray.streaming.operator.ChainStrategy;
|
||||
import io.ray.streaming.python.PythonFunction;
|
||||
import io.ray.streaming.python.PythonFunction.FunctionInterface;
|
||||
import io.ray.streaming.python.PythonOperator;
|
||||
import io.ray.streaming.python.PythonPartition;
|
||||
|
||||
/**
|
||||
* Represents a source of the PythonStream.
|
||||
@@ -13,8 +13,8 @@ import io.ray.streaming.python.PythonPartition;
|
||||
public class PythonStreamSource extends PythonDataStream implements StreamSource {
|
||||
|
||||
private PythonStreamSource(StreamingContext streamingContext, PythonFunction sourceFunction) {
|
||||
super(streamingContext, new PythonOperator(sourceFunction),
|
||||
PythonPartition.RoundRobinPartition);
|
||||
super(streamingContext, new PythonOperator(sourceFunction));
|
||||
withChainStrategy(ChainStrategy.HEAD);
|
||||
}
|
||||
|
||||
public static PythonStreamSource from(StreamingContext streamingContext,
|
||||
|
||||
+2
@@ -14,6 +14,8 @@ public class PythonUnionStream extends PythonDataStream {
|
||||
private List<PythonDataStream> unionStreams;
|
||||
|
||||
public PythonUnionStream(PythonDataStream input, List<PythonDataStream> others) {
|
||||
// Union stream does not create a physical operation, so we don't have to set partition
|
||||
// function for it.
|
||||
super(input, new PythonOperator(
|
||||
"ray.streaming.operator", "UnionOperator"));
|
||||
this.unionStreams = new ArrayList<>();
|
||||
|
||||
+10
-10
@@ -2,8 +2,8 @@ package io.ray.streaming.jobgraph;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.partition.impl.ForwardPartition;
|
||||
import io.ray.streaming.api.partition.impl.KeyPartition;
|
||||
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
|
||||
import io.ray.streaming.api.stream.DataStream;
|
||||
import io.ray.streaming.api.stream.DataStreamSource;
|
||||
import io.ray.streaming.api.stream.StreamSink;
|
||||
@@ -20,14 +20,14 @@ public class JobGraphBuilderTest {
|
||||
@Test
|
||||
public void testDataSync() {
|
||||
JobGraph jobGraph = buildDataSyncJobGraph();
|
||||
List<JobVertex> jobVertexList = jobGraph.getJobVertexList();
|
||||
List<JobEdge> jobEdgeList = jobGraph.getJobEdgeList();
|
||||
List<JobVertex> jobVertexList = jobGraph.getJobVertices();
|
||||
List<JobEdge> jobEdgeList = jobGraph.getJobEdges();
|
||||
|
||||
Assert.assertEquals(jobVertexList.size(), 2);
|
||||
Assert.assertEquals(jobEdgeList.size(), 1);
|
||||
|
||||
JobEdge jobEdge = jobEdgeList.get(0);
|
||||
Assert.assertEquals(jobEdge.getPartition().getClass(), RoundRobinPartition.class);
|
||||
Assert.assertEquals(jobEdge.getPartition().getClass(), ForwardPartition.class);
|
||||
|
||||
JobVertex sinkVertex = jobVertexList.get(1);
|
||||
JobVertex sourceVertex = jobVertexList.get(0);
|
||||
@@ -50,8 +50,8 @@ public class JobGraphBuilderTest {
|
||||
@Test
|
||||
public void testKeyByJobGraph() {
|
||||
JobGraph jobGraph = buildKeyByJobGraph();
|
||||
List<JobVertex> jobVertexList = jobGraph.getJobVertexList();
|
||||
List<JobEdge> jobEdgeList = jobGraph.getJobEdgeList();
|
||||
List<JobVertex> jobVertexList = jobGraph.getJobVertices();
|
||||
List<JobEdge> jobEdgeList = jobGraph.getJobEdges();
|
||||
|
||||
Assert.assertEquals(jobVertexList.size(), 3);
|
||||
Assert.assertEquals(jobEdgeList.size(), 2);
|
||||
@@ -68,7 +68,7 @@ public class JobGraphBuilderTest {
|
||||
JobEdge source2KeyBy = jobEdgeList.get(1);
|
||||
|
||||
Assert.assertEquals(keyBy2Sink.getPartition().getClass(), KeyPartition.class);
|
||||
Assert.assertEquals(source2KeyBy.getPartition().getClass(), RoundRobinPartition.class);
|
||||
Assert.assertEquals(source2KeyBy.getPartition().getClass(), ForwardPartition.class);
|
||||
}
|
||||
|
||||
public JobGraph buildKeyByJobGraph() {
|
||||
@@ -88,8 +88,8 @@ public class JobGraphBuilderTest {
|
||||
JobGraph jobGraph = buildKeyByJobGraph();
|
||||
jobGraph.generateDigraph();
|
||||
String diGraph = jobGraph.getDigraph();
|
||||
System.out.println(diGraph);
|
||||
Assert.assertTrue(diGraph.contains("1-SourceOperator -> 2-KeyByOperator"));
|
||||
Assert.assertTrue(diGraph.contains("2-KeyByOperator -> 3-SinkOperator"));
|
||||
LOG.info(diGraph);
|
||||
Assert.assertTrue(diGraph.contains("\"1-SourceOperatorImpl\" -> \"2-KeyByOperator\""));
|
||||
Assert.assertTrue(diGraph.contains("\"2-KeyByOperator\" -> \"3-SinkOperator\""));
|
||||
}
|
||||
}
|
||||
+70
@@ -0,0 +1,70 @@
|
||||
package io.ray.streaming.jobgraph;
|
||||
|
||||
import static org.testng.Assert.assertEquals;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.stream.DataStream;
|
||||
import io.ray.streaming.api.stream.DataStreamSource;
|
||||
import io.ray.streaming.python.PythonFunction;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class JobGraphOptimizerTest {
|
||||
private static final Logger LOG = LoggerFactory.getLogger( JobGraphOptimizerTest.class );
|
||||
|
||||
@Test
|
||||
public void testOptimize() {
|
||||
StreamingContext context = StreamingContext.buildContext();
|
||||
DataStream<Integer> source1 = DataStreamSource.fromCollection(context,
|
||||
Lists.newArrayList(1 ,2 ,3));
|
||||
DataStream<String> source2 = DataStreamSource.fromCollection(context,
|
||||
Lists.newArrayList("1", "2", "3"));
|
||||
DataStream<String> source3 = DataStreamSource.fromCollection(context,
|
||||
Lists.newArrayList("2", "3", "4"));
|
||||
source1.filter(x -> x > 1)
|
||||
.map(String::valueOf)
|
||||
.union(source2)
|
||||
.join(source3)
|
||||
.sink(x -> System.out.println("Sink " + x));
|
||||
JobGraph jobGraph = new JobGraphBuilder(context.getStreamSinks()).build();
|
||||
LOG.info("Digraph {}", jobGraph.generateDigraph());
|
||||
assertEquals(jobGraph.getJobVertices().size(), 8);
|
||||
|
||||
JobGraphOptimizer graphOptimizer = new JobGraphOptimizer(jobGraph);
|
||||
JobGraph optimizedJobGraph = graphOptimizer.optimize();
|
||||
optimizedJobGraph.printJobGraph();
|
||||
LOG.info("Optimized graph {}", optimizedJobGraph.generateDigraph());
|
||||
assertEquals(optimizedJobGraph.getJobVertices().size(), 5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOptimizeHybridStream() {
|
||||
StreamingContext context = StreamingContext.buildContext();
|
||||
DataStream<Integer> source1 = DataStreamSource.fromCollection(context,
|
||||
Lists.newArrayList(1 ,2 ,3));
|
||||
DataStream<String> source2 = DataStreamSource.fromCollection(context,
|
||||
Lists.newArrayList("1", "2", "3"));
|
||||
source1.asPythonStream()
|
||||
.map(pyFunc(1))
|
||||
.filter(pyFunc(2))
|
||||
.union(source2.asPythonStream().filter(pyFunc(3)).map(pyFunc(4)))
|
||||
.asJavaStream()
|
||||
.sink(x -> System.out.println("Sink " + x));
|
||||
JobGraph jobGraph = new JobGraphBuilder(context.getStreamSinks()).build();
|
||||
LOG.info("Digraph {}", jobGraph.generateDigraph());
|
||||
assertEquals(jobGraph.getJobVertices().size(), 8);
|
||||
|
||||
JobGraphOptimizer graphOptimizer = new JobGraphOptimizer(jobGraph);
|
||||
JobGraph optimizedJobGraph = graphOptimizer.optimize();
|
||||
optimizedJobGraph.printJobGraph();
|
||||
LOG.info("Optimized graph {}", optimizedJobGraph.generateDigraph());
|
||||
assertEquals(optimizedJobGraph.getJobVertices().size(), 6);
|
||||
}
|
||||
|
||||
private PythonFunction pyFunc(int number) {
|
||||
return new PythonFunction("module", "func" + number);
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user