mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[Streaming] operator chain (#8910)
This commit is contained in:
+3
-1
@@ -6,6 +6,7 @@ import io.ray.streaming.api.stream.StreamSink;
|
||||
import io.ray.streaming.client.JobClient;
|
||||
import io.ray.streaming.jobgraph.JobGraph;
|
||||
import io.ray.streaming.jobgraph.JobGraphBuilder;
|
||||
import io.ray.streaming.jobgraph.JobGraphOptimizer;
|
||||
import io.ray.streaming.util.Config;
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
@@ -56,7 +57,8 @@ public class StreamingContext implements Serializable {
|
||||
*/
|
||||
public void execute(String jobName) {
|
||||
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(this.streamSinks, jobName);
|
||||
this.jobGraph = jobGraphBuilder.build();
|
||||
JobGraph originalJobGraph = jobGraphBuilder.build();
|
||||
this.jobGraph = new JobGraphOptimizer(originalJobGraph).optimize();
|
||||
jobGraph.printJobGraph();
|
||||
LOG.info("JobGraph digraph\n{}", jobGraph.generateDigraph());
|
||||
|
||||
|
||||
+19
@@ -0,0 +1,19 @@
|
||||
package io.ray.streaming.api.partition.impl;
|
||||
|
||||
import io.ray.streaming.api.partition.Partition;
|
||||
|
||||
/**
|
||||
* Default partition for operator if the operator can be chained with succeeding operators.
|
||||
* Partition will be set to {@link RoundRobinPartition} if the operator can't be chiained with
|
||||
* succeeding operators.
|
||||
*
|
||||
* @param <T> Type of the input record.
|
||||
*/
|
||||
public class ForwardPartition<T> implements Partition<T> {
|
||||
private int[] partitions = new int[] {0};
|
||||
|
||||
@Override
|
||||
public int[] partition(T record, int numPartition) {
|
||||
return partitions;
|
||||
}
|
||||
}
|
||||
+2
-3
@@ -3,8 +3,7 @@ package io.ray.streaming.api.stream;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.function.impl.SourceFunction;
|
||||
import io.ray.streaming.api.function.internal.CollectionSourceFunction;
|
||||
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
|
||||
import io.ray.streaming.operator.impl.SourceOperator;
|
||||
import io.ray.streaming.operator.impl.SourceOperatorImpl;
|
||||
import java.util.Collection;
|
||||
|
||||
/**
|
||||
@@ -15,7 +14,7 @@ import java.util.Collection;
|
||||
public class DataStreamSource<T> extends DataStream<T> implements StreamSource<T> {
|
||||
|
||||
private DataStreamSource(StreamingContext streamingContext, SourceFunction<T> sourceFunction) {
|
||||
super(streamingContext, new SourceOperator<>(sourceFunction), new RoundRobinPartition<>());
|
||||
super(streamingContext, new SourceOperatorImpl<>(sourceFunction));
|
||||
}
|
||||
|
||||
public static <T> DataStreamSource<T> fromSource(
|
||||
|
||||
+24
-22
@@ -2,6 +2,7 @@ package io.ray.streaming.api.stream;
|
||||
|
||||
import io.ray.streaming.api.function.impl.JoinFunction;
|
||||
import io.ray.streaming.api.function.impl.KeyFunction;
|
||||
import io.ray.streaming.operator.impl.JoinOperator;
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
@@ -9,40 +10,42 @@ import java.io.Serializable;
|
||||
*
|
||||
* @param <L> Type of the data in the left stream.
|
||||
* @param <R> Type of the data in the right stream.
|
||||
* @param <J> Type of the data in the joined stream.
|
||||
* @param <O> Type of the data in the joined stream.
|
||||
*/
|
||||
public class JoinStream<L, R, J> extends DataStream<L> {
|
||||
public class JoinStream<L, R, O> extends DataStream<L> {
|
||||
private final DataStream<R> rightStream;
|
||||
|
||||
public JoinStream(DataStream<L> leftStream, DataStream<R> rightStream) {
|
||||
super(leftStream, null);
|
||||
super(leftStream, new JoinOperator<>());
|
||||
this.rightStream = rightStream;
|
||||
}
|
||||
|
||||
public DataStream<R> getRightStream() {
|
||||
return rightStream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply key-by to the left join stream.
|
||||
*/
|
||||
public <K> Where<L, R, J, K> where(KeyFunction<L, K> keyFunction) {
|
||||
public <K> Where<K> where(KeyFunction<L, K> keyFunction) {
|
||||
return new Where<>(this, keyFunction);
|
||||
}
|
||||
|
||||
/**
|
||||
* Where clause of the join transformation.
|
||||
*
|
||||
* @param <L> Type of the data in the left stream.
|
||||
* @param <R> Type of the data in the right stream.
|
||||
* @param <J> Type of the data in the joined stream.
|
||||
* @param <K> Type of the join key.
|
||||
*/
|
||||
class Where<L, R, J, K> implements Serializable {
|
||||
|
||||
private JoinStream<L, R, J> joinStream;
|
||||
class Where<K> implements Serializable {
|
||||
private JoinStream<L, R, O> joinStream;
|
||||
private KeyFunction<L, K> leftKeyByFunction;
|
||||
|
||||
public Where(JoinStream<L, R, J> joinStream, KeyFunction<L, K> leftKeyByFunction) {
|
||||
Where(JoinStream<L, R, O> joinStream, KeyFunction<L, K> leftKeyByFunction) {
|
||||
this.joinStream = joinStream;
|
||||
this.leftKeyByFunction = leftKeyByFunction;
|
||||
}
|
||||
|
||||
public Equal<L, R, J, K> equalLo(KeyFunction<R, K> rightKeyFunction) {
|
||||
public Equal<K> equalTo(KeyFunction<R, K> rightKeyFunction) {
|
||||
return new Equal<>(joinStream, leftKeyByFunction, rightKeyFunction);
|
||||
}
|
||||
}
|
||||
@@ -50,26 +53,25 @@ public class JoinStream<L, R, J> extends DataStream<L> {
|
||||
/**
|
||||
* Equal clause of the join transformation.
|
||||
*
|
||||
* @param <L> Type of the data in the left stream.
|
||||
* @param <R> Type of the data in the right stream.
|
||||
* @param <J> Type of the data in the joined stream.
|
||||
* @param <K> Type of the join key.
|
||||
*/
|
||||
class Equal<L, R, J, K> implements Serializable {
|
||||
|
||||
private JoinStream<L, R, J> joinStream;
|
||||
class Equal<K> implements Serializable {
|
||||
private JoinStream<L, R, O> joinStream;
|
||||
private KeyFunction<L, K> leftKeyByFunction;
|
||||
private KeyFunction<R, K> rightKeyByFunction;
|
||||
|
||||
public Equal(JoinStream<L, R, J> joinStream, KeyFunction<L, K> leftKeyByFunction,
|
||||
KeyFunction<R, K> rightKeyByFunction) {
|
||||
Equal(JoinStream<L, R, O> joinStream, KeyFunction<L, K> leftKeyByFunction,
|
||||
KeyFunction<R, K> rightKeyByFunction) {
|
||||
this.joinStream = joinStream;
|
||||
this.leftKeyByFunction = leftKeyByFunction;
|
||||
this.rightKeyByFunction = rightKeyByFunction;
|
||||
}
|
||||
|
||||
public DataStream<J> with(JoinFunction<L, R, J> joinFunction) {
|
||||
return (DataStream<J>) joinStream;
|
||||
@SuppressWarnings("unchecked")
|
||||
public DataStream<O> with(JoinFunction<L, R, O> joinFunction) {
|
||||
JoinOperator joinOperator = (JoinOperator) joinStream.getOperator();
|
||||
joinOperator.setFunction(joinFunction);
|
||||
return (DataStream<O>) joinStream;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+35
-10
@@ -4,7 +4,8 @@ import com.google.common.base.Preconditions;
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.partition.Partition;
|
||||
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
|
||||
import io.ray.streaming.api.partition.impl.ForwardPartition;
|
||||
import io.ray.streaming.operator.ChainStrategy;
|
||||
import io.ray.streaming.operator.Operator;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import io.ray.streaming.python.PythonPartition;
|
||||
@@ -30,8 +31,7 @@ public abstract class Stream<S extends Stream<S, T>, T>
|
||||
private Stream originalStream;
|
||||
|
||||
public Stream(StreamingContext streamingContext, StreamOperator streamOperator) {
|
||||
this(streamingContext, null, streamOperator,
|
||||
selectPartition(streamOperator));
|
||||
this(streamingContext, null, streamOperator, getForwardPartition(streamOperator));
|
||||
}
|
||||
|
||||
public Stream(StreamingContext streamingContext,
|
||||
@@ -42,7 +42,7 @@ public abstract class Stream<S extends Stream<S, T>, T>
|
||||
|
||||
public Stream(Stream inputStream, StreamOperator streamOperator) {
|
||||
this(inputStream.getStreamingContext(), inputStream, streamOperator,
|
||||
selectPartition(streamOperator));
|
||||
getForwardPartition(streamOperator));
|
||||
}
|
||||
|
||||
public Stream(Stream inputStream, StreamOperator streamOperator, Partition<T> partition) {
|
||||
@@ -50,9 +50,9 @@ public abstract class Stream<S extends Stream<S, T>, T>
|
||||
}
|
||||
|
||||
protected Stream(StreamingContext streamingContext,
|
||||
Stream inputStream,
|
||||
StreamOperator streamOperator,
|
||||
Partition<T> partition) {
|
||||
Stream inputStream,
|
||||
StreamOperator streamOperator,
|
||||
Partition<T> partition) {
|
||||
this.streamingContext = streamingContext;
|
||||
this.inputStream = inputStream;
|
||||
this.operator = streamOperator;
|
||||
@@ -73,15 +73,16 @@ public abstract class Stream<S extends Stream<S, T>, T>
|
||||
this.streamingContext = originalStream.getStreamingContext();
|
||||
this.inputStream = originalStream.getInputStream();
|
||||
this.operator = originalStream.getOperator();
|
||||
Preconditions.checkNotNull(operator);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static <T> Partition<T> selectPartition(Operator operator) {
|
||||
private static <T> Partition<T> getForwardPartition(Operator operator) {
|
||||
switch (operator.getLanguage()) {
|
||||
case PYTHON:
|
||||
return (Partition<T>) PythonPartition.RoundRobinPartition;
|
||||
return (Partition<T>) PythonPartition.ForwardPartition;
|
||||
case JAVA:
|
||||
return new RoundRobinPartition<>();
|
||||
return new ForwardPartition<>();
|
||||
default:
|
||||
throw new UnsupportedOperationException(
|
||||
"Unsupported language " + operator.getLanguage());
|
||||
@@ -165,5 +166,29 @@ public abstract class Stream<S extends Stream<S, T>, T>
|
||||
return originalStream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set chain strategy for this stream
|
||||
*/
|
||||
public S withChainStrategy(ChainStrategy chainStrategy) {
|
||||
Preconditions.checkArgument(!isProxyStream());
|
||||
operator.setChainStrategy(chainStrategy);
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* Disable chain for this stream
|
||||
*/
|
||||
public S disableChain() {
|
||||
return withChainStrategy(ChainStrategy.NEVER);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the partition function of this {@link Stream} so that output elements are forwarded to
|
||||
* next operator locally.
|
||||
*/
|
||||
public S forward() {
|
||||
return setPartition(getForwardPartition(operator));
|
||||
}
|
||||
|
||||
public abstract Language getLanguage();
|
||||
}
|
||||
|
||||
+2
@@ -16,6 +16,8 @@ public class UnionStream<T> extends DataStream<T> {
|
||||
private List<DataStream<T>> unionStreams;
|
||||
|
||||
public UnionStream(DataStream<T> input, List<DataStream<T>> streams) {
|
||||
// Union stream does not create a physical operation, so we don't have to set partition
|
||||
// function for it.
|
||||
super(input, new UnionOperator());
|
||||
this.unionStreams = new ArrayList<>();
|
||||
streams.forEach(this::addStream);
|
||||
|
||||
+57
-18
@@ -5,6 +5,8 @@ import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -17,15 +19,24 @@ public class JobGraph implements Serializable {
|
||||
|
||||
private final String jobName;
|
||||
private final Map<String, String> jobConfig;
|
||||
private List<JobVertex> jobVertexList;
|
||||
private List<JobEdge> jobEdgeList;
|
||||
private List<JobVertex> jobVertices;
|
||||
private List<JobEdge> jobEdges;
|
||||
private String digraph;
|
||||
|
||||
public JobGraph(String jobName, Map<String, String> jobConfig) {
|
||||
this.jobName = jobName;
|
||||
this.jobConfig = jobConfig;
|
||||
this.jobVertexList = new ArrayList<>();
|
||||
this.jobEdgeList = new ArrayList<>();
|
||||
this.jobVertices = new ArrayList<>();
|
||||
this.jobEdges = new ArrayList<>();
|
||||
}
|
||||
|
||||
public JobGraph(String jobName, Map<String, String> jobConfig,
|
||||
List<JobVertex> jobVertices, List<JobEdge> jobEdges) {
|
||||
this.jobName = jobName;
|
||||
this.jobConfig = jobConfig;
|
||||
this.jobVertices = jobVertices;
|
||||
this.jobEdges = jobEdges;
|
||||
generateDigraph();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -36,12 +47,12 @@ public class JobGraph implements Serializable {
|
||||
*/
|
||||
public String generateDigraph() {
|
||||
StringBuilder digraph = new StringBuilder();
|
||||
digraph.append("digraph ").append(jobName + " ").append(" {");
|
||||
digraph.append("digraph ").append(jobName).append(" ").append(" {");
|
||||
|
||||
for (JobEdge jobEdge : jobEdgeList) {
|
||||
for (JobEdge jobEdge : jobEdges) {
|
||||
String srcNode = null;
|
||||
String targetNode = null;
|
||||
for (JobVertex jobVertex : jobVertexList) {
|
||||
for (JobVertex jobVertex : jobVertices) {
|
||||
if (jobEdge.getSrcVertexId() == jobVertex.getVertexId()) {
|
||||
srcNode = jobVertex.getVertexId() + "-" + jobVertex.getStreamOperator().getName();
|
||||
} else if (jobEdge.getTargetVertexId() == jobVertex.getVertexId()) {
|
||||
@@ -49,7 +60,7 @@ public class JobGraph implements Serializable {
|
||||
}
|
||||
}
|
||||
digraph.append(System.getProperty("line.separator"));
|
||||
digraph.append(srcNode).append(" -> ").append(targetNode);
|
||||
digraph.append(String.format(" \"%s\" -> \"%s\"", srcNode, targetNode));
|
||||
}
|
||||
digraph.append(System.getProperty("line.separator")).append("}");
|
||||
|
||||
@@ -58,19 +69,47 @@ public class JobGraph implements Serializable {
|
||||
}
|
||||
|
||||
public void addVertex(JobVertex vertex) {
|
||||
this.jobVertexList.add(vertex);
|
||||
this.jobVertices.add(vertex);
|
||||
}
|
||||
|
||||
public void addEdge(JobEdge jobEdge) {
|
||||
this.jobEdgeList.add(jobEdge);
|
||||
this.jobEdges.add(jobEdge);
|
||||
}
|
||||
|
||||
public List<JobVertex> getJobVertexList() {
|
||||
return jobVertexList;
|
||||
public List<JobVertex> getJobVertices() {
|
||||
return jobVertices;
|
||||
}
|
||||
|
||||
public List<JobEdge> getJobEdgeList() {
|
||||
return jobEdgeList;
|
||||
public List<JobVertex> getSourceVertices() {
|
||||
return jobVertices.stream()
|
||||
.filter(v -> v.getVertexType() == VertexType.SOURCE)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public List<JobVertex> getSinkVertices() {
|
||||
return jobVertices.stream()
|
||||
.filter(v -> v.getVertexType() == VertexType.SINK)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public JobVertex getVertex(int vertexId) {
|
||||
return jobVertices.stream().filter(v -> v.getVertexId() == vertexId).findFirst().get();
|
||||
}
|
||||
|
||||
public List<JobEdge> getJobEdges() {
|
||||
return jobEdges;
|
||||
}
|
||||
|
||||
public Set<JobEdge> getVertexInputEdges(int vertexId) {
|
||||
return jobEdges.stream()
|
||||
.filter(jobEdge -> jobEdge.getTargetVertexId() == vertexId)
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
public Set<JobEdge> getVertexOutputEdges(int vertexId) {
|
||||
return jobEdges.stream()
|
||||
.filter(jobEdge -> jobEdge.getSrcVertexId() == vertexId)
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
public String getDigraph() {
|
||||
@@ -90,17 +129,17 @@ public class JobGraph implements Serializable {
|
||||
return;
|
||||
}
|
||||
LOG.info("Printing job graph:");
|
||||
for (JobVertex jobVertex : jobVertexList) {
|
||||
for (JobVertex jobVertex : jobVertices) {
|
||||
LOG.info(jobVertex.toString());
|
||||
}
|
||||
for (JobEdge jobEdge : jobEdgeList) {
|
||||
for (JobEdge jobEdge : jobEdges) {
|
||||
LOG.info(jobEdge.toString());
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isCrossLanguageGraph() {
|
||||
Language language = jobVertexList.get(0).getLanguage();
|
||||
for (JobVertex jobVertex : jobVertexList) {
|
||||
Language language = jobVertices.get(0).getLanguage();
|
||||
for (JobVertex jobVertex : jobVertices) {
|
||||
if (jobVertex.getLanguage() != language) {
|
||||
return true;
|
||||
}
|
||||
|
||||
+15
-5
@@ -2,6 +2,7 @@ package io.ray.streaming.jobgraph;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.streaming.api.stream.DataStream;
|
||||
import io.ray.streaming.api.stream.JoinStream;
|
||||
import io.ray.streaming.api.stream.Stream;
|
||||
import io.ray.streaming.api.stream.StreamSink;
|
||||
import io.ray.streaming.api.stream.StreamSource;
|
||||
@@ -26,7 +27,7 @@ public class JobGraphBuilder {
|
||||
private List<StreamSink> streamSinkList;
|
||||
|
||||
public JobGraphBuilder(List<StreamSink> streamSinkList) {
|
||||
this(streamSinkList, "job-" + System.currentTimeMillis());
|
||||
this(streamSinkList, "job_" + System.currentTimeMillis());
|
||||
}
|
||||
|
||||
public JobGraphBuilder(List<StreamSink> streamSinkList, String jobName) {
|
||||
@@ -61,18 +62,20 @@ public class JobGraphBuilder {
|
||||
"Reference stream should be skipped.");
|
||||
int vertexId = stream.getId();
|
||||
int parallelism = stream.getParallelism();
|
||||
Map<String, String> config = stream.getConfig();
|
||||
JobVertex jobVertex;
|
||||
if (stream instanceof StreamSink) {
|
||||
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator);
|
||||
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator, config);
|
||||
Stream parentStream = stream.getInputStream();
|
||||
int inputVertexId = parentStream.getId();
|
||||
JobEdge jobEdge = new JobEdge(inputVertexId, vertexId, parentStream.getPartition());
|
||||
this.jobGraph.addEdge(jobEdge);
|
||||
processStream(parentStream);
|
||||
} else if (stream instanceof StreamSource) {
|
||||
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SOURCE, streamOperator);
|
||||
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SOURCE, streamOperator, config);
|
||||
} else if (stream instanceof DataStream || stream instanceof PythonDataStream) {
|
||||
jobVertex = new JobVertex(vertexId, parallelism, VertexType.TRANSFORMATION, streamOperator);
|
||||
jobVertex = new JobVertex(
|
||||
vertexId, parallelism, VertexType.TRANSFORMATION, streamOperator, config);
|
||||
Stream parentStream = stream.getInputStream();
|
||||
int inputVertexId = parentStream.getId();
|
||||
JobEdge jobEdge = new JobEdge(inputVertexId, vertexId, parentStream.getPartition());
|
||||
@@ -92,10 +95,17 @@ public class JobGraphBuilder {
|
||||
this.jobGraph.addEdge(otherEdge);
|
||||
processStream(otherStream);
|
||||
}
|
||||
|
||||
// process join stream
|
||||
if (stream instanceof JoinStream) {
|
||||
DataStream rightStream = ((JoinStream) stream).getRightStream();
|
||||
this.jobGraph.addEdge(
|
||||
new JobEdge(rightStream.getId(), vertexId, rightStream.getPartition()));
|
||||
processStream(rightStream);
|
||||
}
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Unsupported stream: " + stream);
|
||||
}
|
||||
jobVertex.setConfig(stream.getConfig());
|
||||
this.jobGraph.addVertex(jobVertex);
|
||||
}
|
||||
|
||||
|
||||
+187
@@ -0,0 +1,187 @@
|
||||
package io.ray.streaming.jobgraph;
|
||||
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.api.partition.Partition;
|
||||
import io.ray.streaming.api.partition.impl.ForwardPartition;
|
||||
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
|
||||
import io.ray.streaming.operator.ChainStrategy;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import io.ray.streaming.operator.chain.ChainedOperator;
|
||||
import io.ray.streaming.python.PythonOperator;
|
||||
import io.ray.streaming.python.PythonOperator.ChainedPythonOperator;
|
||||
import io.ray.streaming.python.PythonPartition;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
/**
|
||||
* Optimize job graph by chaining some operators so that some operators can be run in the
|
||||
* same thread.
|
||||
*/
|
||||
public class JobGraphOptimizer {
|
||||
private final JobGraph jobGraph;
|
||||
private Set<JobVertex> visited = new HashSet<>();
|
||||
// vertex id -> vertex
|
||||
private Map<Integer, JobVertex> vertexMap;
|
||||
private Map<JobVertex, Set<JobEdge>> outputEdgesMap;
|
||||
// tail vertex id -> mergedVertex
|
||||
private Map<Integer, Pair<JobVertex, List<JobVertex>>> mergedVertexMap;
|
||||
|
||||
public JobGraphOptimizer(JobGraph jobGraph) {
|
||||
this.jobGraph = jobGraph;
|
||||
vertexMap = jobGraph.getJobVertices().stream()
|
||||
.collect(Collectors.toMap(JobVertex::getVertexId, Function.identity()));
|
||||
outputEdgesMap = vertexMap.keySet().stream().collect(Collectors.toMap(
|
||||
id -> vertexMap.get(id), id -> new HashSet<>(jobGraph.getVertexOutputEdges(id))));
|
||||
mergedVertexMap = new HashMap<>();
|
||||
}
|
||||
|
||||
public JobGraph optimize() {
|
||||
// Deep-first traverse nodes from source to sink to merge vertices that can be chained
|
||||
// together.
|
||||
jobGraph.getSourceVertices().forEach(vertex -> {
|
||||
List<JobVertex> verticesToMerge = new ArrayList<>();
|
||||
verticesToMerge.add(vertex);
|
||||
mergeVerticesRecursively(vertex, verticesToMerge);
|
||||
});
|
||||
|
||||
List<JobVertex> vertices = mergedVertexMap.values().stream()
|
||||
.map(Pair::getLeft).collect(Collectors.toList());
|
||||
|
||||
return new JobGraph(jobGraph.getJobName(), jobGraph.getJobConfig(), vertices, createEdges());
|
||||
}
|
||||
|
||||
private void mergeVerticesRecursively(JobVertex vertex, List<JobVertex> verticesToMerge) {
|
||||
if (!visited.contains(vertex)) {
|
||||
visited.add(vertex);
|
||||
Set<JobEdge> outputEdges = outputEdgesMap.get(vertex);
|
||||
if (outputEdges.isEmpty()) {
|
||||
mergeAndAddVertex(verticesToMerge);
|
||||
} else {
|
||||
outputEdges.forEach(edge -> {
|
||||
JobVertex succeedingVertex = vertexMap.get(edge.getTargetVertexId());
|
||||
if (canBeChained(vertex, succeedingVertex, edge)) {
|
||||
verticesToMerge.add(succeedingVertex);
|
||||
mergeVerticesRecursively(succeedingVertex, verticesToMerge);
|
||||
} else {
|
||||
mergeAndAddVertex(verticesToMerge);
|
||||
List<JobVertex> newMergedVertices = new ArrayList<>();
|
||||
newMergedVertices.add(succeedingVertex);
|
||||
mergeVerticesRecursively(succeedingVertex, newMergedVertices);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void mergeAndAddVertex(List<JobVertex> verticesToMerge) {
|
||||
JobVertex mergedVertex;
|
||||
JobVertex headVertex = verticesToMerge.get(0);
|
||||
Language language = headVertex.getLanguage();
|
||||
if (verticesToMerge.size() == 1) {
|
||||
// no chain
|
||||
mergedVertex = headVertex;
|
||||
} else {
|
||||
List<StreamOperator> operators = verticesToMerge.stream()
|
||||
.map(v -> vertexMap.get(v.getVertexId()).getStreamOperator())
|
||||
.collect(Collectors.toList());
|
||||
List<Map<String, String>> configs = verticesToMerge.stream()
|
||||
.map(v -> vertexMap.get(v.getVertexId()).getConfig())
|
||||
.collect(Collectors.toList());
|
||||
StreamOperator operator;
|
||||
if (language == Language.JAVA) {
|
||||
operator = ChainedOperator.newChainedOperator(operators, configs);
|
||||
} else {
|
||||
List<PythonOperator> pythonOperators = operators.stream()
|
||||
.map(o -> (PythonOperator) o).collect(Collectors.toList());
|
||||
operator = new ChainedPythonOperator(pythonOperators, configs);
|
||||
}
|
||||
// chained operator config is placed into `ChainedOperator`.
|
||||
mergedVertex = new JobVertex(headVertex.getVertexId(), headVertex.getParallelism(),
|
||||
headVertex.getVertexType(), operator, new HashMap<>());
|
||||
}
|
||||
|
||||
mergedVertexMap.put(mergedVertex.getVertexId(), Pair.of(mergedVertex, verticesToMerge));
|
||||
}
|
||||
|
||||
private List<JobEdge> createEdges() {
|
||||
List<JobEdge> edges = new ArrayList<>();
|
||||
mergedVertexMap.forEach((id, pair) -> {
|
||||
JobVertex mergedVertex = pair.getLeft();
|
||||
List<JobVertex> mergedVertices = pair.getRight();
|
||||
JobVertex tailVertex = mergedVertices.get(mergedVertices.size() - 1);
|
||||
// input edge will be set up in input vertices
|
||||
if (outputEdgesMap.containsKey(tailVertex)) {
|
||||
outputEdgesMap.get(tailVertex).forEach(edge -> {
|
||||
Pair<JobVertex, List<JobVertex>> downstreamPair =
|
||||
mergedVertexMap.get(edge.getTargetVertexId());
|
||||
// change ForwardPartition to RoundRobinPartition.
|
||||
Partition partition = changePartition(edge.getPartition());
|
||||
JobEdge newEdge = new JobEdge(
|
||||
mergedVertex.getVertexId(),
|
||||
downstreamPair.getLeft().getVertexId(),
|
||||
partition);
|
||||
edges.add(newEdge);
|
||||
});
|
||||
|
||||
}
|
||||
});
|
||||
return edges;
|
||||
}
|
||||
|
||||
/**
|
||||
* Change ForwardPartition to RoundRobinPartition.
|
||||
*/
|
||||
private Partition changePartition(Partition partition) {
|
||||
if (partition instanceof PythonPartition) {
|
||||
PythonPartition pythonPartition = (PythonPartition) partition;
|
||||
if (!pythonPartition.isConstructedFromBinary() &&
|
||||
pythonPartition.getFunctionName().equals(PythonPartition.FORWARD_PARTITION_CLASS)) {
|
||||
return PythonPartition.RoundRobinPartition;
|
||||
} else {
|
||||
return partition;
|
||||
}
|
||||
} else {
|
||||
if (partition instanceof ForwardPartition) {
|
||||
return new RoundRobinPartition();
|
||||
} else {
|
||||
return partition;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private boolean canBeChained(JobVertex precedingVertex,
|
||||
JobVertex succeedingVertex,
|
||||
JobEdge edge) {
|
||||
if (jobGraph.getVertexOutputEdges(precedingVertex.getVertexId()).size() > 1 ||
|
||||
jobGraph.getVertexInputEdges(succeedingVertex.getVertexId()).size() > 1) {
|
||||
return false;
|
||||
}
|
||||
if (precedingVertex.getParallelism() != succeedingVertex.getParallelism()) {
|
||||
return false;
|
||||
}
|
||||
if (precedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.NEVER
|
||||
|| succeedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.NEVER
|
||||
|| succeedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.HEAD) {
|
||||
return false;
|
||||
}
|
||||
if (precedingVertex.getLanguage() != succeedingVertex.getLanguage()) {
|
||||
return false;
|
||||
}
|
||||
Partition partition = edge.getPartition();
|
||||
if (!(partition instanceof PythonPartition)) {
|
||||
return partition instanceof ForwardPartition;
|
||||
} else {
|
||||
PythonPartition pythonPartition = (PythonPartition) partition;
|
||||
return !pythonPartition.isConstructedFromBinary() &&
|
||||
pythonPartition.getFunctionName().equals(PythonPartition.FORWARD_PARTITION_CLASS);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import java.util.Map;
|
||||
* Job vertex is a cell node where logic is executed.
|
||||
*/
|
||||
public class JobVertex implements Serializable {
|
||||
|
||||
private int vertexId;
|
||||
private int parallelism;
|
||||
private VertexType vertexType;
|
||||
@@ -20,12 +21,14 @@ public class JobVertex implements Serializable {
|
||||
public JobVertex(int vertexId,
|
||||
int parallelism,
|
||||
VertexType vertexType,
|
||||
StreamOperator streamOperator) {
|
||||
StreamOperator streamOperator,
|
||||
Map<String, String> config) {
|
||||
this.vertexId = vertexId;
|
||||
this.parallelism = parallelism;
|
||||
this.vertexType = vertexType;
|
||||
this.streamOperator = streamOperator;
|
||||
this.language = streamOperator.getLanguage();
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
public int getVertexId() {
|
||||
@@ -67,4 +70,5 @@ public class JobVertex implements Serializable {
|
||||
.add("config", config)
|
||||
.toString();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
+20
@@ -0,0 +1,20 @@
|
||||
package io.ray.streaming.operator;
|
||||
|
||||
/**
|
||||
* Chain strategy for streaming operators. Chained operators are run in the same thread.
|
||||
*/
|
||||
public enum ChainStrategy {
|
||||
/**
|
||||
* The operator won't be chained with preceding operators, but maybe chained with succeeding
|
||||
* operators.
|
||||
*/
|
||||
HEAD,
|
||||
/**
|
||||
* Operators will be chained together when possible.
|
||||
*/
|
||||
ALWAYS,
|
||||
/**
|
||||
* The operator won't be chained with any operator.
|
||||
*/
|
||||
NEVER
|
||||
}
|
||||
@@ -9,6 +9,8 @@ import java.util.List;
|
||||
|
||||
public interface Operator extends Serializable {
|
||||
|
||||
String getName();
|
||||
|
||||
void open(List<Collector> collectors, RuntimeContext runtimeContext);
|
||||
|
||||
void finish();
|
||||
@@ -20,4 +22,7 @@ public interface Operator extends Serializable {
|
||||
Language getLanguage();
|
||||
|
||||
OperatorType getOpType();
|
||||
|
||||
ChainStrategy getChainStrategy();
|
||||
|
||||
}
|
||||
|
||||
+14
@@ -0,0 +1,14 @@
|
||||
package io.ray.streaming.operator;
|
||||
|
||||
import io.ray.streaming.api.function.impl.SourceFunction.SourceContext;
|
||||
|
||||
public interface SourceOperator<T> extends Operator {
|
||||
|
||||
void run();
|
||||
|
||||
SourceContext<T> getSourceContext();
|
||||
|
||||
default OperatorType getOpType() {
|
||||
return OperatorType.SOURCE;
|
||||
}
|
||||
}
|
||||
+22
-3
@@ -12,13 +12,22 @@ import java.util.List;
|
||||
|
||||
public abstract class StreamOperator<F extends Function> implements Operator {
|
||||
protected final String name;
|
||||
protected final F function;
|
||||
protected final RichFunction richFunction;
|
||||
protected F function;
|
||||
protected RichFunction richFunction;
|
||||
protected List<Collector> collectorList;
|
||||
protected RuntimeContext runtimeContext;
|
||||
private ChainStrategy chainStrategy = ChainStrategy.ALWAYS;
|
||||
|
||||
public StreamOperator(F function) {
|
||||
protected StreamOperator() {
|
||||
this.name = getClass().getSimpleName();
|
||||
}
|
||||
|
||||
protected StreamOperator(F function) {
|
||||
this();
|
||||
setFunction(function);
|
||||
}
|
||||
|
||||
public void setFunction(F function) {
|
||||
this.function = function;
|
||||
this.richFunction = Functions.wrap(function);
|
||||
}
|
||||
@@ -62,7 +71,17 @@ public abstract class StreamOperator<F extends Function> implements Operator {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public void setChainStrategy(ChainStrategy chainStrategy) {
|
||||
this.chainStrategy = chainStrategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChainStrategy getChainStrategy() {
|
||||
return chainStrategy;
|
||||
}
|
||||
}
|
||||
|
||||
+169
@@ -0,0 +1,169 @@
|
||||
package io.ray.streaming.operator.chain;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.api.collector.Collector;
|
||||
import io.ray.streaming.api.context.RuntimeContext;
|
||||
import io.ray.streaming.api.function.Function;
|
||||
import io.ray.streaming.api.function.impl.SourceFunction.SourceContext;
|
||||
import io.ray.streaming.message.Record;
|
||||
import io.ray.streaming.operator.OneInputOperator;
|
||||
import io.ray.streaming.operator.Operator;
|
||||
import io.ray.streaming.operator.OperatorType;
|
||||
import io.ray.streaming.operator.SourceOperator;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import io.ray.streaming.operator.TwoInputOperator;
|
||||
import java.lang.reflect.Proxy;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Abstract base class for chained operators.
|
||||
*/
|
||||
public abstract class ChainedOperator extends StreamOperator<Function> {
|
||||
protected final List<StreamOperator> operators;
|
||||
protected final Operator headOperator;
|
||||
protected final Operator tailOperator;
|
||||
private final List<Map<String, String>> configs;
|
||||
|
||||
public ChainedOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
|
||||
Preconditions.checkArgument(operators.size() >= 2,
|
||||
"Need at lease two operators to be chained together");
|
||||
operators.stream().skip(1)
|
||||
.forEach(operator -> Preconditions.checkArgument(operator instanceof OneInputOperator));
|
||||
this.operators = operators;
|
||||
this.configs = configs;
|
||||
this.headOperator = operators.get(0);
|
||||
this.tailOperator = operators.get(operators.size() - 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void open(List<Collector> collectorList, RuntimeContext runtimeContext) {
|
||||
// Dont' call super.open() as we `open` every operator separately.
|
||||
List<ForwardCollector> succeedingCollectors = operators.stream().skip(1)
|
||||
.map(operator -> new ForwardCollector((OneInputOperator) operator))
|
||||
.collect(Collectors.toList());
|
||||
for (int i = 0; i < operators.size() - 1; i++) {
|
||||
StreamOperator operator = operators.get(i);
|
||||
List<ForwardCollector> forwardCollectors =
|
||||
Collections.singletonList(succeedingCollectors.get(i));
|
||||
operator.open(forwardCollectors, createRuntimeContext(runtimeContext, i));
|
||||
}
|
||||
// tail operator send data to downstream using provided collectors.
|
||||
tailOperator.open(collectorList, createRuntimeContext(runtimeContext, operators.size() - 1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperatorType getOpType() {
|
||||
return headOperator.getOpType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Language getLanguage() {
|
||||
return headOperator.getLanguage();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return operators.stream().map(Operator::getName)
|
||||
.collect(Collectors.joining(" -> ", "[", "]"));
|
||||
}
|
||||
|
||||
public List<StreamOperator> getOperators() {
|
||||
return operators;
|
||||
}
|
||||
|
||||
public Operator getHeadOperator() {
|
||||
return headOperator;
|
||||
}
|
||||
|
||||
public Operator getTailOperator() {
|
||||
return tailOperator;
|
||||
}
|
||||
|
||||
private RuntimeContext createRuntimeContext(RuntimeContext runtimeContext, int index) {
|
||||
return (RuntimeContext) Proxy.newProxyInstance(runtimeContext.getClass().getClassLoader(),
|
||||
new Class[] {RuntimeContext.class},
|
||||
(proxy, method, methodArgs) -> {
|
||||
if (method.getName().equals("getConfig")) {
|
||||
return configs.get(index);
|
||||
} else {
|
||||
return method.invoke(runtimeContext, methodArgs);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public static ChainedOperator newChainedOperator(
|
||||
List<StreamOperator> operators,
|
||||
List<Map<String, String>> configs) {
|
||||
switch (operators.get(0).getOpType()) {
|
||||
case SOURCE:
|
||||
return new ChainedSourceOperator(operators, configs);
|
||||
case ONE_INPUT:
|
||||
return new ChainedOneInputOperator(operators, configs);
|
||||
case TWO_INPUT:
|
||||
return new ChainedTwoInputOperator(operators, configs);
|
||||
default:
|
||||
throw new IllegalArgumentException(
|
||||
"Unsupported operator type " + operators.get(0).getOpType());
|
||||
}
|
||||
}
|
||||
|
||||
static class ChainedSourceOperator<T> extends ChainedOperator
|
||||
implements SourceOperator<T> {
|
||||
private final SourceOperator<T> sourceOperator;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
ChainedSourceOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
|
||||
super(operators, configs);
|
||||
sourceOperator = (SourceOperator<T>) headOperator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
sourceOperator.run();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SourceContext<T> getSourceContext() {
|
||||
return sourceOperator.getSourceContext();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static class ChainedOneInputOperator<T> extends ChainedOperator
|
||||
implements OneInputOperator<T> {
|
||||
private final OneInputOperator<T> inputOperator;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
ChainedOneInputOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
|
||||
super(operators, configs);
|
||||
inputOperator = (OneInputOperator<T>) headOperator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void processElement(Record<T> record) throws Exception {
|
||||
inputOperator.processElement(record);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static class ChainedTwoInputOperator<L, R> extends ChainedOperator
|
||||
implements TwoInputOperator<L, R> {
|
||||
private final TwoInputOperator<L, R> inputOperator;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
ChainedTwoInputOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
|
||||
super(operators, configs);
|
||||
inputOperator = (TwoInputOperator<L, R>) headOperator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void processElement(Record<L> record1, Record<R> record2) {
|
||||
inputOperator.processElement(record1, record2);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
package io.ray.streaming.operator.chain;
|
||||
|
||||
import io.ray.streaming.api.collector.Collector;
|
||||
import io.ray.streaming.message.Record;
|
||||
import io.ray.streaming.operator.OneInputOperator;
|
||||
|
||||
class ForwardCollector implements Collector<Record> {
|
||||
private final OneInputOperator succeedingOperator;
|
||||
|
||||
ForwardCollector(OneInputOperator succeedingOperator) {
|
||||
this.succeedingOperator = succeedingOperator;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public void collect(Record record) {
|
||||
try {
|
||||
succeedingOperator.processElement(record);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
+39
@@ -0,0 +1,39 @@
|
||||
package io.ray.streaming.operator.impl;
|
||||
|
||||
import io.ray.streaming.api.function.impl.JoinFunction;
|
||||
import io.ray.streaming.message.Record;
|
||||
import io.ray.streaming.operator.ChainStrategy;
|
||||
import io.ray.streaming.operator.OperatorType;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import io.ray.streaming.operator.TwoInputOperator;
|
||||
|
||||
/**
|
||||
* Join operator
|
||||
*
|
||||
* @param <L> Type of the data in the left stream.
|
||||
* @param <R> Type of the data in the right stream.
|
||||
* @param <K> Type of the data in the join key.
|
||||
* @param <O> Type of the data in the joined stream.
|
||||
*/
|
||||
public class JoinOperator<L, R, K, O> extends StreamOperator<JoinFunction<L, R, O>> implements
|
||||
TwoInputOperator<L, R> {
|
||||
public JoinOperator() {
|
||||
|
||||
}
|
||||
|
||||
public JoinOperator(JoinFunction<L, R, O> function) {
|
||||
super(function);
|
||||
setChainStrategy(ChainStrategy.HEAD);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void processElement(Record<L> record1, Record<R> record2) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperatorType getOpType() {
|
||||
return OperatorType.TWO_INPUT;
|
||||
}
|
||||
|
||||
}
|
||||
+3
@@ -5,6 +5,7 @@ import io.ray.streaming.api.context.RuntimeContext;
|
||||
import io.ray.streaming.api.function.impl.ReduceFunction;
|
||||
import io.ray.streaming.message.KeyRecord;
|
||||
import io.ray.streaming.message.Record;
|
||||
import io.ray.streaming.operator.ChainStrategy;
|
||||
import io.ray.streaming.operator.OneInputOperator;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import java.util.HashMap;
|
||||
@@ -18,6 +19,7 @@ public class ReduceOperator<K, T> extends StreamOperator<ReduceFunction<T>> impl
|
||||
|
||||
public ReduceOperator(ReduceFunction<T> reduceFunction) {
|
||||
super(reduceFunction);
|
||||
setChainStrategy(ChainStrategy.HEAD);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -41,4 +43,5 @@ public class ReduceOperator<K, T> extends StreamOperator<ReduceFunction<T>> impl
|
||||
collect(record);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
+14
-4
@@ -5,15 +5,19 @@ import io.ray.streaming.api.context.RuntimeContext;
|
||||
import io.ray.streaming.api.function.impl.SourceFunction;
|
||||
import io.ray.streaming.api.function.impl.SourceFunction.SourceContext;
|
||||
import io.ray.streaming.message.Record;
|
||||
import io.ray.streaming.operator.ChainStrategy;
|
||||
import io.ray.streaming.operator.OperatorType;
|
||||
import io.ray.streaming.operator.SourceOperator;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import java.util.List;
|
||||
|
||||
public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
|
||||
public class SourceOperatorImpl<T> extends StreamOperator<SourceFunction<T>>
|
||||
implements SourceOperator {
|
||||
private SourceContextImpl sourceContext;
|
||||
|
||||
public SourceOperator(SourceFunction<T> function) {
|
||||
public SourceOperatorImpl(SourceFunction<T> function) {
|
||||
super(function);
|
||||
setChainStrategy(ChainStrategy.HEAD);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -23,6 +27,7 @@ public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
|
||||
this.function.init(runtimeContext.getParallelism(), runtimeContext.getTaskIndex());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
this.function.run(this.sourceContext);
|
||||
@@ -31,13 +36,17 @@ public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public SourceContext getSourceContext() {
|
||||
return sourceContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperatorType getOpType() {
|
||||
return OperatorType.SOURCE;
|
||||
}
|
||||
|
||||
class SourceContextImpl implements SourceContext<T> {
|
||||
|
||||
private List<Collector> collectors;
|
||||
|
||||
public SourceContextImpl(List<Collector> collectors) {
|
||||
@@ -47,9 +56,10 @@ public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
|
||||
@Override
|
||||
public void collect(T t) throws Exception {
|
||||
for (Collector collector : collectors) {
|
||||
collector.collect(new Record(t));
|
||||
collector.collect(new Record<>(t));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -100,6 +100,14 @@ public class PythonFunction implements Function {
|
||||
return functionInterface;
|
||||
}
|
||||
|
||||
public String toSimpleString() {
|
||||
if (function != null) {
|
||||
return "binary function";
|
||||
} else {
|
||||
return String.format("%s-%s.%s", functionInterface, moduleName, functionName);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringJoiner stringJoiner = new StringJoiner(", ",
|
||||
|
||||
+114
-24
@@ -1,11 +1,16 @@
|
||||
package io.ray.streaming.python;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.api.context.RuntimeContext;
|
||||
import io.ray.streaming.api.function.Function;
|
||||
import io.ray.streaming.operator.Operator;
|
||||
import io.ray.streaming.operator.OperatorType;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.StringJoiner;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Represents a {@link StreamOperator} that wraps python {@link PythonFunction}.
|
||||
@@ -27,30 +32,6 @@ public class PythonOperator extends StreamOperator {
|
||||
this.className = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void open(List list, RuntimeContext runtimeContext) {
|
||||
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
|
||||
throw new UnsupportedOperationException(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() {
|
||||
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
|
||||
throw new UnsupportedOperationException(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
|
||||
throw new UnsupportedOperationException(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperatorType getOpType() {
|
||||
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
|
||||
throw new UnsupportedOperationException(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Language getLanguage() {
|
||||
return Language.PYTHON;
|
||||
@@ -64,6 +45,48 @@ public class PythonOperator extends StreamOperator {
|
||||
return className;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void open(List list, RuntimeContext runtimeContext) {
|
||||
throwUnsupportedException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() {
|
||||
throwUnsupportedException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
throwUnsupportedException();
|
||||
}
|
||||
|
||||
void throwUnsupportedException() {
|
||||
StackTraceElement[] trace = Thread.currentThread().getStackTrace();
|
||||
Preconditions.checkState(trace.length >= 2);
|
||||
StackTraceElement traceElement = trace[2];
|
||||
String msg = String.format("Method %s.%s shouldn't be called.",
|
||||
traceElement.getClassName(), traceElement.getMethodName());
|
||||
throw new UnsupportedOperationException(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperatorType getOpType() {
|
||||
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
|
||||
throw new UnsupportedOperationException(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append(PythonOperator.class.getSimpleName()).append("[");
|
||||
if (function != null) {
|
||||
builder.append(((PythonFunction)function).toSimpleString());
|
||||
} else {
|
||||
builder.append(moduleName).append(".").append(className);
|
||||
}
|
||||
return builder.append("]").toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringJoiner stringJoiner = new StringJoiner(", ",
|
||||
@@ -76,4 +99,71 @@ public class PythonOperator extends StreamOperator {
|
||||
}
|
||||
return stringJoiner.toString();
|
||||
}
|
||||
|
||||
public static class ChainedPythonOperator extends PythonOperator {
|
||||
private final List<PythonOperator> operators;
|
||||
private final PythonOperator headOperator;
|
||||
private final PythonOperator tailOperator;
|
||||
private final List<Map<String, String>> configs;
|
||||
|
||||
public ChainedPythonOperator(
|
||||
List<PythonOperator> operators, List<Map<String, String>> configs) {
|
||||
super(null);
|
||||
Preconditions.checkArgument(!operators.isEmpty());
|
||||
this.operators = operators;
|
||||
this.configs = configs;
|
||||
this.headOperator = operators.get(0);
|
||||
this.tailOperator = operators.get(operators.size() - 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public OperatorType getOpType() {
|
||||
return headOperator.getOpType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Language getLanguage() {
|
||||
return Language.PYTHON;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return operators.stream().map(Operator::getName)
|
||||
.collect(Collectors.joining(" -> ", "[", "]"));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModuleName() {
|
||||
throwUnsupportedException();
|
||||
return null; // impossible
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getClassName() {
|
||||
throwUnsupportedException();
|
||||
return null; // impossible
|
||||
}
|
||||
|
||||
@Override
|
||||
public Function getFunction() {
|
||||
throwUnsupportedException();
|
||||
return null; // impossible
|
||||
}
|
||||
|
||||
public List<PythonOperator> getOperators() {
|
||||
return operators;
|
||||
}
|
||||
|
||||
public PythonOperator getHeadOperator() {
|
||||
return headOperator;
|
||||
}
|
||||
|
||||
public PythonOperator getTailOperator() {
|
||||
return tailOperator;
|
||||
}
|
||||
|
||||
public List<Map<String, String>> getConfigs() {
|
||||
return configs;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+7
@@ -24,6 +24,9 @@ public class PythonPartition implements Partition<Object> {
|
||||
"ray.streaming.partition", "KeyPartition");
|
||||
public static final PythonPartition RoundRobinPartition = new PythonPartition(
|
||||
"ray.streaming.partition", "RoundRobinPartition");
|
||||
public static final String FORWARD_PARTITION_CLASS = "ForwardPartition";
|
||||
public static final PythonPartition ForwardPartition = new PythonPartition(
|
||||
"ray.streaming.partition", FORWARD_PARTITION_CLASS);
|
||||
|
||||
private byte[] partition;
|
||||
private String moduleName;
|
||||
@@ -66,6 +69,10 @@ public class PythonPartition implements Partition<Object> {
|
||||
return functionName;
|
||||
}
|
||||
|
||||
public boolean isConstructedFromBinary() {
|
||||
return partition != null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringJoiner stringJoiner = new StringJoiner(", ",
|
||||
|
||||
+4
-1
@@ -2,6 +2,7 @@ package io.ray.streaming.python.stream;
|
||||
|
||||
import io.ray.streaming.api.stream.DataStream;
|
||||
import io.ray.streaming.api.stream.KeyDataStream;
|
||||
import io.ray.streaming.operator.ChainStrategy;
|
||||
import io.ray.streaming.python.PythonFunction;
|
||||
import io.ray.streaming.python.PythonFunction.FunctionInterface;
|
||||
import io.ray.streaming.python.PythonOperator;
|
||||
@@ -37,7 +38,9 @@ public class PythonKeyDataStream extends PythonDataStream implements PythonStrea
|
||||
*/
|
||||
public PythonDataStream reduce(PythonFunction func) {
|
||||
func.setFunctionInterface(FunctionInterface.REDUCE_FUNCTION);
|
||||
return new PythonDataStream(this, new PythonOperator(func));
|
||||
PythonDataStream stream = new PythonDataStream(this, new PythonOperator(func));
|
||||
stream.withChainStrategy(ChainStrategy.HEAD);
|
||||
return stream;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
+3
-3
@@ -2,10 +2,10 @@ package io.ray.streaming.python.stream;
|
||||
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.stream.StreamSource;
|
||||
import io.ray.streaming.operator.ChainStrategy;
|
||||
import io.ray.streaming.python.PythonFunction;
|
||||
import io.ray.streaming.python.PythonFunction.FunctionInterface;
|
||||
import io.ray.streaming.python.PythonOperator;
|
||||
import io.ray.streaming.python.PythonPartition;
|
||||
|
||||
/**
|
||||
* Represents a source of the PythonStream.
|
||||
@@ -13,8 +13,8 @@ import io.ray.streaming.python.PythonPartition;
|
||||
public class PythonStreamSource extends PythonDataStream implements StreamSource {
|
||||
|
||||
private PythonStreamSource(StreamingContext streamingContext, PythonFunction sourceFunction) {
|
||||
super(streamingContext, new PythonOperator(sourceFunction),
|
||||
PythonPartition.RoundRobinPartition);
|
||||
super(streamingContext, new PythonOperator(sourceFunction));
|
||||
withChainStrategy(ChainStrategy.HEAD);
|
||||
}
|
||||
|
||||
public static PythonStreamSource from(StreamingContext streamingContext,
|
||||
|
||||
+2
@@ -14,6 +14,8 @@ public class PythonUnionStream extends PythonDataStream {
|
||||
private List<PythonDataStream> unionStreams;
|
||||
|
||||
public PythonUnionStream(PythonDataStream input, List<PythonDataStream> others) {
|
||||
// Union stream does not create a physical operation, so we don't have to set partition
|
||||
// function for it.
|
||||
super(input, new PythonOperator(
|
||||
"ray.streaming.operator", "UnionOperator"));
|
||||
this.unionStreams = new ArrayList<>();
|
||||
|
||||
+10
-10
@@ -2,8 +2,8 @@ package io.ray.streaming.jobgraph;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.partition.impl.ForwardPartition;
|
||||
import io.ray.streaming.api.partition.impl.KeyPartition;
|
||||
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
|
||||
import io.ray.streaming.api.stream.DataStream;
|
||||
import io.ray.streaming.api.stream.DataStreamSource;
|
||||
import io.ray.streaming.api.stream.StreamSink;
|
||||
@@ -20,14 +20,14 @@ public class JobGraphBuilderTest {
|
||||
@Test
|
||||
public void testDataSync() {
|
||||
JobGraph jobGraph = buildDataSyncJobGraph();
|
||||
List<JobVertex> jobVertexList = jobGraph.getJobVertexList();
|
||||
List<JobEdge> jobEdgeList = jobGraph.getJobEdgeList();
|
||||
List<JobVertex> jobVertexList = jobGraph.getJobVertices();
|
||||
List<JobEdge> jobEdgeList = jobGraph.getJobEdges();
|
||||
|
||||
Assert.assertEquals(jobVertexList.size(), 2);
|
||||
Assert.assertEquals(jobEdgeList.size(), 1);
|
||||
|
||||
JobEdge jobEdge = jobEdgeList.get(0);
|
||||
Assert.assertEquals(jobEdge.getPartition().getClass(), RoundRobinPartition.class);
|
||||
Assert.assertEquals(jobEdge.getPartition().getClass(), ForwardPartition.class);
|
||||
|
||||
JobVertex sinkVertex = jobVertexList.get(1);
|
||||
JobVertex sourceVertex = jobVertexList.get(0);
|
||||
@@ -50,8 +50,8 @@ public class JobGraphBuilderTest {
|
||||
@Test
|
||||
public void testKeyByJobGraph() {
|
||||
JobGraph jobGraph = buildKeyByJobGraph();
|
||||
List<JobVertex> jobVertexList = jobGraph.getJobVertexList();
|
||||
List<JobEdge> jobEdgeList = jobGraph.getJobEdgeList();
|
||||
List<JobVertex> jobVertexList = jobGraph.getJobVertices();
|
||||
List<JobEdge> jobEdgeList = jobGraph.getJobEdges();
|
||||
|
||||
Assert.assertEquals(jobVertexList.size(), 3);
|
||||
Assert.assertEquals(jobEdgeList.size(), 2);
|
||||
@@ -68,7 +68,7 @@ public class JobGraphBuilderTest {
|
||||
JobEdge source2KeyBy = jobEdgeList.get(1);
|
||||
|
||||
Assert.assertEquals(keyBy2Sink.getPartition().getClass(), KeyPartition.class);
|
||||
Assert.assertEquals(source2KeyBy.getPartition().getClass(), RoundRobinPartition.class);
|
||||
Assert.assertEquals(source2KeyBy.getPartition().getClass(), ForwardPartition.class);
|
||||
}
|
||||
|
||||
public JobGraph buildKeyByJobGraph() {
|
||||
@@ -88,8 +88,8 @@ public class JobGraphBuilderTest {
|
||||
JobGraph jobGraph = buildKeyByJobGraph();
|
||||
jobGraph.generateDigraph();
|
||||
String diGraph = jobGraph.getDigraph();
|
||||
System.out.println(diGraph);
|
||||
Assert.assertTrue(diGraph.contains("1-SourceOperator -> 2-KeyByOperator"));
|
||||
Assert.assertTrue(diGraph.contains("2-KeyByOperator -> 3-SinkOperator"));
|
||||
LOG.info(diGraph);
|
||||
Assert.assertTrue(diGraph.contains("\"1-SourceOperatorImpl\" -> \"2-KeyByOperator\""));
|
||||
Assert.assertTrue(diGraph.contains("\"2-KeyByOperator\" -> \"3-SinkOperator\""));
|
||||
}
|
||||
}
|
||||
+70
@@ -0,0 +1,70 @@
|
||||
package io.ray.streaming.jobgraph;
|
||||
|
||||
import static org.testng.Assert.assertEquals;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.stream.DataStream;
|
||||
import io.ray.streaming.api.stream.DataStreamSource;
|
||||
import io.ray.streaming.python.PythonFunction;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class JobGraphOptimizerTest {
|
||||
private static final Logger LOG = LoggerFactory.getLogger( JobGraphOptimizerTest.class );
|
||||
|
||||
@Test
|
||||
public void testOptimize() {
|
||||
StreamingContext context = StreamingContext.buildContext();
|
||||
DataStream<Integer> source1 = DataStreamSource.fromCollection(context,
|
||||
Lists.newArrayList(1 ,2 ,3));
|
||||
DataStream<String> source2 = DataStreamSource.fromCollection(context,
|
||||
Lists.newArrayList("1", "2", "3"));
|
||||
DataStream<String> source3 = DataStreamSource.fromCollection(context,
|
||||
Lists.newArrayList("2", "3", "4"));
|
||||
source1.filter(x -> x > 1)
|
||||
.map(String::valueOf)
|
||||
.union(source2)
|
||||
.join(source3)
|
||||
.sink(x -> System.out.println("Sink " + x));
|
||||
JobGraph jobGraph = new JobGraphBuilder(context.getStreamSinks()).build();
|
||||
LOG.info("Digraph {}", jobGraph.generateDigraph());
|
||||
assertEquals(jobGraph.getJobVertices().size(), 8);
|
||||
|
||||
JobGraphOptimizer graphOptimizer = new JobGraphOptimizer(jobGraph);
|
||||
JobGraph optimizedJobGraph = graphOptimizer.optimize();
|
||||
optimizedJobGraph.printJobGraph();
|
||||
LOG.info("Optimized graph {}", optimizedJobGraph.generateDigraph());
|
||||
assertEquals(optimizedJobGraph.getJobVertices().size(), 5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOptimizeHybridStream() {
|
||||
StreamingContext context = StreamingContext.buildContext();
|
||||
DataStream<Integer> source1 = DataStreamSource.fromCollection(context,
|
||||
Lists.newArrayList(1 ,2 ,3));
|
||||
DataStream<String> source2 = DataStreamSource.fromCollection(context,
|
||||
Lists.newArrayList("1", "2", "3"));
|
||||
source1.asPythonStream()
|
||||
.map(pyFunc(1))
|
||||
.filter(pyFunc(2))
|
||||
.union(source2.asPythonStream().filter(pyFunc(3)).map(pyFunc(4)))
|
||||
.asJavaStream()
|
||||
.sink(x -> System.out.println("Sink " + x));
|
||||
JobGraph jobGraph = new JobGraphBuilder(context.getStreamSinks()).build();
|
||||
LOG.info("Digraph {}", jobGraph.generateDigraph());
|
||||
assertEquals(jobGraph.getJobVertices().size(), 8);
|
||||
|
||||
JobGraphOptimizer graphOptimizer = new JobGraphOptimizer(jobGraph);
|
||||
JobGraph optimizedJobGraph = graphOptimizer.optimize();
|
||||
optimizedJobGraph.printJobGraph();
|
||||
LOG.info("Optimized graph {}", optimizedJobGraph.generateDigraph());
|
||||
assertEquals(optimizedJobGraph.getJobVertices().size(), 6);
|
||||
}
|
||||
|
||||
private PythonFunction pyFunc(int number) {
|
||||
return new PythonFunction("module", "func" + number);
|
||||
}
|
||||
|
||||
}
|
||||
+1
-1
@@ -2,9 +2,9 @@ package io.ray.streaming.runtime.core.processor;
|
||||
|
||||
import io.ray.streaming.operator.OneInputOperator;
|
||||
import io.ray.streaming.operator.OperatorType;
|
||||
import io.ray.streaming.operator.SourceOperator;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import io.ray.streaming.operator.TwoInputOperator;
|
||||
import io.ray.streaming.operator.impl.SourceOperator;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
||||
+1
-1
@@ -1,7 +1,7 @@
|
||||
package io.ray.streaming.runtime.core.processor;
|
||||
|
||||
import io.ray.streaming.message.Record;
|
||||
import io.ray.streaming.operator.impl.SourceOperator;
|
||||
import io.ray.streaming.operator.SourceOperator;
|
||||
|
||||
/**
|
||||
* The processor for the stream sources, containing a SourceOperator.
|
||||
|
||||
+5
-5
@@ -30,7 +30,7 @@ public class GraphManagerImpl implements GraphManager {
|
||||
ExecutionGraph executionGraph = setupStructure(jobGraph);
|
||||
|
||||
// set max parallelism
|
||||
int maxParallelism = jobGraph.getJobVertexList().stream()
|
||||
int maxParallelism = jobGraph.getJobVertices().stream()
|
||||
.map(JobVertex::getParallelism)
|
||||
.max(Integer::compareTo).get();
|
||||
executionGraph.setMaxParallelism(maxParallelism);
|
||||
@@ -49,7 +49,7 @@ public class GraphManagerImpl implements GraphManager {
|
||||
// create vertex
|
||||
Map<Integer, ExecutionJobVertex> exeJobVertexMap = new LinkedHashMap<>();
|
||||
long buildTime = executionGraph.getBuildTime();
|
||||
for (JobVertex jobVertex : jobGraph.getJobVertexList()) {
|
||||
for (JobVertex jobVertex : jobGraph.getJobVertices()) {
|
||||
int jobVertexId = jobVertex.getVertexId();
|
||||
exeJobVertexMap.put(jobVertexId,
|
||||
new ExecutionJobVertex(
|
||||
@@ -60,7 +60,7 @@ public class GraphManagerImpl implements GraphManager {
|
||||
}
|
||||
|
||||
// connect vertex
|
||||
jobGraph.getJobEdgeList().stream().forEach(jobEdge -> {
|
||||
jobGraph.getJobEdges().forEach(jobEdge -> {
|
||||
ExecutionJobVertex source = exeJobVertexMap.get(jobEdge.getSrcVertexId());
|
||||
ExecutionJobVertex target = exeJobVertexMap.get(jobEdge.getTargetVertexId());
|
||||
|
||||
@@ -70,8 +70,8 @@ public class GraphManagerImpl implements GraphManager {
|
||||
source.getOutputEdges().add(executionJobEdge);
|
||||
target.getInputEdges().add(executionJobEdge);
|
||||
|
||||
source.getExecutionVertices().stream().forEach(vertex -> {
|
||||
target.getExecutionVertices().stream().forEach(outputVertex -> {
|
||||
source.getExecutionVertices().forEach(vertex -> {
|
||||
target.getExecutionVertices().forEach(outputVertex -> {
|
||||
ExecutionEdge executionEdge = new ExecutionEdge(vertex, outputVertex, executionJobEdge);
|
||||
vertex.getOutputEdges().add(executionEdge);
|
||||
outputVertex.getInputEdges().add(executionEdge);
|
||||
|
||||
+26
-6
@@ -7,6 +7,7 @@ import io.ray.streaming.api.partition.Partition;
|
||||
import io.ray.streaming.operator.Operator;
|
||||
import io.ray.streaming.python.PythonFunction;
|
||||
import io.ray.streaming.python.PythonOperator;
|
||||
import io.ray.streaming.python.PythonOperator.ChainedPythonOperator;
|
||||
import io.ray.streaming.python.PythonPartition;
|
||||
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionEdge;
|
||||
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
|
||||
@@ -77,6 +78,7 @@ public class GraphPbBuilder {
|
||||
executionVertexBuilder.setOperator(
|
||||
ByteString.copyFrom(
|
||||
serializeOperator(executionVertex.getStreamOperator())));
|
||||
executionVertexBuilder.setChained(isPythonChainedOperator(executionVertex.getStreamOperator()));
|
||||
executionVertexBuilder.setWorkerActor(
|
||||
ByteString.copyFrom(
|
||||
((NativeActorHandle) (executionVertex.getWorkerActor())).toBytes()));
|
||||
@@ -104,17 +106,35 @@ public class GraphPbBuilder {
|
||||
|
||||
private byte[] serializeOperator(Operator operator) {
|
||||
if (operator instanceof PythonOperator) {
|
||||
PythonOperator pythonOperator = (PythonOperator) operator;
|
||||
return serializer.serialize(Arrays.asList(
|
||||
serializeFunction(pythonOperator.getFunction()),
|
||||
pythonOperator.getModuleName(),
|
||||
pythonOperator.getClassName()
|
||||
));
|
||||
if (isPythonChainedOperator(operator)) {
|
||||
return serializePythonChainedOperator((ChainedPythonOperator) operator);
|
||||
} else {
|
||||
PythonOperator pythonOperator = (PythonOperator) operator;
|
||||
return serializer.serialize(Arrays.asList(
|
||||
serializeFunction(pythonOperator.getFunction()),
|
||||
pythonOperator.getModuleName(),
|
||||
pythonOperator.getClassName()
|
||||
));
|
||||
}
|
||||
} else {
|
||||
return new byte[0];
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isPythonChainedOperator(Operator operator) {
|
||||
return operator instanceof ChainedPythonOperator;
|
||||
}
|
||||
|
||||
private byte[] serializePythonChainedOperator(ChainedPythonOperator operator) {
|
||||
List<byte[]> serializedOperators = operator.getOperators().stream()
|
||||
.map(this::serializeOperator).collect(Collectors.toList());
|
||||
return serializer.serialize(Arrays.asList(
|
||||
serializedOperators,
|
||||
operator.getConfigs()
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
private byte[] serializeFunction(Function function) {
|
||||
if (function instanceof PythonFunction) {
|
||||
PythonFunction pyFunc = (PythonFunction) function;
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
package io.ray.streaming.runtime.worker.tasks;
|
||||
|
||||
import io.ray.streaming.operator.impl.SourceOperator;
|
||||
import io.ray.streaming.operator.SourceOperator;
|
||||
import io.ray.streaming.runtime.core.processor.Processor;
|
||||
import io.ray.streaming.runtime.core.processor.SourceProcessor;
|
||||
import io.ray.streaming.runtime.worker.JobWorker;
|
||||
|
||||
+5
-4
@@ -7,6 +7,7 @@ import io.ray.streaming.api.stream.DataStreamSource;
|
||||
import io.ray.streaming.api.stream.StreamSink;
|
||||
import io.ray.streaming.jobgraph.JobGraph;
|
||||
import io.ray.streaming.jobgraph.JobGraphBuilder;
|
||||
import io.ray.streaming.jobgraph.JobVertex;
|
||||
import io.ray.streaming.runtime.BaseUnitTest;
|
||||
import io.ray.streaming.runtime.config.StreamingConfig;
|
||||
import io.ray.streaming.runtime.config.master.ResourceConfig;
|
||||
@@ -40,10 +41,10 @@ public class ExecutionGraphTest extends BaseUnitTest {
|
||||
ExecutionGraph executionGraph = buildExecutionGraph(graphManager, jobGraph);
|
||||
List<ExecutionJobVertex> executionJobVertices = executionGraph.getExecutionJobVertexList();
|
||||
|
||||
Assert.assertEquals(executionJobVertices.size(), jobGraph.getJobVertexList().size());
|
||||
Assert.assertEquals(executionJobVertices.size(), jobGraph.getJobVertices().size());
|
||||
|
||||
int totalVertexNum = jobGraph.getJobVertexList().stream()
|
||||
.mapToInt(vertex -> vertex.getParallelism()).sum();
|
||||
int totalVertexNum = jobGraph.getJobVertices().stream()
|
||||
.mapToInt(JobVertex::getParallelism).sum();
|
||||
Assert.assertEquals(executionGraph.getAllExecutionVertices().size(), totalVertexNum);
|
||||
Assert.assertEquals(executionGraph.getAllExecutionVertices().size(),
|
||||
executionGraph.getExecutionVertexIdGenerator().get());
|
||||
@@ -66,7 +67,7 @@ public class ExecutionGraphTest extends BaseUnitTest {
|
||||
List<ExecutionVertex> downStreamVertices = downStream.getExecutionVertices();
|
||||
upStreamVertices.forEach(vertex -> {
|
||||
Assert.assertEquals(vertex.getResource().get(ResourceType.CPU.name()), 2.0);
|
||||
vertex.getOutputEdges().stream().forEach(upStreamOutPutEdge -> {
|
||||
vertex.getOutputEdges().forEach(upStreamOutPutEdge -> {
|
||||
Assert.assertTrue(downStreamVertices.contains(upStreamOutPutEdge.getTargetExecutionVertex()));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -39,7 +39,7 @@ fi
|
||||
if [ $exit_code -ne 2 ] && [ $exit_code -ne 0 ] ; then
|
||||
if [ -d "/tmp/ray_streaming_java_test_output/" ] ; then
|
||||
echo "all test output"
|
||||
for f in /tmp/ray_streaming_java_test_output/*; do
|
||||
for f in /tmp/ray_streaming_java_test_output/*.{log,xml}; do
|
||||
if [ -f "$f" ]; then
|
||||
echo "Cat file $f"
|
||||
cat "$f"
|
||||
|
||||
@@ -94,6 +94,18 @@ class Stream(ABC):
|
||||
def get_language(self):
|
||||
pass
|
||||
|
||||
def forward(self):
|
||||
"""Set the partition function of this {@link Stream} so that output
|
||||
elements are forwarded to next operator locally."""
|
||||
self._gateway_client().call_method(self._j_stream, "forward")
|
||||
return self
|
||||
|
||||
def disable_chain(self):
|
||||
"""Disable chain for this stream so that it will be run in a separate
|
||||
task."""
|
||||
self._gateway_client().call_method(self._j_stream, "disableChain")
|
||||
return self
|
||||
|
||||
def _gateway_client(self):
|
||||
return self.get_streaming_context()._gateway_client
|
||||
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import enum
|
||||
import importlib
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray import streaming
|
||||
from ray.streaming import function
|
||||
from ray.streaming import message
|
||||
from ray.streaming.collector import Collector
|
||||
from ray.streaming.runtime import gateway_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OperatorType(enum.Enum):
|
||||
SOURCE = 0 # Sources are where your program reads its input from
|
||||
@@ -227,15 +231,93 @@ class UnionOperator(StreamOperator, OneInputOperator):
|
||||
self.collect(record)
|
||||
|
||||
|
||||
_function_to_operator = {
|
||||
function.SourceFunction: SourceOperator,
|
||||
function.MapFunction: MapOperator,
|
||||
function.FlatMapFunction: FlatMapOperator,
|
||||
function.FilterFunction: FilterOperator,
|
||||
function.KeyFunction: KeyByOperator,
|
||||
function.ReduceFunction: ReduceOperator,
|
||||
function.SinkFunction: SinkOperator,
|
||||
}
|
||||
class ChainedOperator(StreamOperator, ABC):
|
||||
class ForwardCollector(Collector):
|
||||
def __init__(self, succeeding_operator):
|
||||
self.succeeding_operator = succeeding_operator
|
||||
|
||||
def collect(self, record):
|
||||
self.succeeding_operator.process_element(record)
|
||||
|
||||
def __init__(self, operators, configs):
|
||||
super().__init__(operators[0].func)
|
||||
self.operators = operators
|
||||
self.configs = configs
|
||||
|
||||
def open(self, collectors, runtime_context):
|
||||
# Dont' call super.open() as we `open` every operator separately.
|
||||
num_operators = len(self.operators)
|
||||
succeeding_collectors = [
|
||||
ChainedOperator.ForwardCollector(operator)
|
||||
for operator in self.operators[1:]
|
||||
]
|
||||
for i in range(0, num_operators - 1):
|
||||
forward_collectors = [succeeding_collectors[i]]
|
||||
self.operators[i].open(
|
||||
forward_collectors,
|
||||
self.__create_runtime_context(runtime_context, i))
|
||||
self.operators[-1].open(
|
||||
collectors,
|
||||
self.__create_runtime_context(runtime_context, num_operators - 1))
|
||||
|
||||
def operator_type(self) -> OperatorType:
|
||||
return self.operators[0].operator_type()
|
||||
|
||||
def __create_runtime_context(self, runtime_context, index):
|
||||
def get_config():
|
||||
return self.configs[index]
|
||||
|
||||
runtime_context.get_config = get_config
|
||||
return runtime_context
|
||||
|
||||
@staticmethod
|
||||
def new_chained_operator(operators, configs):
|
||||
operator_type = operators[0].operator_type()
|
||||
logger.info(
|
||||
"Building ChainedOperator from operators {} and configs {}."
|
||||
.format(operators, configs))
|
||||
if operator_type == OperatorType.SOURCE:
|
||||
return ChainedSourceOperator(operators, configs)
|
||||
elif operator_type == OperatorType.ONE_INPUT:
|
||||
return ChainedOneInputOperator(operators, configs)
|
||||
elif operator_type == OperatorType.TWO_INPUT:
|
||||
return ChainedTwoInputOperator(operators, configs)
|
||||
else:
|
||||
raise Exception("Current operator type is not supported")
|
||||
|
||||
|
||||
class ChainedSourceOperator(ChainedOperator):
|
||||
def __init__(self, operators, configs):
|
||||
super().__init__(operators, configs)
|
||||
|
||||
def run(self):
|
||||
self.operators[0].run()
|
||||
|
||||
|
||||
class ChainedOneInputOperator(ChainedOperator):
|
||||
def __init__(self, operators, configs):
|
||||
super().__init__(operators, configs)
|
||||
|
||||
def process_element(self, record):
|
||||
self.operators[0].process_element(record)
|
||||
|
||||
|
||||
class ChainedTwoInputOperator(ChainedOperator):
|
||||
def __init__(self, operators, configs):
|
||||
super().__init__(operators, configs)
|
||||
|
||||
def process_element(self, record1, record2):
|
||||
self.operators[0].process_element(record1, record2)
|
||||
|
||||
|
||||
def load_chained_operator(chained_operator_bytes: bytes):
|
||||
"""Load chained operator from serialized operators and configs"""
|
||||
serialized_operators, configs = gateway_client.deserialize(
|
||||
chained_operator_bytes)
|
||||
operators = [
|
||||
load_operator(desc_bytes) for desc_bytes in serialized_operators
|
||||
]
|
||||
return ChainedOperator.new_chained_operator(operators, configs)
|
||||
|
||||
|
||||
def load_operator(descriptor_operator_bytes: bytes):
|
||||
@@ -267,6 +349,17 @@ def load_operator(descriptor_operator_bytes: bytes):
|
||||
return cls()
|
||||
|
||||
|
||||
_function_to_operator = {
|
||||
function.SourceFunction: SourceOperator,
|
||||
function.MapFunction: MapOperator,
|
||||
function.FlatMapFunction: FlatMapOperator,
|
||||
function.FilterFunction: FilterOperator,
|
||||
function.KeyFunction: KeyByOperator,
|
||||
function.ReduceFunction: ReduceOperator,
|
||||
function.SinkFunction: SinkOperator,
|
||||
}
|
||||
|
||||
|
||||
def create_operator_with_func(func: function.Function):
|
||||
"""Create an operator according to a :class:`function.Function`
|
||||
|
||||
|
||||
@@ -60,6 +60,17 @@ class RoundRobinPartition(Partition):
|
||||
return self.__partitions
|
||||
|
||||
|
||||
class ForwardPartition(Partition):
|
||||
"""Default partition for operator if the operator can be chained with
|
||||
succeeding operators."""
|
||||
|
||||
def __init__(self):
|
||||
self.__partitions = [0]
|
||||
|
||||
def partition(self, key_record, num_partition: int):
|
||||
return self.__partitions
|
||||
|
||||
|
||||
class SimplePartition(Partition):
|
||||
"""Wrap a python function as subclass of :class:`Partition`"""
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import enum
|
||||
import logging
|
||||
|
||||
import ray
|
||||
import ray.streaming.generated.remote_call_pb2 as remote_call_pb
|
||||
@@ -6,6 +7,8 @@ import ray.streaming.operator as operator
|
||||
import ray.streaming.partition as partition
|
||||
from ray.streaming.generated.streaming_pb2 import Language
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeType(enum.Enum):
|
||||
"""
|
||||
@@ -43,7 +46,13 @@ class ExecutionVertex:
|
||||
self.parallelism = vertex_pb.parallelism
|
||||
if vertex_pb.language == Language.PYTHON:
|
||||
operator_bytes = vertex_pb.operator # python operator descriptor
|
||||
self.stream_operator = operator.load_operator(operator_bytes)
|
||||
if vertex_pb.chained:
|
||||
logger.info("Load chained operator")
|
||||
self.stream_operator = operator.load_chained_operator(
|
||||
operator_bytes)
|
||||
else:
|
||||
logger.info("Load operator")
|
||||
self.stream_operator = operator.load_operator(operator_bytes)
|
||||
self.worker_actor = ray.actor.ActorHandle. \
|
||||
_deserialization_helper(vertex_pb.worker_actor)
|
||||
self.container_id = vertex_pb.container_id
|
||||
|
||||
@@ -30,12 +30,13 @@ message ExecutionVertexContext {
|
||||
int32 parallelism = 5;
|
||||
// serialized operator
|
||||
bytes operator = 6;
|
||||
bytes worker_actor = 7;
|
||||
string container_id = 8;
|
||||
uint64 build_time = 9;
|
||||
Language language = 10;
|
||||
map<string, string> config = 11;
|
||||
map<string, double> resource = 12;
|
||||
bool chained = 7;
|
||||
bytes worker_actor = 8;
|
||||
string container_id = 9;
|
||||
uint64 build_time = 10;
|
||||
Language language = 11;
|
||||
map<string, string> config = 12;
|
||||
map<string, double> resource = 13;
|
||||
}
|
||||
|
||||
// vertices
|
||||
|
||||
Reference in New Issue
Block a user