[Streaming] operator chain (#8910)

This commit is contained in:
chaokunyang
2020-06-18 15:11:07 +08:00
committed by GitHub
parent 003cec87b4
commit 5edddf6eac
39 changed files with 1058 additions and 140 deletions
@@ -6,6 +6,7 @@ import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.client.JobClient;
import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.jobgraph.JobGraphBuilder;
import io.ray.streaming.jobgraph.JobGraphOptimizer;
import io.ray.streaming.util.Config;
import java.io.Serializable;
import java.util.ArrayList;
@@ -56,7 +57,8 @@ public class StreamingContext implements Serializable {
*/
public void execute(String jobName) {
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(this.streamSinks, jobName);
this.jobGraph = jobGraphBuilder.build();
JobGraph originalJobGraph = jobGraphBuilder.build();
this.jobGraph = new JobGraphOptimizer(originalJobGraph).optimize();
jobGraph.printJobGraph();
LOG.info("JobGraph digraph\n{}", jobGraph.generateDigraph());
@@ -0,0 +1,19 @@
package io.ray.streaming.api.partition.impl;
import io.ray.streaming.api.partition.Partition;
/**
* Default partition for operator if the operator can be chained with succeeding operators.
* Partition will be set to {@link RoundRobinPartition} if the operator can't be chiained with
* succeeding operators.
*
* @param <T> Type of the input record.
*/
public class ForwardPartition<T> implements Partition<T> {
private int[] partitions = new int[] {0};
@Override
public int[] partition(T record, int numPartition) {
return partitions;
}
}
@@ -3,8 +3,7 @@ package io.ray.streaming.api.stream;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.function.impl.SourceFunction;
import io.ray.streaming.api.function.internal.CollectionSourceFunction;
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
import io.ray.streaming.operator.impl.SourceOperator;
import io.ray.streaming.operator.impl.SourceOperatorImpl;
import java.util.Collection;
/**
@@ -15,7 +14,7 @@ import java.util.Collection;
public class DataStreamSource<T> extends DataStream<T> implements StreamSource<T> {
private DataStreamSource(StreamingContext streamingContext, SourceFunction<T> sourceFunction) {
super(streamingContext, new SourceOperator<>(sourceFunction), new RoundRobinPartition<>());
super(streamingContext, new SourceOperatorImpl<>(sourceFunction));
}
public static <T> DataStreamSource<T> fromSource(
@@ -2,6 +2,7 @@ package io.ray.streaming.api.stream;
import io.ray.streaming.api.function.impl.JoinFunction;
import io.ray.streaming.api.function.impl.KeyFunction;
import io.ray.streaming.operator.impl.JoinOperator;
import java.io.Serializable;
/**
@@ -9,40 +10,42 @@ import java.io.Serializable;
*
* @param <L> Type of the data in the left stream.
* @param <R> Type of the data in the right stream.
* @param <J> Type of the data in the joined stream.
* @param <O> Type of the data in the joined stream.
*/
public class JoinStream<L, R, J> extends DataStream<L> {
public class JoinStream<L, R, O> extends DataStream<L> {
private final DataStream<R> rightStream;
public JoinStream(DataStream<L> leftStream, DataStream<R> rightStream) {
super(leftStream, null);
super(leftStream, new JoinOperator<>());
this.rightStream = rightStream;
}
public DataStream<R> getRightStream() {
return rightStream;
}
/**
* Apply key-by to the left join stream.
*/
public <K> Where<L, R, J, K> where(KeyFunction<L, K> keyFunction) {
public <K> Where<K> where(KeyFunction<L, K> keyFunction) {
return new Where<>(this, keyFunction);
}
/**
* Where clause of the join transformation.
*
* @param <L> Type of the data in the left stream.
* @param <R> Type of the data in the right stream.
* @param <J> Type of the data in the joined stream.
* @param <K> Type of the join key.
*/
class Where<L, R, J, K> implements Serializable {
private JoinStream<L, R, J> joinStream;
class Where<K> implements Serializable {
private JoinStream<L, R, O> joinStream;
private KeyFunction<L, K> leftKeyByFunction;
public Where(JoinStream<L, R, J> joinStream, KeyFunction<L, K> leftKeyByFunction) {
Where(JoinStream<L, R, O> joinStream, KeyFunction<L, K> leftKeyByFunction) {
this.joinStream = joinStream;
this.leftKeyByFunction = leftKeyByFunction;
}
public Equal<L, R, J, K> equalLo(KeyFunction<R, K> rightKeyFunction) {
public Equal<K> equalTo(KeyFunction<R, K> rightKeyFunction) {
return new Equal<>(joinStream, leftKeyByFunction, rightKeyFunction);
}
}
@@ -50,26 +53,25 @@ public class JoinStream<L, R, J> extends DataStream<L> {
/**
* Equal clause of the join transformation.
*
* @param <L> Type of the data in the left stream.
* @param <R> Type of the data in the right stream.
* @param <J> Type of the data in the joined stream.
* @param <K> Type of the join key.
*/
class Equal<L, R, J, K> implements Serializable {
private JoinStream<L, R, J> joinStream;
class Equal<K> implements Serializable {
private JoinStream<L, R, O> joinStream;
private KeyFunction<L, K> leftKeyByFunction;
private KeyFunction<R, K> rightKeyByFunction;
public Equal(JoinStream<L, R, J> joinStream, KeyFunction<L, K> leftKeyByFunction,
KeyFunction<R, K> rightKeyByFunction) {
Equal(JoinStream<L, R, O> joinStream, KeyFunction<L, K> leftKeyByFunction,
KeyFunction<R, K> rightKeyByFunction) {
this.joinStream = joinStream;
this.leftKeyByFunction = leftKeyByFunction;
this.rightKeyByFunction = rightKeyByFunction;
}
public DataStream<J> with(JoinFunction<L, R, J> joinFunction) {
return (DataStream<J>) joinStream;
@SuppressWarnings("unchecked")
public DataStream<O> with(JoinFunction<L, R, O> joinFunction) {
JoinOperator joinOperator = (JoinOperator) joinStream.getOperator();
joinOperator.setFunction(joinFunction);
return (DataStream<O>) joinStream;
}
}
@@ -4,7 +4,8 @@ import com.google.common.base.Preconditions;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
import io.ray.streaming.api.partition.impl.ForwardPartition;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.operator.Operator;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.python.PythonPartition;
@@ -30,8 +31,7 @@ public abstract class Stream<S extends Stream<S, T>, T>
private Stream originalStream;
public Stream(StreamingContext streamingContext, StreamOperator streamOperator) {
this(streamingContext, null, streamOperator,
selectPartition(streamOperator));
this(streamingContext, null, streamOperator, getForwardPartition(streamOperator));
}
public Stream(StreamingContext streamingContext,
@@ -42,7 +42,7 @@ public abstract class Stream<S extends Stream<S, T>, T>
public Stream(Stream inputStream, StreamOperator streamOperator) {
this(inputStream.getStreamingContext(), inputStream, streamOperator,
selectPartition(streamOperator));
getForwardPartition(streamOperator));
}
public Stream(Stream inputStream, StreamOperator streamOperator, Partition<T> partition) {
@@ -50,9 +50,9 @@ public abstract class Stream<S extends Stream<S, T>, T>
}
protected Stream(StreamingContext streamingContext,
Stream inputStream,
StreamOperator streamOperator,
Partition<T> partition) {
Stream inputStream,
StreamOperator streamOperator,
Partition<T> partition) {
this.streamingContext = streamingContext;
this.inputStream = inputStream;
this.operator = streamOperator;
@@ -73,15 +73,16 @@ public abstract class Stream<S extends Stream<S, T>, T>
this.streamingContext = originalStream.getStreamingContext();
this.inputStream = originalStream.getInputStream();
this.operator = originalStream.getOperator();
Preconditions.checkNotNull(operator);
}
@SuppressWarnings("unchecked")
private static <T> Partition<T> selectPartition(Operator operator) {
private static <T> Partition<T> getForwardPartition(Operator operator) {
switch (operator.getLanguage()) {
case PYTHON:
return (Partition<T>) PythonPartition.RoundRobinPartition;
return (Partition<T>) PythonPartition.ForwardPartition;
case JAVA:
return new RoundRobinPartition<>();
return new ForwardPartition<>();
default:
throw new UnsupportedOperationException(
"Unsupported language " + operator.getLanguage());
@@ -165,5 +166,29 @@ public abstract class Stream<S extends Stream<S, T>, T>
return originalStream;
}
/**
* Set chain strategy for this stream
*/
public S withChainStrategy(ChainStrategy chainStrategy) {
Preconditions.checkArgument(!isProxyStream());
operator.setChainStrategy(chainStrategy);
return self();
}
/**
* Disable chain for this stream
*/
public S disableChain() {
return withChainStrategy(ChainStrategy.NEVER);
}
/**
* Set the partition function of this {@link Stream} so that output elements are forwarded to
* next operator locally.
*/
public S forward() {
return setPartition(getForwardPartition(operator));
}
public abstract Language getLanguage();
}
@@ -16,6 +16,8 @@ public class UnionStream<T> extends DataStream<T> {
private List<DataStream<T>> unionStreams;
public UnionStream(DataStream<T> input, List<DataStream<T>> streams) {
// Union stream does not create a physical operation, so we don't have to set partition
// function for it.
super(input, new UnionOperator());
this.unionStreams = new ArrayList<>();
streams.forEach(this::addStream);
@@ -5,6 +5,8 @@ import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -17,15 +19,24 @@ public class JobGraph implements Serializable {
private final String jobName;
private final Map<String, String> jobConfig;
private List<JobVertex> jobVertexList;
private List<JobEdge> jobEdgeList;
private List<JobVertex> jobVertices;
private List<JobEdge> jobEdges;
private String digraph;
public JobGraph(String jobName, Map<String, String> jobConfig) {
this.jobName = jobName;
this.jobConfig = jobConfig;
this.jobVertexList = new ArrayList<>();
this.jobEdgeList = new ArrayList<>();
this.jobVertices = new ArrayList<>();
this.jobEdges = new ArrayList<>();
}
public JobGraph(String jobName, Map<String, String> jobConfig,
List<JobVertex> jobVertices, List<JobEdge> jobEdges) {
this.jobName = jobName;
this.jobConfig = jobConfig;
this.jobVertices = jobVertices;
this.jobEdges = jobEdges;
generateDigraph();
}
/**
@@ -36,12 +47,12 @@ public class JobGraph implements Serializable {
*/
public String generateDigraph() {
StringBuilder digraph = new StringBuilder();
digraph.append("digraph ").append(jobName + " ").append(" {");
digraph.append("digraph ").append(jobName).append(" ").append(" {");
for (JobEdge jobEdge : jobEdgeList) {
for (JobEdge jobEdge : jobEdges) {
String srcNode = null;
String targetNode = null;
for (JobVertex jobVertex : jobVertexList) {
for (JobVertex jobVertex : jobVertices) {
if (jobEdge.getSrcVertexId() == jobVertex.getVertexId()) {
srcNode = jobVertex.getVertexId() + "-" + jobVertex.getStreamOperator().getName();
} else if (jobEdge.getTargetVertexId() == jobVertex.getVertexId()) {
@@ -49,7 +60,7 @@ public class JobGraph implements Serializable {
}
}
digraph.append(System.getProperty("line.separator"));
digraph.append(srcNode).append(" -> ").append(targetNode);
digraph.append(String.format(" \"%s\" -> \"%s\"", srcNode, targetNode));
}
digraph.append(System.getProperty("line.separator")).append("}");
@@ -58,19 +69,47 @@ public class JobGraph implements Serializable {
}
public void addVertex(JobVertex vertex) {
this.jobVertexList.add(vertex);
this.jobVertices.add(vertex);
}
public void addEdge(JobEdge jobEdge) {
this.jobEdgeList.add(jobEdge);
this.jobEdges.add(jobEdge);
}
public List<JobVertex> getJobVertexList() {
return jobVertexList;
public List<JobVertex> getJobVertices() {
return jobVertices;
}
public List<JobEdge> getJobEdgeList() {
return jobEdgeList;
public List<JobVertex> getSourceVertices() {
return jobVertices.stream()
.filter(v -> v.getVertexType() == VertexType.SOURCE)
.collect(Collectors.toList());
}
public List<JobVertex> getSinkVertices() {
return jobVertices.stream()
.filter(v -> v.getVertexType() == VertexType.SINK)
.collect(Collectors.toList());
}
public JobVertex getVertex(int vertexId) {
return jobVertices.stream().filter(v -> v.getVertexId() == vertexId).findFirst().get();
}
public List<JobEdge> getJobEdges() {
return jobEdges;
}
public Set<JobEdge> getVertexInputEdges(int vertexId) {
return jobEdges.stream()
.filter(jobEdge -> jobEdge.getTargetVertexId() == vertexId)
.collect(Collectors.toSet());
}
public Set<JobEdge> getVertexOutputEdges(int vertexId) {
return jobEdges.stream()
.filter(jobEdge -> jobEdge.getSrcVertexId() == vertexId)
.collect(Collectors.toSet());
}
public String getDigraph() {
@@ -90,17 +129,17 @@ public class JobGraph implements Serializable {
return;
}
LOG.info("Printing job graph:");
for (JobVertex jobVertex : jobVertexList) {
for (JobVertex jobVertex : jobVertices) {
LOG.info(jobVertex.toString());
}
for (JobEdge jobEdge : jobEdgeList) {
for (JobEdge jobEdge : jobEdges) {
LOG.info(jobEdge.toString());
}
}
public boolean isCrossLanguageGraph() {
Language language = jobVertexList.get(0).getLanguage();
for (JobVertex jobVertex : jobVertexList) {
Language language = jobVertices.get(0).getLanguage();
for (JobVertex jobVertex : jobVertices) {
if (jobVertex.getLanguage() != language) {
return true;
}
@@ -2,6 +2,7 @@ package io.ray.streaming.jobgraph;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.JoinStream;
import io.ray.streaming.api.stream.Stream;
import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.api.stream.StreamSource;
@@ -26,7 +27,7 @@ public class JobGraphBuilder {
private List<StreamSink> streamSinkList;
public JobGraphBuilder(List<StreamSink> streamSinkList) {
this(streamSinkList, "job-" + System.currentTimeMillis());
this(streamSinkList, "job_" + System.currentTimeMillis());
}
public JobGraphBuilder(List<StreamSink> streamSinkList, String jobName) {
@@ -61,18 +62,20 @@ public class JobGraphBuilder {
"Reference stream should be skipped.");
int vertexId = stream.getId();
int parallelism = stream.getParallelism();
Map<String, String> config = stream.getConfig();
JobVertex jobVertex;
if (stream instanceof StreamSink) {
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator);
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator, config);
Stream parentStream = stream.getInputStream();
int inputVertexId = parentStream.getId();
JobEdge jobEdge = new JobEdge(inputVertexId, vertexId, parentStream.getPartition());
this.jobGraph.addEdge(jobEdge);
processStream(parentStream);
} else if (stream instanceof StreamSource) {
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SOURCE, streamOperator);
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SOURCE, streamOperator, config);
} else if (stream instanceof DataStream || stream instanceof PythonDataStream) {
jobVertex = new JobVertex(vertexId, parallelism, VertexType.TRANSFORMATION, streamOperator);
jobVertex = new JobVertex(
vertexId, parallelism, VertexType.TRANSFORMATION, streamOperator, config);
Stream parentStream = stream.getInputStream();
int inputVertexId = parentStream.getId();
JobEdge jobEdge = new JobEdge(inputVertexId, vertexId, parentStream.getPartition());
@@ -92,10 +95,17 @@ public class JobGraphBuilder {
this.jobGraph.addEdge(otherEdge);
processStream(otherStream);
}
// process join stream
if (stream instanceof JoinStream) {
DataStream rightStream = ((JoinStream) stream).getRightStream();
this.jobGraph.addEdge(
new JobEdge(rightStream.getId(), vertexId, rightStream.getPartition()));
processStream(rightStream);
}
} else {
throw new UnsupportedOperationException("Unsupported stream: " + stream);
}
jobVertex.setConfig(stream.getConfig());
this.jobGraph.addVertex(jobVertex);
}
@@ -0,0 +1,187 @@
package io.ray.streaming.jobgraph;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.partition.impl.ForwardPartition;
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.operator.chain.ChainedOperator;
import io.ray.streaming.python.PythonOperator;
import io.ray.streaming.python.PythonOperator.ChainedPythonOperator;
import io.ray.streaming.python.PythonPartition;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
/**
* Optimize job graph by chaining some operators so that some operators can be run in the
* same thread.
*/
public class JobGraphOptimizer {
private final JobGraph jobGraph;
private Set<JobVertex> visited = new HashSet<>();
// vertex id -> vertex
private Map<Integer, JobVertex> vertexMap;
private Map<JobVertex, Set<JobEdge>> outputEdgesMap;
// tail vertex id -> mergedVertex
private Map<Integer, Pair<JobVertex, List<JobVertex>>> mergedVertexMap;
public JobGraphOptimizer(JobGraph jobGraph) {
this.jobGraph = jobGraph;
vertexMap = jobGraph.getJobVertices().stream()
.collect(Collectors.toMap(JobVertex::getVertexId, Function.identity()));
outputEdgesMap = vertexMap.keySet().stream().collect(Collectors.toMap(
id -> vertexMap.get(id), id -> new HashSet<>(jobGraph.getVertexOutputEdges(id))));
mergedVertexMap = new HashMap<>();
}
public JobGraph optimize() {
// Deep-first traverse nodes from source to sink to merge vertices that can be chained
// together.
jobGraph.getSourceVertices().forEach(vertex -> {
List<JobVertex> verticesToMerge = new ArrayList<>();
verticesToMerge.add(vertex);
mergeVerticesRecursively(vertex, verticesToMerge);
});
List<JobVertex> vertices = mergedVertexMap.values().stream()
.map(Pair::getLeft).collect(Collectors.toList());
return new JobGraph(jobGraph.getJobName(), jobGraph.getJobConfig(), vertices, createEdges());
}
private void mergeVerticesRecursively(JobVertex vertex, List<JobVertex> verticesToMerge) {
if (!visited.contains(vertex)) {
visited.add(vertex);
Set<JobEdge> outputEdges = outputEdgesMap.get(vertex);
if (outputEdges.isEmpty()) {
mergeAndAddVertex(verticesToMerge);
} else {
outputEdges.forEach(edge -> {
JobVertex succeedingVertex = vertexMap.get(edge.getTargetVertexId());
if (canBeChained(vertex, succeedingVertex, edge)) {
verticesToMerge.add(succeedingVertex);
mergeVerticesRecursively(succeedingVertex, verticesToMerge);
} else {
mergeAndAddVertex(verticesToMerge);
List<JobVertex> newMergedVertices = new ArrayList<>();
newMergedVertices.add(succeedingVertex);
mergeVerticesRecursively(succeedingVertex, newMergedVertices);
}
});
}
}
}
private void mergeAndAddVertex(List<JobVertex> verticesToMerge) {
JobVertex mergedVertex;
JobVertex headVertex = verticesToMerge.get(0);
Language language = headVertex.getLanguage();
if (verticesToMerge.size() == 1) {
// no chain
mergedVertex = headVertex;
} else {
List<StreamOperator> operators = verticesToMerge.stream()
.map(v -> vertexMap.get(v.getVertexId()).getStreamOperator())
.collect(Collectors.toList());
List<Map<String, String>> configs = verticesToMerge.stream()
.map(v -> vertexMap.get(v.getVertexId()).getConfig())
.collect(Collectors.toList());
StreamOperator operator;
if (language == Language.JAVA) {
operator = ChainedOperator.newChainedOperator(operators, configs);
} else {
List<PythonOperator> pythonOperators = operators.stream()
.map(o -> (PythonOperator) o).collect(Collectors.toList());
operator = new ChainedPythonOperator(pythonOperators, configs);
}
// chained operator config is placed into `ChainedOperator`.
mergedVertex = new JobVertex(headVertex.getVertexId(), headVertex.getParallelism(),
headVertex.getVertexType(), operator, new HashMap<>());
}
mergedVertexMap.put(mergedVertex.getVertexId(), Pair.of(mergedVertex, verticesToMerge));
}
private List<JobEdge> createEdges() {
List<JobEdge> edges = new ArrayList<>();
mergedVertexMap.forEach((id, pair) -> {
JobVertex mergedVertex = pair.getLeft();
List<JobVertex> mergedVertices = pair.getRight();
JobVertex tailVertex = mergedVertices.get(mergedVertices.size() - 1);
// input edge will be set up in input vertices
if (outputEdgesMap.containsKey(tailVertex)) {
outputEdgesMap.get(tailVertex).forEach(edge -> {
Pair<JobVertex, List<JobVertex>> downstreamPair =
mergedVertexMap.get(edge.getTargetVertexId());
// change ForwardPartition to RoundRobinPartition.
Partition partition = changePartition(edge.getPartition());
JobEdge newEdge = new JobEdge(
mergedVertex.getVertexId(),
downstreamPair.getLeft().getVertexId(),
partition);
edges.add(newEdge);
});
}
});
return edges;
}
/**
* Change ForwardPartition to RoundRobinPartition.
*/
private Partition changePartition(Partition partition) {
if (partition instanceof PythonPartition) {
PythonPartition pythonPartition = (PythonPartition) partition;
if (!pythonPartition.isConstructedFromBinary() &&
pythonPartition.getFunctionName().equals(PythonPartition.FORWARD_PARTITION_CLASS)) {
return PythonPartition.RoundRobinPartition;
} else {
return partition;
}
} else {
if (partition instanceof ForwardPartition) {
return new RoundRobinPartition();
} else {
return partition;
}
}
}
private boolean canBeChained(JobVertex precedingVertex,
JobVertex succeedingVertex,
JobEdge edge) {
if (jobGraph.getVertexOutputEdges(precedingVertex.getVertexId()).size() > 1 ||
jobGraph.getVertexInputEdges(succeedingVertex.getVertexId()).size() > 1) {
return false;
}
if (precedingVertex.getParallelism() != succeedingVertex.getParallelism()) {
return false;
}
if (precedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.NEVER
|| succeedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.NEVER
|| succeedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.HEAD) {
return false;
}
if (precedingVertex.getLanguage() != succeedingVertex.getLanguage()) {
return false;
}
Partition partition = edge.getPartition();
if (!(partition instanceof PythonPartition)) {
return partition instanceof ForwardPartition;
} else {
PythonPartition pythonPartition = (PythonPartition) partition;
return !pythonPartition.isConstructedFromBinary() &&
pythonPartition.getFunctionName().equals(PythonPartition.FORWARD_PARTITION_CLASS);
}
}
}
@@ -10,6 +10,7 @@ import java.util.Map;
* Job vertex is a cell node where logic is executed.
*/
public class JobVertex implements Serializable {
private int vertexId;
private int parallelism;
private VertexType vertexType;
@@ -20,12 +21,14 @@ public class JobVertex implements Serializable {
public JobVertex(int vertexId,
int parallelism,
VertexType vertexType,
StreamOperator streamOperator) {
StreamOperator streamOperator,
Map<String, String> config) {
this.vertexId = vertexId;
this.parallelism = parallelism;
this.vertexType = vertexType;
this.streamOperator = streamOperator;
this.language = streamOperator.getLanguage();
this.config = config;
}
public int getVertexId() {
@@ -67,4 +70,5 @@ public class JobVertex implements Serializable {
.add("config", config)
.toString();
}
}
@@ -0,0 +1,20 @@
package io.ray.streaming.operator;
/**
* Chain strategy for streaming operators. Chained operators are run in the same thread.
*/
public enum ChainStrategy {
/**
* The operator won't be chained with preceding operators, but maybe chained with succeeding
* operators.
*/
HEAD,
/**
* Operators will be chained together when possible.
*/
ALWAYS,
/**
* The operator won't be chained with any operator.
*/
NEVER
}
@@ -9,6 +9,8 @@ import java.util.List;
public interface Operator extends Serializable {
String getName();
void open(List<Collector> collectors, RuntimeContext runtimeContext);
void finish();
@@ -20,4 +22,7 @@ public interface Operator extends Serializable {
Language getLanguage();
OperatorType getOpType();
ChainStrategy getChainStrategy();
}
@@ -0,0 +1,14 @@
package io.ray.streaming.operator;
import io.ray.streaming.api.function.impl.SourceFunction.SourceContext;
public interface SourceOperator<T> extends Operator {
void run();
SourceContext<T> getSourceContext();
default OperatorType getOpType() {
return OperatorType.SOURCE;
}
}
@@ -12,13 +12,22 @@ import java.util.List;
public abstract class StreamOperator<F extends Function> implements Operator {
protected final String name;
protected final F function;
protected final RichFunction richFunction;
protected F function;
protected RichFunction richFunction;
protected List<Collector> collectorList;
protected RuntimeContext runtimeContext;
private ChainStrategy chainStrategy = ChainStrategy.ALWAYS;
public StreamOperator(F function) {
protected StreamOperator() {
this.name = getClass().getSimpleName();
}
protected StreamOperator(F function) {
this();
setFunction(function);
}
public void setFunction(F function) {
this.function = function;
this.richFunction = Functions.wrap(function);
}
@@ -62,7 +71,17 @@ public abstract class StreamOperator<F extends Function> implements Operator {
}
}
@Override
public String getName() {
return name;
}
public void setChainStrategy(ChainStrategy chainStrategy) {
this.chainStrategy = chainStrategy;
}
@Override
public ChainStrategy getChainStrategy() {
return chainStrategy;
}
}
@@ -0,0 +1,169 @@
package io.ray.streaming.operator.chain;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.function.Function;
import io.ray.streaming.api.function.impl.SourceFunction.SourceContext;
import io.ray.streaming.message.Record;
import io.ray.streaming.operator.OneInputOperator;
import io.ray.streaming.operator.Operator;
import io.ray.streaming.operator.OperatorType;
import io.ray.streaming.operator.SourceOperator;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.operator.TwoInputOperator;
import java.lang.reflect.Proxy;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* Abstract base class for chained operators.
*/
public abstract class ChainedOperator extends StreamOperator<Function> {
protected final List<StreamOperator> operators;
protected final Operator headOperator;
protected final Operator tailOperator;
private final List<Map<String, String>> configs;
public ChainedOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
Preconditions.checkArgument(operators.size() >= 2,
"Need at lease two operators to be chained together");
operators.stream().skip(1)
.forEach(operator -> Preconditions.checkArgument(operator instanceof OneInputOperator));
this.operators = operators;
this.configs = configs;
this.headOperator = operators.get(0);
this.tailOperator = operators.get(operators.size() - 1);
}
@Override
public void open(List<Collector> collectorList, RuntimeContext runtimeContext) {
// Dont' call super.open() as we `open` every operator separately.
List<ForwardCollector> succeedingCollectors = operators.stream().skip(1)
.map(operator -> new ForwardCollector((OneInputOperator) operator))
.collect(Collectors.toList());
for (int i = 0; i < operators.size() - 1; i++) {
StreamOperator operator = operators.get(i);
List<ForwardCollector> forwardCollectors =
Collections.singletonList(succeedingCollectors.get(i));
operator.open(forwardCollectors, createRuntimeContext(runtimeContext, i));
}
// tail operator send data to downstream using provided collectors.
tailOperator.open(collectorList, createRuntimeContext(runtimeContext, operators.size() - 1));
}
@Override
public OperatorType getOpType() {
return headOperator.getOpType();
}
@Override
public Language getLanguage() {
return headOperator.getLanguage();
}
@Override
public String getName() {
return operators.stream().map(Operator::getName)
.collect(Collectors.joining(" -> ", "[", "]"));
}
public List<StreamOperator> getOperators() {
return operators;
}
public Operator getHeadOperator() {
return headOperator;
}
public Operator getTailOperator() {
return tailOperator;
}
private RuntimeContext createRuntimeContext(RuntimeContext runtimeContext, int index) {
return (RuntimeContext) Proxy.newProxyInstance(runtimeContext.getClass().getClassLoader(),
new Class[] {RuntimeContext.class},
(proxy, method, methodArgs) -> {
if (method.getName().equals("getConfig")) {
return configs.get(index);
} else {
return method.invoke(runtimeContext, methodArgs);
}
});
}
public static ChainedOperator newChainedOperator(
List<StreamOperator> operators,
List<Map<String, String>> configs) {
switch (operators.get(0).getOpType()) {
case SOURCE:
return new ChainedSourceOperator(operators, configs);
case ONE_INPUT:
return new ChainedOneInputOperator(operators, configs);
case TWO_INPUT:
return new ChainedTwoInputOperator(operators, configs);
default:
throw new IllegalArgumentException(
"Unsupported operator type " + operators.get(0).getOpType());
}
}
static class ChainedSourceOperator<T> extends ChainedOperator
implements SourceOperator<T> {
private final SourceOperator<T> sourceOperator;
@SuppressWarnings("unchecked")
ChainedSourceOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
super(operators, configs);
sourceOperator = (SourceOperator<T>) headOperator;
}
@Override
public void run() {
sourceOperator.run();
}
@Override
public SourceContext<T> getSourceContext() {
return sourceOperator.getSourceContext();
}
}
static class ChainedOneInputOperator<T> extends ChainedOperator
implements OneInputOperator<T> {
private final OneInputOperator<T> inputOperator;
@SuppressWarnings("unchecked")
ChainedOneInputOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
super(operators, configs);
inputOperator = (OneInputOperator<T>) headOperator;
}
@Override
public void processElement(Record<T> record) throws Exception {
inputOperator.processElement(record);
}
}
static class ChainedTwoInputOperator<L, R> extends ChainedOperator
implements TwoInputOperator<L, R> {
private final TwoInputOperator<L, R> inputOperator;
@SuppressWarnings("unchecked")
ChainedTwoInputOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
super(operators, configs);
inputOperator = (TwoInputOperator<L, R>) headOperator;
}
@Override
public void processElement(Record<L> record1, Record<R> record2) {
inputOperator.processElement(record1, record2);
}
}
}
@@ -0,0 +1,23 @@
package io.ray.streaming.operator.chain;
import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.message.Record;
import io.ray.streaming.operator.OneInputOperator;
class ForwardCollector implements Collector<Record> {
private final OneInputOperator succeedingOperator;
ForwardCollector(OneInputOperator succeedingOperator) {
this.succeedingOperator = succeedingOperator;
}
@SuppressWarnings("unchecked")
@Override
public void collect(Record record) {
try {
succeedingOperator.processElement(record);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
@@ -0,0 +1,39 @@
package io.ray.streaming.operator.impl;
import io.ray.streaming.api.function.impl.JoinFunction;
import io.ray.streaming.message.Record;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.operator.OperatorType;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.operator.TwoInputOperator;
/**
* Join operator
*
* @param <L> Type of the data in the left stream.
* @param <R> Type of the data in the right stream.
* @param <K> Type of the data in the join key.
* @param <O> Type of the data in the joined stream.
*/
public class JoinOperator<L, R, K, O> extends StreamOperator<JoinFunction<L, R, O>> implements
TwoInputOperator<L, R> {
public JoinOperator() {
}
public JoinOperator(JoinFunction<L, R, O> function) {
super(function);
setChainStrategy(ChainStrategy.HEAD);
}
@Override
public void processElement(Record<L> record1, Record<R> record2) {
}
@Override
public OperatorType getOpType() {
return OperatorType.TWO_INPUT;
}
}
@@ -5,6 +5,7 @@ import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.function.impl.ReduceFunction;
import io.ray.streaming.message.KeyRecord;
import io.ray.streaming.message.Record;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.operator.OneInputOperator;
import io.ray.streaming.operator.StreamOperator;
import java.util.HashMap;
@@ -18,6 +19,7 @@ public class ReduceOperator<K, T> extends StreamOperator<ReduceFunction<T>> impl
public ReduceOperator(ReduceFunction<T> reduceFunction) {
super(reduceFunction);
setChainStrategy(ChainStrategy.HEAD);
}
@Override
@@ -41,4 +43,5 @@ public class ReduceOperator<K, T> extends StreamOperator<ReduceFunction<T>> impl
collect(record);
}
}
}
@@ -5,15 +5,19 @@ import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.function.impl.SourceFunction;
import io.ray.streaming.api.function.impl.SourceFunction.SourceContext;
import io.ray.streaming.message.Record;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.operator.OperatorType;
import io.ray.streaming.operator.SourceOperator;
import io.ray.streaming.operator.StreamOperator;
import java.util.List;
public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
public class SourceOperatorImpl<T> extends StreamOperator<SourceFunction<T>>
implements SourceOperator {
private SourceContextImpl sourceContext;
public SourceOperator(SourceFunction<T> function) {
public SourceOperatorImpl(SourceFunction<T> function) {
super(function);
setChainStrategy(ChainStrategy.HEAD);
}
@Override
@@ -23,6 +27,7 @@ public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
this.function.init(runtimeContext.getParallelism(), runtimeContext.getTaskIndex());
}
@Override
public void run() {
try {
this.function.run(this.sourceContext);
@@ -31,13 +36,17 @@ public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
}
}
@Override
public SourceContext getSourceContext() {
return sourceContext;
}
@Override
public OperatorType getOpType() {
return OperatorType.SOURCE;
}
class SourceContextImpl implements SourceContext<T> {
private List<Collector> collectors;
public SourceContextImpl(List<Collector> collectors) {
@@ -47,9 +56,10 @@ public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
@Override
public void collect(T t) throws Exception {
for (Collector collector : collectors) {
collector.collect(new Record(t));
collector.collect(new Record<>(t));
}
}
}
}
@@ -100,6 +100,14 @@ public class PythonFunction implements Function {
return functionInterface;
}
public String toSimpleString() {
if (function != null) {
return "binary function";
} else {
return String.format("%s-%s.%s", functionInterface, moduleName, functionName);
}
}
@Override
public String toString() {
StringJoiner stringJoiner = new StringJoiner(", ",
@@ -1,11 +1,16 @@
package io.ray.streaming.python;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.function.Function;
import io.ray.streaming.operator.Operator;
import io.ray.streaming.operator.OperatorType;
import io.ray.streaming.operator.StreamOperator;
import java.util.List;
import java.util.Map;
import java.util.StringJoiner;
import java.util.stream.Collectors;
/**
* Represents a {@link StreamOperator} that wraps python {@link PythonFunction}.
@@ -27,30 +32,6 @@ public class PythonOperator extends StreamOperator {
this.className = null;
}
@Override
public void open(List list, RuntimeContext runtimeContext) {
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
throw new UnsupportedOperationException(msg);
}
@Override
public void finish() {
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
throw new UnsupportedOperationException(msg);
}
@Override
public void close() {
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
throw new UnsupportedOperationException(msg);
}
@Override
public OperatorType getOpType() {
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
throw new UnsupportedOperationException(msg);
}
@Override
public Language getLanguage() {
return Language.PYTHON;
@@ -64,6 +45,48 @@ public class PythonOperator extends StreamOperator {
return className;
}
@Override
public void open(List list, RuntimeContext runtimeContext) {
throwUnsupportedException();
}
@Override
public void finish() {
throwUnsupportedException();
}
@Override
public void close() {
throwUnsupportedException();
}
void throwUnsupportedException() {
StackTraceElement[] trace = Thread.currentThread().getStackTrace();
Preconditions.checkState(trace.length >= 2);
StackTraceElement traceElement = trace[2];
String msg = String.format("Method %s.%s shouldn't be called.",
traceElement.getClassName(), traceElement.getMethodName());
throw new UnsupportedOperationException(msg);
}
@Override
public OperatorType getOpType() {
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
throw new UnsupportedOperationException(msg);
}
@Override
public String getName() {
StringBuilder builder = new StringBuilder();
builder.append(PythonOperator.class.getSimpleName()).append("[");
if (function != null) {
builder.append(((PythonFunction)function).toSimpleString());
} else {
builder.append(moduleName).append(".").append(className);
}
return builder.append("]").toString();
}
@Override
public String toString() {
StringJoiner stringJoiner = new StringJoiner(", ",
@@ -76,4 +99,71 @@ public class PythonOperator extends StreamOperator {
}
return stringJoiner.toString();
}
public static class ChainedPythonOperator extends PythonOperator {
private final List<PythonOperator> operators;
private final PythonOperator headOperator;
private final PythonOperator tailOperator;
private final List<Map<String, String>> configs;
public ChainedPythonOperator(
List<PythonOperator> operators, List<Map<String, String>> configs) {
super(null);
Preconditions.checkArgument(!operators.isEmpty());
this.operators = operators;
this.configs = configs;
this.headOperator = operators.get(0);
this.tailOperator = operators.get(operators.size() - 1);
}
@Override
public OperatorType getOpType() {
return headOperator.getOpType();
}
@Override
public Language getLanguage() {
return Language.PYTHON;
}
@Override
public String getName() {
return operators.stream().map(Operator::getName)
.collect(Collectors.joining(" -> ", "[", "]"));
}
@Override
public String getModuleName() {
throwUnsupportedException();
return null; // impossible
}
@Override
public String getClassName() {
throwUnsupportedException();
return null; // impossible
}
@Override
public Function getFunction() {
throwUnsupportedException();
return null; // impossible
}
public List<PythonOperator> getOperators() {
return operators;
}
public PythonOperator getHeadOperator() {
return headOperator;
}
public PythonOperator getTailOperator() {
return tailOperator;
}
public List<Map<String, String>> getConfigs() {
return configs;
}
}
}
@@ -24,6 +24,9 @@ public class PythonPartition implements Partition<Object> {
"ray.streaming.partition", "KeyPartition");
public static final PythonPartition RoundRobinPartition = new PythonPartition(
"ray.streaming.partition", "RoundRobinPartition");
public static final String FORWARD_PARTITION_CLASS = "ForwardPartition";
public static final PythonPartition ForwardPartition = new PythonPartition(
"ray.streaming.partition", FORWARD_PARTITION_CLASS);
private byte[] partition;
private String moduleName;
@@ -66,6 +69,10 @@ public class PythonPartition implements Partition<Object> {
return functionName;
}
public boolean isConstructedFromBinary() {
return partition != null;
}
@Override
public String toString() {
StringJoiner stringJoiner = new StringJoiner(", ",
@@ -2,6 +2,7 @@ package io.ray.streaming.python.stream;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.KeyDataStream;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.python.PythonFunction;
import io.ray.streaming.python.PythonFunction.FunctionInterface;
import io.ray.streaming.python.PythonOperator;
@@ -37,7 +38,9 @@ public class PythonKeyDataStream extends PythonDataStream implements PythonStrea
*/
public PythonDataStream reduce(PythonFunction func) {
func.setFunctionInterface(FunctionInterface.REDUCE_FUNCTION);
return new PythonDataStream(this, new PythonOperator(func));
PythonDataStream stream = new PythonDataStream(this, new PythonOperator(func));
stream.withChainStrategy(ChainStrategy.HEAD);
return stream;
}
/**
@@ -2,10 +2,10 @@ package io.ray.streaming.python.stream;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.stream.StreamSource;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.python.PythonFunction;
import io.ray.streaming.python.PythonFunction.FunctionInterface;
import io.ray.streaming.python.PythonOperator;
import io.ray.streaming.python.PythonPartition;
/**
* Represents a source of the PythonStream.
@@ -13,8 +13,8 @@ import io.ray.streaming.python.PythonPartition;
public class PythonStreamSource extends PythonDataStream implements StreamSource {
private PythonStreamSource(StreamingContext streamingContext, PythonFunction sourceFunction) {
super(streamingContext, new PythonOperator(sourceFunction),
PythonPartition.RoundRobinPartition);
super(streamingContext, new PythonOperator(sourceFunction));
withChainStrategy(ChainStrategy.HEAD);
}
public static PythonStreamSource from(StreamingContext streamingContext,
@@ -14,6 +14,8 @@ public class PythonUnionStream extends PythonDataStream {
private List<PythonDataStream> unionStreams;
public PythonUnionStream(PythonDataStream input, List<PythonDataStream> others) {
// Union stream does not create a physical operation, so we don't have to set partition
// function for it.
super(input, new PythonOperator(
"ray.streaming.operator", "UnionOperator"));
this.unionStreams = new ArrayList<>();
@@ -2,8 +2,8 @@ package io.ray.streaming.jobgraph;
import com.google.common.collect.Lists;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.partition.impl.ForwardPartition;
import io.ray.streaming.api.partition.impl.KeyPartition;
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.DataStreamSource;
import io.ray.streaming.api.stream.StreamSink;
@@ -20,14 +20,14 @@ public class JobGraphBuilderTest {
@Test
public void testDataSync() {
JobGraph jobGraph = buildDataSyncJobGraph();
List<JobVertex> jobVertexList = jobGraph.getJobVertexList();
List<JobEdge> jobEdgeList = jobGraph.getJobEdgeList();
List<JobVertex> jobVertexList = jobGraph.getJobVertices();
List<JobEdge> jobEdgeList = jobGraph.getJobEdges();
Assert.assertEquals(jobVertexList.size(), 2);
Assert.assertEquals(jobEdgeList.size(), 1);
JobEdge jobEdge = jobEdgeList.get(0);
Assert.assertEquals(jobEdge.getPartition().getClass(), RoundRobinPartition.class);
Assert.assertEquals(jobEdge.getPartition().getClass(), ForwardPartition.class);
JobVertex sinkVertex = jobVertexList.get(1);
JobVertex sourceVertex = jobVertexList.get(0);
@@ -50,8 +50,8 @@ public class JobGraphBuilderTest {
@Test
public void testKeyByJobGraph() {
JobGraph jobGraph = buildKeyByJobGraph();
List<JobVertex> jobVertexList = jobGraph.getJobVertexList();
List<JobEdge> jobEdgeList = jobGraph.getJobEdgeList();
List<JobVertex> jobVertexList = jobGraph.getJobVertices();
List<JobEdge> jobEdgeList = jobGraph.getJobEdges();
Assert.assertEquals(jobVertexList.size(), 3);
Assert.assertEquals(jobEdgeList.size(), 2);
@@ -68,7 +68,7 @@ public class JobGraphBuilderTest {
JobEdge source2KeyBy = jobEdgeList.get(1);
Assert.assertEquals(keyBy2Sink.getPartition().getClass(), KeyPartition.class);
Assert.assertEquals(source2KeyBy.getPartition().getClass(), RoundRobinPartition.class);
Assert.assertEquals(source2KeyBy.getPartition().getClass(), ForwardPartition.class);
}
public JobGraph buildKeyByJobGraph() {
@@ -88,8 +88,8 @@ public class JobGraphBuilderTest {
JobGraph jobGraph = buildKeyByJobGraph();
jobGraph.generateDigraph();
String diGraph = jobGraph.getDigraph();
System.out.println(diGraph);
Assert.assertTrue(diGraph.contains("1-SourceOperator -> 2-KeyByOperator"));
Assert.assertTrue(diGraph.contains("2-KeyByOperator -> 3-SinkOperator"));
LOG.info(diGraph);
Assert.assertTrue(diGraph.contains("\"1-SourceOperatorImpl\" -> \"2-KeyByOperator\""));
Assert.assertTrue(diGraph.contains("\"2-KeyByOperator\" -> \"3-SinkOperator\""));
}
}
@@ -0,0 +1,70 @@
package io.ray.streaming.jobgraph;
import static org.testng.Assert.assertEquals;
import com.google.common.collect.Lists;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.DataStreamSource;
import io.ray.streaming.python.PythonFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.annotations.Test;
public class JobGraphOptimizerTest {
private static final Logger LOG = LoggerFactory.getLogger( JobGraphOptimizerTest.class );
@Test
public void testOptimize() {
StreamingContext context = StreamingContext.buildContext();
DataStream<Integer> source1 = DataStreamSource.fromCollection(context,
Lists.newArrayList(1 ,2 ,3));
DataStream<String> source2 = DataStreamSource.fromCollection(context,
Lists.newArrayList("1", "2", "3"));
DataStream<String> source3 = DataStreamSource.fromCollection(context,
Lists.newArrayList("2", "3", "4"));
source1.filter(x -> x > 1)
.map(String::valueOf)
.union(source2)
.join(source3)
.sink(x -> System.out.println("Sink " + x));
JobGraph jobGraph = new JobGraphBuilder(context.getStreamSinks()).build();
LOG.info("Digraph {}", jobGraph.generateDigraph());
assertEquals(jobGraph.getJobVertices().size(), 8);
JobGraphOptimizer graphOptimizer = new JobGraphOptimizer(jobGraph);
JobGraph optimizedJobGraph = graphOptimizer.optimize();
optimizedJobGraph.printJobGraph();
LOG.info("Optimized graph {}", optimizedJobGraph.generateDigraph());
assertEquals(optimizedJobGraph.getJobVertices().size(), 5);
}
@Test
public void testOptimizeHybridStream() {
StreamingContext context = StreamingContext.buildContext();
DataStream<Integer> source1 = DataStreamSource.fromCollection(context,
Lists.newArrayList(1 ,2 ,3));
DataStream<String> source2 = DataStreamSource.fromCollection(context,
Lists.newArrayList("1", "2", "3"));
source1.asPythonStream()
.map(pyFunc(1))
.filter(pyFunc(2))
.union(source2.asPythonStream().filter(pyFunc(3)).map(pyFunc(4)))
.asJavaStream()
.sink(x -> System.out.println("Sink " + x));
JobGraph jobGraph = new JobGraphBuilder(context.getStreamSinks()).build();
LOG.info("Digraph {}", jobGraph.generateDigraph());
assertEquals(jobGraph.getJobVertices().size(), 8);
JobGraphOptimizer graphOptimizer = new JobGraphOptimizer(jobGraph);
JobGraph optimizedJobGraph = graphOptimizer.optimize();
optimizedJobGraph.printJobGraph();
LOG.info("Optimized graph {}", optimizedJobGraph.generateDigraph());
assertEquals(optimizedJobGraph.getJobVertices().size(), 6);
}
private PythonFunction pyFunc(int number) {
return new PythonFunction("module", "func" + number);
}
}