[Streaming] operator chain (#8910)

This commit is contained in:
chaokunyang
2020-06-18 15:11:07 +08:00
committed by GitHub
parent 003cec87b4
commit 5edddf6eac
39 changed files with 1058 additions and 140 deletions
@@ -6,6 +6,7 @@ import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.client.JobClient;
import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.jobgraph.JobGraphBuilder;
import io.ray.streaming.jobgraph.JobGraphOptimizer;
import io.ray.streaming.util.Config;
import java.io.Serializable;
import java.util.ArrayList;
@@ -56,7 +57,8 @@ public class StreamingContext implements Serializable {
*/
public void execute(String jobName) {
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(this.streamSinks, jobName);
this.jobGraph = jobGraphBuilder.build();
JobGraph originalJobGraph = jobGraphBuilder.build();
this.jobGraph = new JobGraphOptimizer(originalJobGraph).optimize();
jobGraph.printJobGraph();
LOG.info("JobGraph digraph\n{}", jobGraph.generateDigraph());
@@ -0,0 +1,19 @@
package io.ray.streaming.api.partition.impl;
import io.ray.streaming.api.partition.Partition;
/**
* Default partition for operator if the operator can be chained with succeeding operators.
* Partition will be set to {@link RoundRobinPartition} if the operator can't be chiained with
* succeeding operators.
*
* @param <T> Type of the input record.
*/
public class ForwardPartition<T> implements Partition<T> {
private int[] partitions = new int[] {0};
@Override
public int[] partition(T record, int numPartition) {
return partitions;
}
}
@@ -3,8 +3,7 @@ package io.ray.streaming.api.stream;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.function.impl.SourceFunction;
import io.ray.streaming.api.function.internal.CollectionSourceFunction;
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
import io.ray.streaming.operator.impl.SourceOperator;
import io.ray.streaming.operator.impl.SourceOperatorImpl;
import java.util.Collection;
/**
@@ -15,7 +14,7 @@ import java.util.Collection;
public class DataStreamSource<T> extends DataStream<T> implements StreamSource<T> {
private DataStreamSource(StreamingContext streamingContext, SourceFunction<T> sourceFunction) {
super(streamingContext, new SourceOperator<>(sourceFunction), new RoundRobinPartition<>());
super(streamingContext, new SourceOperatorImpl<>(sourceFunction));
}
public static <T> DataStreamSource<T> fromSource(
@@ -2,6 +2,7 @@ package io.ray.streaming.api.stream;
import io.ray.streaming.api.function.impl.JoinFunction;
import io.ray.streaming.api.function.impl.KeyFunction;
import io.ray.streaming.operator.impl.JoinOperator;
import java.io.Serializable;
/**
@@ -9,40 +10,42 @@ import java.io.Serializable;
*
* @param <L> Type of the data in the left stream.
* @param <R> Type of the data in the right stream.
* @param <J> Type of the data in the joined stream.
* @param <O> Type of the data in the joined stream.
*/
public class JoinStream<L, R, J> extends DataStream<L> {
public class JoinStream<L, R, O> extends DataStream<L> {
private final DataStream<R> rightStream;
public JoinStream(DataStream<L> leftStream, DataStream<R> rightStream) {
super(leftStream, null);
super(leftStream, new JoinOperator<>());
this.rightStream = rightStream;
}
public DataStream<R> getRightStream() {
return rightStream;
}
/**
* Apply key-by to the left join stream.
*/
public <K> Where<L, R, J, K> where(KeyFunction<L, K> keyFunction) {
public <K> Where<K> where(KeyFunction<L, K> keyFunction) {
return new Where<>(this, keyFunction);
}
/**
* Where clause of the join transformation.
*
* @param <L> Type of the data in the left stream.
* @param <R> Type of the data in the right stream.
* @param <J> Type of the data in the joined stream.
* @param <K> Type of the join key.
*/
class Where<L, R, J, K> implements Serializable {
private JoinStream<L, R, J> joinStream;
class Where<K> implements Serializable {
private JoinStream<L, R, O> joinStream;
private KeyFunction<L, K> leftKeyByFunction;
public Where(JoinStream<L, R, J> joinStream, KeyFunction<L, K> leftKeyByFunction) {
Where(JoinStream<L, R, O> joinStream, KeyFunction<L, K> leftKeyByFunction) {
this.joinStream = joinStream;
this.leftKeyByFunction = leftKeyByFunction;
}
public Equal<L, R, J, K> equalLo(KeyFunction<R, K> rightKeyFunction) {
public Equal<K> equalTo(KeyFunction<R, K> rightKeyFunction) {
return new Equal<>(joinStream, leftKeyByFunction, rightKeyFunction);
}
}
@@ -50,26 +53,25 @@ public class JoinStream<L, R, J> extends DataStream<L> {
/**
* Equal clause of the join transformation.
*
* @param <L> Type of the data in the left stream.
* @param <R> Type of the data in the right stream.
* @param <J> Type of the data in the joined stream.
* @param <K> Type of the join key.
*/
class Equal<L, R, J, K> implements Serializable {
private JoinStream<L, R, J> joinStream;
class Equal<K> implements Serializable {
private JoinStream<L, R, O> joinStream;
private KeyFunction<L, K> leftKeyByFunction;
private KeyFunction<R, K> rightKeyByFunction;
public Equal(JoinStream<L, R, J> joinStream, KeyFunction<L, K> leftKeyByFunction,
KeyFunction<R, K> rightKeyByFunction) {
Equal(JoinStream<L, R, O> joinStream, KeyFunction<L, K> leftKeyByFunction,
KeyFunction<R, K> rightKeyByFunction) {
this.joinStream = joinStream;
this.leftKeyByFunction = leftKeyByFunction;
this.rightKeyByFunction = rightKeyByFunction;
}
public DataStream<J> with(JoinFunction<L, R, J> joinFunction) {
return (DataStream<J>) joinStream;
@SuppressWarnings("unchecked")
public DataStream<O> with(JoinFunction<L, R, O> joinFunction) {
JoinOperator joinOperator = (JoinOperator) joinStream.getOperator();
joinOperator.setFunction(joinFunction);
return (DataStream<O>) joinStream;
}
}
@@ -4,7 +4,8 @@ import com.google.common.base.Preconditions;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
import io.ray.streaming.api.partition.impl.ForwardPartition;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.operator.Operator;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.python.PythonPartition;
@@ -30,8 +31,7 @@ public abstract class Stream<S extends Stream<S, T>, T>
private Stream originalStream;
public Stream(StreamingContext streamingContext, StreamOperator streamOperator) {
this(streamingContext, null, streamOperator,
selectPartition(streamOperator));
this(streamingContext, null, streamOperator, getForwardPartition(streamOperator));
}
public Stream(StreamingContext streamingContext,
@@ -42,7 +42,7 @@ public abstract class Stream<S extends Stream<S, T>, T>
public Stream(Stream inputStream, StreamOperator streamOperator) {
this(inputStream.getStreamingContext(), inputStream, streamOperator,
selectPartition(streamOperator));
getForwardPartition(streamOperator));
}
public Stream(Stream inputStream, StreamOperator streamOperator, Partition<T> partition) {
@@ -50,9 +50,9 @@ public abstract class Stream<S extends Stream<S, T>, T>
}
protected Stream(StreamingContext streamingContext,
Stream inputStream,
StreamOperator streamOperator,
Partition<T> partition) {
Stream inputStream,
StreamOperator streamOperator,
Partition<T> partition) {
this.streamingContext = streamingContext;
this.inputStream = inputStream;
this.operator = streamOperator;
@@ -73,15 +73,16 @@ public abstract class Stream<S extends Stream<S, T>, T>
this.streamingContext = originalStream.getStreamingContext();
this.inputStream = originalStream.getInputStream();
this.operator = originalStream.getOperator();
Preconditions.checkNotNull(operator);
}
@SuppressWarnings("unchecked")
private static <T> Partition<T> selectPartition(Operator operator) {
private static <T> Partition<T> getForwardPartition(Operator operator) {
switch (operator.getLanguage()) {
case PYTHON:
return (Partition<T>) PythonPartition.RoundRobinPartition;
return (Partition<T>) PythonPartition.ForwardPartition;
case JAVA:
return new RoundRobinPartition<>();
return new ForwardPartition<>();
default:
throw new UnsupportedOperationException(
"Unsupported language " + operator.getLanguage());
@@ -165,5 +166,29 @@ public abstract class Stream<S extends Stream<S, T>, T>
return originalStream;
}
/**
* Set chain strategy for this stream
*/
public S withChainStrategy(ChainStrategy chainStrategy) {
Preconditions.checkArgument(!isProxyStream());
operator.setChainStrategy(chainStrategy);
return self();
}
/**
* Disable chain for this stream
*/
public S disableChain() {
return withChainStrategy(ChainStrategy.NEVER);
}
/**
* Set the partition function of this {@link Stream} so that output elements are forwarded to
* next operator locally.
*/
public S forward() {
return setPartition(getForwardPartition(operator));
}
public abstract Language getLanguage();
}
@@ -16,6 +16,8 @@ public class UnionStream<T> extends DataStream<T> {
private List<DataStream<T>> unionStreams;
public UnionStream(DataStream<T> input, List<DataStream<T>> streams) {
// Union stream does not create a physical operation, so we don't have to set partition
// function for it.
super(input, new UnionOperator());
this.unionStreams = new ArrayList<>();
streams.forEach(this::addStream);
@@ -5,6 +5,8 @@ import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -17,15 +19,24 @@ public class JobGraph implements Serializable {
private final String jobName;
private final Map<String, String> jobConfig;
private List<JobVertex> jobVertexList;
private List<JobEdge> jobEdgeList;
private List<JobVertex> jobVertices;
private List<JobEdge> jobEdges;
private String digraph;
public JobGraph(String jobName, Map<String, String> jobConfig) {
this.jobName = jobName;
this.jobConfig = jobConfig;
this.jobVertexList = new ArrayList<>();
this.jobEdgeList = new ArrayList<>();
this.jobVertices = new ArrayList<>();
this.jobEdges = new ArrayList<>();
}
public JobGraph(String jobName, Map<String, String> jobConfig,
List<JobVertex> jobVertices, List<JobEdge> jobEdges) {
this.jobName = jobName;
this.jobConfig = jobConfig;
this.jobVertices = jobVertices;
this.jobEdges = jobEdges;
generateDigraph();
}
/**
@@ -36,12 +47,12 @@ public class JobGraph implements Serializable {
*/
public String generateDigraph() {
StringBuilder digraph = new StringBuilder();
digraph.append("digraph ").append(jobName + " ").append(" {");
digraph.append("digraph ").append(jobName).append(" ").append(" {");
for (JobEdge jobEdge : jobEdgeList) {
for (JobEdge jobEdge : jobEdges) {
String srcNode = null;
String targetNode = null;
for (JobVertex jobVertex : jobVertexList) {
for (JobVertex jobVertex : jobVertices) {
if (jobEdge.getSrcVertexId() == jobVertex.getVertexId()) {
srcNode = jobVertex.getVertexId() + "-" + jobVertex.getStreamOperator().getName();
} else if (jobEdge.getTargetVertexId() == jobVertex.getVertexId()) {
@@ -49,7 +60,7 @@ public class JobGraph implements Serializable {
}
}
digraph.append(System.getProperty("line.separator"));
digraph.append(srcNode).append(" -> ").append(targetNode);
digraph.append(String.format(" \"%s\" -> \"%s\"", srcNode, targetNode));
}
digraph.append(System.getProperty("line.separator")).append("}");
@@ -58,19 +69,47 @@ public class JobGraph implements Serializable {
}
public void addVertex(JobVertex vertex) {
this.jobVertexList.add(vertex);
this.jobVertices.add(vertex);
}
public void addEdge(JobEdge jobEdge) {
this.jobEdgeList.add(jobEdge);
this.jobEdges.add(jobEdge);
}
public List<JobVertex> getJobVertexList() {
return jobVertexList;
public List<JobVertex> getJobVertices() {
return jobVertices;
}
public List<JobEdge> getJobEdgeList() {
return jobEdgeList;
public List<JobVertex> getSourceVertices() {
return jobVertices.stream()
.filter(v -> v.getVertexType() == VertexType.SOURCE)
.collect(Collectors.toList());
}
public List<JobVertex> getSinkVertices() {
return jobVertices.stream()
.filter(v -> v.getVertexType() == VertexType.SINK)
.collect(Collectors.toList());
}
public JobVertex getVertex(int vertexId) {
return jobVertices.stream().filter(v -> v.getVertexId() == vertexId).findFirst().get();
}
public List<JobEdge> getJobEdges() {
return jobEdges;
}
public Set<JobEdge> getVertexInputEdges(int vertexId) {
return jobEdges.stream()
.filter(jobEdge -> jobEdge.getTargetVertexId() == vertexId)
.collect(Collectors.toSet());
}
public Set<JobEdge> getVertexOutputEdges(int vertexId) {
return jobEdges.stream()
.filter(jobEdge -> jobEdge.getSrcVertexId() == vertexId)
.collect(Collectors.toSet());
}
public String getDigraph() {
@@ -90,17 +129,17 @@ public class JobGraph implements Serializable {
return;
}
LOG.info("Printing job graph:");
for (JobVertex jobVertex : jobVertexList) {
for (JobVertex jobVertex : jobVertices) {
LOG.info(jobVertex.toString());
}
for (JobEdge jobEdge : jobEdgeList) {
for (JobEdge jobEdge : jobEdges) {
LOG.info(jobEdge.toString());
}
}
public boolean isCrossLanguageGraph() {
Language language = jobVertexList.get(0).getLanguage();
for (JobVertex jobVertex : jobVertexList) {
Language language = jobVertices.get(0).getLanguage();
for (JobVertex jobVertex : jobVertices) {
if (jobVertex.getLanguage() != language) {
return true;
}
@@ -2,6 +2,7 @@ package io.ray.streaming.jobgraph;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.JoinStream;
import io.ray.streaming.api.stream.Stream;
import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.api.stream.StreamSource;
@@ -26,7 +27,7 @@ public class JobGraphBuilder {
private List<StreamSink> streamSinkList;
public JobGraphBuilder(List<StreamSink> streamSinkList) {
this(streamSinkList, "job-" + System.currentTimeMillis());
this(streamSinkList, "job_" + System.currentTimeMillis());
}
public JobGraphBuilder(List<StreamSink> streamSinkList, String jobName) {
@@ -61,18 +62,20 @@ public class JobGraphBuilder {
"Reference stream should be skipped.");
int vertexId = stream.getId();
int parallelism = stream.getParallelism();
Map<String, String> config = stream.getConfig();
JobVertex jobVertex;
if (stream instanceof StreamSink) {
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator);
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator, config);
Stream parentStream = stream.getInputStream();
int inputVertexId = parentStream.getId();
JobEdge jobEdge = new JobEdge(inputVertexId, vertexId, parentStream.getPartition());
this.jobGraph.addEdge(jobEdge);
processStream(parentStream);
} else if (stream instanceof StreamSource) {
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SOURCE, streamOperator);
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SOURCE, streamOperator, config);
} else if (stream instanceof DataStream || stream instanceof PythonDataStream) {
jobVertex = new JobVertex(vertexId, parallelism, VertexType.TRANSFORMATION, streamOperator);
jobVertex = new JobVertex(
vertexId, parallelism, VertexType.TRANSFORMATION, streamOperator, config);
Stream parentStream = stream.getInputStream();
int inputVertexId = parentStream.getId();
JobEdge jobEdge = new JobEdge(inputVertexId, vertexId, parentStream.getPartition());
@@ -92,10 +95,17 @@ public class JobGraphBuilder {
this.jobGraph.addEdge(otherEdge);
processStream(otherStream);
}
// process join stream
if (stream instanceof JoinStream) {
DataStream rightStream = ((JoinStream) stream).getRightStream();
this.jobGraph.addEdge(
new JobEdge(rightStream.getId(), vertexId, rightStream.getPartition()));
processStream(rightStream);
}
} else {
throw new UnsupportedOperationException("Unsupported stream: " + stream);
}
jobVertex.setConfig(stream.getConfig());
this.jobGraph.addVertex(jobVertex);
}
@@ -0,0 +1,187 @@
package io.ray.streaming.jobgraph;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.partition.impl.ForwardPartition;
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.operator.chain.ChainedOperator;
import io.ray.streaming.python.PythonOperator;
import io.ray.streaming.python.PythonOperator.ChainedPythonOperator;
import io.ray.streaming.python.PythonPartition;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
/**
* Optimize job graph by chaining some operators so that some operators can be run in the
* same thread.
*/
public class JobGraphOptimizer {
private final JobGraph jobGraph;
private Set<JobVertex> visited = new HashSet<>();
// vertex id -> vertex
private Map<Integer, JobVertex> vertexMap;
private Map<JobVertex, Set<JobEdge>> outputEdgesMap;
// tail vertex id -> mergedVertex
private Map<Integer, Pair<JobVertex, List<JobVertex>>> mergedVertexMap;
public JobGraphOptimizer(JobGraph jobGraph) {
this.jobGraph = jobGraph;
vertexMap = jobGraph.getJobVertices().stream()
.collect(Collectors.toMap(JobVertex::getVertexId, Function.identity()));
outputEdgesMap = vertexMap.keySet().stream().collect(Collectors.toMap(
id -> vertexMap.get(id), id -> new HashSet<>(jobGraph.getVertexOutputEdges(id))));
mergedVertexMap = new HashMap<>();
}
public JobGraph optimize() {
// Deep-first traverse nodes from source to sink to merge vertices that can be chained
// together.
jobGraph.getSourceVertices().forEach(vertex -> {
List<JobVertex> verticesToMerge = new ArrayList<>();
verticesToMerge.add(vertex);
mergeVerticesRecursively(vertex, verticesToMerge);
});
List<JobVertex> vertices = mergedVertexMap.values().stream()
.map(Pair::getLeft).collect(Collectors.toList());
return new JobGraph(jobGraph.getJobName(), jobGraph.getJobConfig(), vertices, createEdges());
}
private void mergeVerticesRecursively(JobVertex vertex, List<JobVertex> verticesToMerge) {
if (!visited.contains(vertex)) {
visited.add(vertex);
Set<JobEdge> outputEdges = outputEdgesMap.get(vertex);
if (outputEdges.isEmpty()) {
mergeAndAddVertex(verticesToMerge);
} else {
outputEdges.forEach(edge -> {
JobVertex succeedingVertex = vertexMap.get(edge.getTargetVertexId());
if (canBeChained(vertex, succeedingVertex, edge)) {
verticesToMerge.add(succeedingVertex);
mergeVerticesRecursively(succeedingVertex, verticesToMerge);
} else {
mergeAndAddVertex(verticesToMerge);
List<JobVertex> newMergedVertices = new ArrayList<>();
newMergedVertices.add(succeedingVertex);
mergeVerticesRecursively(succeedingVertex, newMergedVertices);
}
});
}
}
}
private void mergeAndAddVertex(List<JobVertex> verticesToMerge) {
JobVertex mergedVertex;
JobVertex headVertex = verticesToMerge.get(0);
Language language = headVertex.getLanguage();
if (verticesToMerge.size() == 1) {
// no chain
mergedVertex = headVertex;
} else {
List<StreamOperator> operators = verticesToMerge.stream()
.map(v -> vertexMap.get(v.getVertexId()).getStreamOperator())
.collect(Collectors.toList());
List<Map<String, String>> configs = verticesToMerge.stream()
.map(v -> vertexMap.get(v.getVertexId()).getConfig())
.collect(Collectors.toList());
StreamOperator operator;
if (language == Language.JAVA) {
operator = ChainedOperator.newChainedOperator(operators, configs);
} else {
List<PythonOperator> pythonOperators = operators.stream()
.map(o -> (PythonOperator) o).collect(Collectors.toList());
operator = new ChainedPythonOperator(pythonOperators, configs);
}
// chained operator config is placed into `ChainedOperator`.
mergedVertex = new JobVertex(headVertex.getVertexId(), headVertex.getParallelism(),
headVertex.getVertexType(), operator, new HashMap<>());
}
mergedVertexMap.put(mergedVertex.getVertexId(), Pair.of(mergedVertex, verticesToMerge));
}
private List<JobEdge> createEdges() {
List<JobEdge> edges = new ArrayList<>();
mergedVertexMap.forEach((id, pair) -> {
JobVertex mergedVertex = pair.getLeft();
List<JobVertex> mergedVertices = pair.getRight();
JobVertex tailVertex = mergedVertices.get(mergedVertices.size() - 1);
// input edge will be set up in input vertices
if (outputEdgesMap.containsKey(tailVertex)) {
outputEdgesMap.get(tailVertex).forEach(edge -> {
Pair<JobVertex, List<JobVertex>> downstreamPair =
mergedVertexMap.get(edge.getTargetVertexId());
// change ForwardPartition to RoundRobinPartition.
Partition partition = changePartition(edge.getPartition());
JobEdge newEdge = new JobEdge(
mergedVertex.getVertexId(),
downstreamPair.getLeft().getVertexId(),
partition);
edges.add(newEdge);
});
}
});
return edges;
}
/**
* Change ForwardPartition to RoundRobinPartition.
*/
private Partition changePartition(Partition partition) {
if (partition instanceof PythonPartition) {
PythonPartition pythonPartition = (PythonPartition) partition;
if (!pythonPartition.isConstructedFromBinary() &&
pythonPartition.getFunctionName().equals(PythonPartition.FORWARD_PARTITION_CLASS)) {
return PythonPartition.RoundRobinPartition;
} else {
return partition;
}
} else {
if (partition instanceof ForwardPartition) {
return new RoundRobinPartition();
} else {
return partition;
}
}
}
private boolean canBeChained(JobVertex precedingVertex,
JobVertex succeedingVertex,
JobEdge edge) {
if (jobGraph.getVertexOutputEdges(precedingVertex.getVertexId()).size() > 1 ||
jobGraph.getVertexInputEdges(succeedingVertex.getVertexId()).size() > 1) {
return false;
}
if (precedingVertex.getParallelism() != succeedingVertex.getParallelism()) {
return false;
}
if (precedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.NEVER
|| succeedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.NEVER
|| succeedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.HEAD) {
return false;
}
if (precedingVertex.getLanguage() != succeedingVertex.getLanguage()) {
return false;
}
Partition partition = edge.getPartition();
if (!(partition instanceof PythonPartition)) {
return partition instanceof ForwardPartition;
} else {
PythonPartition pythonPartition = (PythonPartition) partition;
return !pythonPartition.isConstructedFromBinary() &&
pythonPartition.getFunctionName().equals(PythonPartition.FORWARD_PARTITION_CLASS);
}
}
}
@@ -10,6 +10,7 @@ import java.util.Map;
* Job vertex is a cell node where logic is executed.
*/
public class JobVertex implements Serializable {
private int vertexId;
private int parallelism;
private VertexType vertexType;
@@ -20,12 +21,14 @@ public class JobVertex implements Serializable {
public JobVertex(int vertexId,
int parallelism,
VertexType vertexType,
StreamOperator streamOperator) {
StreamOperator streamOperator,
Map<String, String> config) {
this.vertexId = vertexId;
this.parallelism = parallelism;
this.vertexType = vertexType;
this.streamOperator = streamOperator;
this.language = streamOperator.getLanguage();
this.config = config;
}
public int getVertexId() {
@@ -67,4 +70,5 @@ public class JobVertex implements Serializable {
.add("config", config)
.toString();
}
}
@@ -0,0 +1,20 @@
package io.ray.streaming.operator;
/**
* Chain strategy for streaming operators. Chained operators are run in the same thread.
*/
public enum ChainStrategy {
/**
* The operator won't be chained with preceding operators, but maybe chained with succeeding
* operators.
*/
HEAD,
/**
* Operators will be chained together when possible.
*/
ALWAYS,
/**
* The operator won't be chained with any operator.
*/
NEVER
}
@@ -9,6 +9,8 @@ import java.util.List;
public interface Operator extends Serializable {
String getName();
void open(List<Collector> collectors, RuntimeContext runtimeContext);
void finish();
@@ -20,4 +22,7 @@ public interface Operator extends Serializable {
Language getLanguage();
OperatorType getOpType();
ChainStrategy getChainStrategy();
}
@@ -0,0 +1,14 @@
package io.ray.streaming.operator;
import io.ray.streaming.api.function.impl.SourceFunction.SourceContext;
public interface SourceOperator<T> extends Operator {
void run();
SourceContext<T> getSourceContext();
default OperatorType getOpType() {
return OperatorType.SOURCE;
}
}
@@ -12,13 +12,22 @@ import java.util.List;
public abstract class StreamOperator<F extends Function> implements Operator {
protected final String name;
protected final F function;
protected final RichFunction richFunction;
protected F function;
protected RichFunction richFunction;
protected List<Collector> collectorList;
protected RuntimeContext runtimeContext;
private ChainStrategy chainStrategy = ChainStrategy.ALWAYS;
public StreamOperator(F function) {
protected StreamOperator() {
this.name = getClass().getSimpleName();
}
protected StreamOperator(F function) {
this();
setFunction(function);
}
public void setFunction(F function) {
this.function = function;
this.richFunction = Functions.wrap(function);
}
@@ -62,7 +71,17 @@ public abstract class StreamOperator<F extends Function> implements Operator {
}
}
@Override
public String getName() {
return name;
}
public void setChainStrategy(ChainStrategy chainStrategy) {
this.chainStrategy = chainStrategy;
}
@Override
public ChainStrategy getChainStrategy() {
return chainStrategy;
}
}
@@ -0,0 +1,169 @@
package io.ray.streaming.operator.chain;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.function.Function;
import io.ray.streaming.api.function.impl.SourceFunction.SourceContext;
import io.ray.streaming.message.Record;
import io.ray.streaming.operator.OneInputOperator;
import io.ray.streaming.operator.Operator;
import io.ray.streaming.operator.OperatorType;
import io.ray.streaming.operator.SourceOperator;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.operator.TwoInputOperator;
import java.lang.reflect.Proxy;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* Abstract base class for chained operators.
*/
public abstract class ChainedOperator extends StreamOperator<Function> {
protected final List<StreamOperator> operators;
protected final Operator headOperator;
protected final Operator tailOperator;
private final List<Map<String, String>> configs;
public ChainedOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
Preconditions.checkArgument(operators.size() >= 2,
"Need at lease two operators to be chained together");
operators.stream().skip(1)
.forEach(operator -> Preconditions.checkArgument(operator instanceof OneInputOperator));
this.operators = operators;
this.configs = configs;
this.headOperator = operators.get(0);
this.tailOperator = operators.get(operators.size() - 1);
}
@Override
public void open(List<Collector> collectorList, RuntimeContext runtimeContext) {
// Dont' call super.open() as we `open` every operator separately.
List<ForwardCollector> succeedingCollectors = operators.stream().skip(1)
.map(operator -> new ForwardCollector((OneInputOperator) operator))
.collect(Collectors.toList());
for (int i = 0; i < operators.size() - 1; i++) {
StreamOperator operator = operators.get(i);
List<ForwardCollector> forwardCollectors =
Collections.singletonList(succeedingCollectors.get(i));
operator.open(forwardCollectors, createRuntimeContext(runtimeContext, i));
}
// tail operator send data to downstream using provided collectors.
tailOperator.open(collectorList, createRuntimeContext(runtimeContext, operators.size() - 1));
}
@Override
public OperatorType getOpType() {
return headOperator.getOpType();
}
@Override
public Language getLanguage() {
return headOperator.getLanguage();
}
@Override
public String getName() {
return operators.stream().map(Operator::getName)
.collect(Collectors.joining(" -> ", "[", "]"));
}
public List<StreamOperator> getOperators() {
return operators;
}
public Operator getHeadOperator() {
return headOperator;
}
public Operator getTailOperator() {
return tailOperator;
}
private RuntimeContext createRuntimeContext(RuntimeContext runtimeContext, int index) {
return (RuntimeContext) Proxy.newProxyInstance(runtimeContext.getClass().getClassLoader(),
new Class[] {RuntimeContext.class},
(proxy, method, methodArgs) -> {
if (method.getName().equals("getConfig")) {
return configs.get(index);
} else {
return method.invoke(runtimeContext, methodArgs);
}
});
}
public static ChainedOperator newChainedOperator(
List<StreamOperator> operators,
List<Map<String, String>> configs) {
switch (operators.get(0).getOpType()) {
case SOURCE:
return new ChainedSourceOperator(operators, configs);
case ONE_INPUT:
return new ChainedOneInputOperator(operators, configs);
case TWO_INPUT:
return new ChainedTwoInputOperator(operators, configs);
default:
throw new IllegalArgumentException(
"Unsupported operator type " + operators.get(0).getOpType());
}
}
static class ChainedSourceOperator<T> extends ChainedOperator
implements SourceOperator<T> {
private final SourceOperator<T> sourceOperator;
@SuppressWarnings("unchecked")
ChainedSourceOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
super(operators, configs);
sourceOperator = (SourceOperator<T>) headOperator;
}
@Override
public void run() {
sourceOperator.run();
}
@Override
public SourceContext<T> getSourceContext() {
return sourceOperator.getSourceContext();
}
}
static class ChainedOneInputOperator<T> extends ChainedOperator
implements OneInputOperator<T> {
private final OneInputOperator<T> inputOperator;
@SuppressWarnings("unchecked")
ChainedOneInputOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
super(operators, configs);
inputOperator = (OneInputOperator<T>) headOperator;
}
@Override
public void processElement(Record<T> record) throws Exception {
inputOperator.processElement(record);
}
}
static class ChainedTwoInputOperator<L, R> extends ChainedOperator
implements TwoInputOperator<L, R> {
private final TwoInputOperator<L, R> inputOperator;
@SuppressWarnings("unchecked")
ChainedTwoInputOperator(List<StreamOperator> operators, List<Map<String, String>> configs) {
super(operators, configs);
inputOperator = (TwoInputOperator<L, R>) headOperator;
}
@Override
public void processElement(Record<L> record1, Record<R> record2) {
inputOperator.processElement(record1, record2);
}
}
}
@@ -0,0 +1,23 @@
package io.ray.streaming.operator.chain;
import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.message.Record;
import io.ray.streaming.operator.OneInputOperator;
class ForwardCollector implements Collector<Record> {
private final OneInputOperator succeedingOperator;
ForwardCollector(OneInputOperator succeedingOperator) {
this.succeedingOperator = succeedingOperator;
}
@SuppressWarnings("unchecked")
@Override
public void collect(Record record) {
try {
succeedingOperator.processElement(record);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
@@ -0,0 +1,39 @@
package io.ray.streaming.operator.impl;
import io.ray.streaming.api.function.impl.JoinFunction;
import io.ray.streaming.message.Record;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.operator.OperatorType;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.operator.TwoInputOperator;
/**
* Join operator
*
* @param <L> Type of the data in the left stream.
* @param <R> Type of the data in the right stream.
* @param <K> Type of the data in the join key.
* @param <O> Type of the data in the joined stream.
*/
public class JoinOperator<L, R, K, O> extends StreamOperator<JoinFunction<L, R, O>> implements
TwoInputOperator<L, R> {
public JoinOperator() {
}
public JoinOperator(JoinFunction<L, R, O> function) {
super(function);
setChainStrategy(ChainStrategy.HEAD);
}
@Override
public void processElement(Record<L> record1, Record<R> record2) {
}
@Override
public OperatorType getOpType() {
return OperatorType.TWO_INPUT;
}
}
@@ -5,6 +5,7 @@ import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.function.impl.ReduceFunction;
import io.ray.streaming.message.KeyRecord;
import io.ray.streaming.message.Record;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.operator.OneInputOperator;
import io.ray.streaming.operator.StreamOperator;
import java.util.HashMap;
@@ -18,6 +19,7 @@ public class ReduceOperator<K, T> extends StreamOperator<ReduceFunction<T>> impl
public ReduceOperator(ReduceFunction<T> reduceFunction) {
super(reduceFunction);
setChainStrategy(ChainStrategy.HEAD);
}
@Override
@@ -41,4 +43,5 @@ public class ReduceOperator<K, T> extends StreamOperator<ReduceFunction<T>> impl
collect(record);
}
}
}
@@ -5,15 +5,19 @@ import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.function.impl.SourceFunction;
import io.ray.streaming.api.function.impl.SourceFunction.SourceContext;
import io.ray.streaming.message.Record;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.operator.OperatorType;
import io.ray.streaming.operator.SourceOperator;
import io.ray.streaming.operator.StreamOperator;
import java.util.List;
public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
public class SourceOperatorImpl<T> extends StreamOperator<SourceFunction<T>>
implements SourceOperator {
private SourceContextImpl sourceContext;
public SourceOperator(SourceFunction<T> function) {
public SourceOperatorImpl(SourceFunction<T> function) {
super(function);
setChainStrategy(ChainStrategy.HEAD);
}
@Override
@@ -23,6 +27,7 @@ public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
this.function.init(runtimeContext.getParallelism(), runtimeContext.getTaskIndex());
}
@Override
public void run() {
try {
this.function.run(this.sourceContext);
@@ -31,13 +36,17 @@ public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
}
}
@Override
public SourceContext getSourceContext() {
return sourceContext;
}
@Override
public OperatorType getOpType() {
return OperatorType.SOURCE;
}
class SourceContextImpl implements SourceContext<T> {
private List<Collector> collectors;
public SourceContextImpl(List<Collector> collectors) {
@@ -47,9 +56,10 @@ public class SourceOperator<T> extends StreamOperator<SourceFunction<T>> {
@Override
public void collect(T t) throws Exception {
for (Collector collector : collectors) {
collector.collect(new Record(t));
collector.collect(new Record<>(t));
}
}
}
}
@@ -100,6 +100,14 @@ public class PythonFunction implements Function {
return functionInterface;
}
public String toSimpleString() {
if (function != null) {
return "binary function";
} else {
return String.format("%s-%s.%s", functionInterface, moduleName, functionName);
}
}
@Override
public String toString() {
StringJoiner stringJoiner = new StringJoiner(", ",
@@ -1,11 +1,16 @@
package io.ray.streaming.python;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.function.Function;
import io.ray.streaming.operator.Operator;
import io.ray.streaming.operator.OperatorType;
import io.ray.streaming.operator.StreamOperator;
import java.util.List;
import java.util.Map;
import java.util.StringJoiner;
import java.util.stream.Collectors;
/**
* Represents a {@link StreamOperator} that wraps python {@link PythonFunction}.
@@ -27,30 +32,6 @@ public class PythonOperator extends StreamOperator {
this.className = null;
}
@Override
public void open(List list, RuntimeContext runtimeContext) {
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
throw new UnsupportedOperationException(msg);
}
@Override
public void finish() {
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
throw new UnsupportedOperationException(msg);
}
@Override
public void close() {
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
throw new UnsupportedOperationException(msg);
}
@Override
public OperatorType getOpType() {
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
throw new UnsupportedOperationException(msg);
}
@Override
public Language getLanguage() {
return Language.PYTHON;
@@ -64,6 +45,48 @@ public class PythonOperator extends StreamOperator {
return className;
}
@Override
public void open(List list, RuntimeContext runtimeContext) {
throwUnsupportedException();
}
@Override
public void finish() {
throwUnsupportedException();
}
@Override
public void close() {
throwUnsupportedException();
}
void throwUnsupportedException() {
StackTraceElement[] trace = Thread.currentThread().getStackTrace();
Preconditions.checkState(trace.length >= 2);
StackTraceElement traceElement = trace[2];
String msg = String.format("Method %s.%s shouldn't be called.",
traceElement.getClassName(), traceElement.getMethodName());
throw new UnsupportedOperationException(msg);
}
@Override
public OperatorType getOpType() {
String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName());
throw new UnsupportedOperationException(msg);
}
@Override
public String getName() {
StringBuilder builder = new StringBuilder();
builder.append(PythonOperator.class.getSimpleName()).append("[");
if (function != null) {
builder.append(((PythonFunction)function).toSimpleString());
} else {
builder.append(moduleName).append(".").append(className);
}
return builder.append("]").toString();
}
@Override
public String toString() {
StringJoiner stringJoiner = new StringJoiner(", ",
@@ -76,4 +99,71 @@ public class PythonOperator extends StreamOperator {
}
return stringJoiner.toString();
}
public static class ChainedPythonOperator extends PythonOperator {
private final List<PythonOperator> operators;
private final PythonOperator headOperator;
private final PythonOperator tailOperator;
private final List<Map<String, String>> configs;
public ChainedPythonOperator(
List<PythonOperator> operators, List<Map<String, String>> configs) {
super(null);
Preconditions.checkArgument(!operators.isEmpty());
this.operators = operators;
this.configs = configs;
this.headOperator = operators.get(0);
this.tailOperator = operators.get(operators.size() - 1);
}
@Override
public OperatorType getOpType() {
return headOperator.getOpType();
}
@Override
public Language getLanguage() {
return Language.PYTHON;
}
@Override
public String getName() {
return operators.stream().map(Operator::getName)
.collect(Collectors.joining(" -> ", "[", "]"));
}
@Override
public String getModuleName() {
throwUnsupportedException();
return null; // impossible
}
@Override
public String getClassName() {
throwUnsupportedException();
return null; // impossible
}
@Override
public Function getFunction() {
throwUnsupportedException();
return null; // impossible
}
public List<PythonOperator> getOperators() {
return operators;
}
public PythonOperator getHeadOperator() {
return headOperator;
}
public PythonOperator getTailOperator() {
return tailOperator;
}
public List<Map<String, String>> getConfigs() {
return configs;
}
}
}
@@ -24,6 +24,9 @@ public class PythonPartition implements Partition<Object> {
"ray.streaming.partition", "KeyPartition");
public static final PythonPartition RoundRobinPartition = new PythonPartition(
"ray.streaming.partition", "RoundRobinPartition");
public static final String FORWARD_PARTITION_CLASS = "ForwardPartition";
public static final PythonPartition ForwardPartition = new PythonPartition(
"ray.streaming.partition", FORWARD_PARTITION_CLASS);
private byte[] partition;
private String moduleName;
@@ -66,6 +69,10 @@ public class PythonPartition implements Partition<Object> {
return functionName;
}
public boolean isConstructedFromBinary() {
return partition != null;
}
@Override
public String toString() {
StringJoiner stringJoiner = new StringJoiner(", ",
@@ -2,6 +2,7 @@ package io.ray.streaming.python.stream;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.KeyDataStream;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.python.PythonFunction;
import io.ray.streaming.python.PythonFunction.FunctionInterface;
import io.ray.streaming.python.PythonOperator;
@@ -37,7 +38,9 @@ public class PythonKeyDataStream extends PythonDataStream implements PythonStrea
*/
public PythonDataStream reduce(PythonFunction func) {
func.setFunctionInterface(FunctionInterface.REDUCE_FUNCTION);
return new PythonDataStream(this, new PythonOperator(func));
PythonDataStream stream = new PythonDataStream(this, new PythonOperator(func));
stream.withChainStrategy(ChainStrategy.HEAD);
return stream;
}
/**
@@ -2,10 +2,10 @@ package io.ray.streaming.python.stream;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.stream.StreamSource;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.python.PythonFunction;
import io.ray.streaming.python.PythonFunction.FunctionInterface;
import io.ray.streaming.python.PythonOperator;
import io.ray.streaming.python.PythonPartition;
/**
* Represents a source of the PythonStream.
@@ -13,8 +13,8 @@ import io.ray.streaming.python.PythonPartition;
public class PythonStreamSource extends PythonDataStream implements StreamSource {
private PythonStreamSource(StreamingContext streamingContext, PythonFunction sourceFunction) {
super(streamingContext, new PythonOperator(sourceFunction),
PythonPartition.RoundRobinPartition);
super(streamingContext, new PythonOperator(sourceFunction));
withChainStrategy(ChainStrategy.HEAD);
}
public static PythonStreamSource from(StreamingContext streamingContext,
@@ -14,6 +14,8 @@ public class PythonUnionStream extends PythonDataStream {
private List<PythonDataStream> unionStreams;
public PythonUnionStream(PythonDataStream input, List<PythonDataStream> others) {
// Union stream does not create a physical operation, so we don't have to set partition
// function for it.
super(input, new PythonOperator(
"ray.streaming.operator", "UnionOperator"));
this.unionStreams = new ArrayList<>();
@@ -2,8 +2,8 @@ package io.ray.streaming.jobgraph;
import com.google.common.collect.Lists;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.partition.impl.ForwardPartition;
import io.ray.streaming.api.partition.impl.KeyPartition;
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.DataStreamSource;
import io.ray.streaming.api.stream.StreamSink;
@@ -20,14 +20,14 @@ public class JobGraphBuilderTest {
@Test
public void testDataSync() {
JobGraph jobGraph = buildDataSyncJobGraph();
List<JobVertex> jobVertexList = jobGraph.getJobVertexList();
List<JobEdge> jobEdgeList = jobGraph.getJobEdgeList();
List<JobVertex> jobVertexList = jobGraph.getJobVertices();
List<JobEdge> jobEdgeList = jobGraph.getJobEdges();
Assert.assertEquals(jobVertexList.size(), 2);
Assert.assertEquals(jobEdgeList.size(), 1);
JobEdge jobEdge = jobEdgeList.get(0);
Assert.assertEquals(jobEdge.getPartition().getClass(), RoundRobinPartition.class);
Assert.assertEquals(jobEdge.getPartition().getClass(), ForwardPartition.class);
JobVertex sinkVertex = jobVertexList.get(1);
JobVertex sourceVertex = jobVertexList.get(0);
@@ -50,8 +50,8 @@ public class JobGraphBuilderTest {
@Test
public void testKeyByJobGraph() {
JobGraph jobGraph = buildKeyByJobGraph();
List<JobVertex> jobVertexList = jobGraph.getJobVertexList();
List<JobEdge> jobEdgeList = jobGraph.getJobEdgeList();
List<JobVertex> jobVertexList = jobGraph.getJobVertices();
List<JobEdge> jobEdgeList = jobGraph.getJobEdges();
Assert.assertEquals(jobVertexList.size(), 3);
Assert.assertEquals(jobEdgeList.size(), 2);
@@ -68,7 +68,7 @@ public class JobGraphBuilderTest {
JobEdge source2KeyBy = jobEdgeList.get(1);
Assert.assertEquals(keyBy2Sink.getPartition().getClass(), KeyPartition.class);
Assert.assertEquals(source2KeyBy.getPartition().getClass(), RoundRobinPartition.class);
Assert.assertEquals(source2KeyBy.getPartition().getClass(), ForwardPartition.class);
}
public JobGraph buildKeyByJobGraph() {
@@ -88,8 +88,8 @@ public class JobGraphBuilderTest {
JobGraph jobGraph = buildKeyByJobGraph();
jobGraph.generateDigraph();
String diGraph = jobGraph.getDigraph();
System.out.println(diGraph);
Assert.assertTrue(diGraph.contains("1-SourceOperator -> 2-KeyByOperator"));
Assert.assertTrue(diGraph.contains("2-KeyByOperator -> 3-SinkOperator"));
LOG.info(diGraph);
Assert.assertTrue(diGraph.contains("\"1-SourceOperatorImpl\" -> \"2-KeyByOperator\""));
Assert.assertTrue(diGraph.contains("\"2-KeyByOperator\" -> \"3-SinkOperator\""));
}
}
@@ -0,0 +1,70 @@
package io.ray.streaming.jobgraph;
import static org.testng.Assert.assertEquals;
import com.google.common.collect.Lists;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.DataStreamSource;
import io.ray.streaming.python.PythonFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.annotations.Test;
public class JobGraphOptimizerTest {
private static final Logger LOG = LoggerFactory.getLogger( JobGraphOptimizerTest.class );
@Test
public void testOptimize() {
StreamingContext context = StreamingContext.buildContext();
DataStream<Integer> source1 = DataStreamSource.fromCollection(context,
Lists.newArrayList(1 ,2 ,3));
DataStream<String> source2 = DataStreamSource.fromCollection(context,
Lists.newArrayList("1", "2", "3"));
DataStream<String> source3 = DataStreamSource.fromCollection(context,
Lists.newArrayList("2", "3", "4"));
source1.filter(x -> x > 1)
.map(String::valueOf)
.union(source2)
.join(source3)
.sink(x -> System.out.println("Sink " + x));
JobGraph jobGraph = new JobGraphBuilder(context.getStreamSinks()).build();
LOG.info("Digraph {}", jobGraph.generateDigraph());
assertEquals(jobGraph.getJobVertices().size(), 8);
JobGraphOptimizer graphOptimizer = new JobGraphOptimizer(jobGraph);
JobGraph optimizedJobGraph = graphOptimizer.optimize();
optimizedJobGraph.printJobGraph();
LOG.info("Optimized graph {}", optimizedJobGraph.generateDigraph());
assertEquals(optimizedJobGraph.getJobVertices().size(), 5);
}
@Test
public void testOptimizeHybridStream() {
StreamingContext context = StreamingContext.buildContext();
DataStream<Integer> source1 = DataStreamSource.fromCollection(context,
Lists.newArrayList(1 ,2 ,3));
DataStream<String> source2 = DataStreamSource.fromCollection(context,
Lists.newArrayList("1", "2", "3"));
source1.asPythonStream()
.map(pyFunc(1))
.filter(pyFunc(2))
.union(source2.asPythonStream().filter(pyFunc(3)).map(pyFunc(4)))
.asJavaStream()
.sink(x -> System.out.println("Sink " + x));
JobGraph jobGraph = new JobGraphBuilder(context.getStreamSinks()).build();
LOG.info("Digraph {}", jobGraph.generateDigraph());
assertEquals(jobGraph.getJobVertices().size(), 8);
JobGraphOptimizer graphOptimizer = new JobGraphOptimizer(jobGraph);
JobGraph optimizedJobGraph = graphOptimizer.optimize();
optimizedJobGraph.printJobGraph();
LOG.info("Optimized graph {}", optimizedJobGraph.generateDigraph());
assertEquals(optimizedJobGraph.getJobVertices().size(), 6);
}
private PythonFunction pyFunc(int number) {
return new PythonFunction("module", "func" + number);
}
}
@@ -2,9 +2,9 @@ package io.ray.streaming.runtime.core.processor;
import io.ray.streaming.operator.OneInputOperator;
import io.ray.streaming.operator.OperatorType;
import io.ray.streaming.operator.SourceOperator;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.operator.TwoInputOperator;
import io.ray.streaming.operator.impl.SourceOperator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -1,7 +1,7 @@
package io.ray.streaming.runtime.core.processor;
import io.ray.streaming.message.Record;
import io.ray.streaming.operator.impl.SourceOperator;
import io.ray.streaming.operator.SourceOperator;
/**
* The processor for the stream sources, containing a SourceOperator.
@@ -30,7 +30,7 @@ public class GraphManagerImpl implements GraphManager {
ExecutionGraph executionGraph = setupStructure(jobGraph);
// set max parallelism
int maxParallelism = jobGraph.getJobVertexList().stream()
int maxParallelism = jobGraph.getJobVertices().stream()
.map(JobVertex::getParallelism)
.max(Integer::compareTo).get();
executionGraph.setMaxParallelism(maxParallelism);
@@ -49,7 +49,7 @@ public class GraphManagerImpl implements GraphManager {
// create vertex
Map<Integer, ExecutionJobVertex> exeJobVertexMap = new LinkedHashMap<>();
long buildTime = executionGraph.getBuildTime();
for (JobVertex jobVertex : jobGraph.getJobVertexList()) {
for (JobVertex jobVertex : jobGraph.getJobVertices()) {
int jobVertexId = jobVertex.getVertexId();
exeJobVertexMap.put(jobVertexId,
new ExecutionJobVertex(
@@ -60,7 +60,7 @@ public class GraphManagerImpl implements GraphManager {
}
// connect vertex
jobGraph.getJobEdgeList().stream().forEach(jobEdge -> {
jobGraph.getJobEdges().forEach(jobEdge -> {
ExecutionJobVertex source = exeJobVertexMap.get(jobEdge.getSrcVertexId());
ExecutionJobVertex target = exeJobVertexMap.get(jobEdge.getTargetVertexId());
@@ -70,8 +70,8 @@ public class GraphManagerImpl implements GraphManager {
source.getOutputEdges().add(executionJobEdge);
target.getInputEdges().add(executionJobEdge);
source.getExecutionVertices().stream().forEach(vertex -> {
target.getExecutionVertices().stream().forEach(outputVertex -> {
source.getExecutionVertices().forEach(vertex -> {
target.getExecutionVertices().forEach(outputVertex -> {
ExecutionEdge executionEdge = new ExecutionEdge(vertex, outputVertex, executionJobEdge);
vertex.getOutputEdges().add(executionEdge);
outputVertex.getInputEdges().add(executionEdge);
@@ -7,6 +7,7 @@ import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.operator.Operator;
import io.ray.streaming.python.PythonFunction;
import io.ray.streaming.python.PythonOperator;
import io.ray.streaming.python.PythonOperator.ChainedPythonOperator;
import io.ray.streaming.python.PythonPartition;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionEdge;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
@@ -77,6 +78,7 @@ public class GraphPbBuilder {
executionVertexBuilder.setOperator(
ByteString.copyFrom(
serializeOperator(executionVertex.getStreamOperator())));
executionVertexBuilder.setChained(isPythonChainedOperator(executionVertex.getStreamOperator()));
executionVertexBuilder.setWorkerActor(
ByteString.copyFrom(
((NativeActorHandle) (executionVertex.getWorkerActor())).toBytes()));
@@ -104,17 +106,35 @@ public class GraphPbBuilder {
private byte[] serializeOperator(Operator operator) {
if (operator instanceof PythonOperator) {
PythonOperator pythonOperator = (PythonOperator) operator;
return serializer.serialize(Arrays.asList(
serializeFunction(pythonOperator.getFunction()),
pythonOperator.getModuleName(),
pythonOperator.getClassName()
));
if (isPythonChainedOperator(operator)) {
return serializePythonChainedOperator((ChainedPythonOperator) operator);
} else {
PythonOperator pythonOperator = (PythonOperator) operator;
return serializer.serialize(Arrays.asList(
serializeFunction(pythonOperator.getFunction()),
pythonOperator.getModuleName(),
pythonOperator.getClassName()
));
}
} else {
return new byte[0];
}
}
private boolean isPythonChainedOperator(Operator operator) {
return operator instanceof ChainedPythonOperator;
}
private byte[] serializePythonChainedOperator(ChainedPythonOperator operator) {
List<byte[]> serializedOperators = operator.getOperators().stream()
.map(this::serializeOperator).collect(Collectors.toList());
return serializer.serialize(Arrays.asList(
serializedOperators,
operator.getConfigs()
));
}
private byte[] serializeFunction(Function function) {
if (function instanceof PythonFunction) {
PythonFunction pyFunc = (PythonFunction) function;
@@ -1,6 +1,6 @@
package io.ray.streaming.runtime.worker.tasks;
import io.ray.streaming.operator.impl.SourceOperator;
import io.ray.streaming.operator.SourceOperator;
import io.ray.streaming.runtime.core.processor.Processor;
import io.ray.streaming.runtime.core.processor.SourceProcessor;
import io.ray.streaming.runtime.worker.JobWorker;
@@ -7,6 +7,7 @@ import io.ray.streaming.api.stream.DataStreamSource;
import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.jobgraph.JobGraphBuilder;
import io.ray.streaming.jobgraph.JobVertex;
import io.ray.streaming.runtime.BaseUnitTest;
import io.ray.streaming.runtime.config.StreamingConfig;
import io.ray.streaming.runtime.config.master.ResourceConfig;
@@ -40,10 +41,10 @@ public class ExecutionGraphTest extends BaseUnitTest {
ExecutionGraph executionGraph = buildExecutionGraph(graphManager, jobGraph);
List<ExecutionJobVertex> executionJobVertices = executionGraph.getExecutionJobVertexList();
Assert.assertEquals(executionJobVertices.size(), jobGraph.getJobVertexList().size());
Assert.assertEquals(executionJobVertices.size(), jobGraph.getJobVertices().size());
int totalVertexNum = jobGraph.getJobVertexList().stream()
.mapToInt(vertex -> vertex.getParallelism()).sum();
int totalVertexNum = jobGraph.getJobVertices().stream()
.mapToInt(JobVertex::getParallelism).sum();
Assert.assertEquals(executionGraph.getAllExecutionVertices().size(), totalVertexNum);
Assert.assertEquals(executionGraph.getAllExecutionVertices().size(),
executionGraph.getExecutionVertexIdGenerator().get());
@@ -66,7 +67,7 @@ public class ExecutionGraphTest extends BaseUnitTest {
List<ExecutionVertex> downStreamVertices = downStream.getExecutionVertices();
upStreamVertices.forEach(vertex -> {
Assert.assertEquals(vertex.getResource().get(ResourceType.CPU.name()), 2.0);
vertex.getOutputEdges().stream().forEach(upStreamOutPutEdge -> {
vertex.getOutputEdges().forEach(upStreamOutPutEdge -> {
Assert.assertTrue(downStreamVertices.contains(upStreamOutPutEdge.getTargetExecutionVertex()));
});
});
+1 -1
View File
@@ -39,7 +39,7 @@ fi
if [ $exit_code -ne 2 ] && [ $exit_code -ne 0 ] ; then
if [ -d "/tmp/ray_streaming_java_test_output/" ] ; then
echo "all test output"
for f in /tmp/ray_streaming_java_test_output/*; do
for f in /tmp/ray_streaming_java_test_output/*.{log,xml}; do
if [ -f "$f" ]; then
echo "Cat file $f"
cat "$f"
+12
View File
@@ -94,6 +94,18 @@ class Stream(ABC):
def get_language(self):
pass
def forward(self):
"""Set the partition function of this {@link Stream} so that output
elements are forwarded to next operator locally."""
self._gateway_client().call_method(self._j_stream, "forward")
return self
def disable_chain(self):
"""Disable chain for this stream so that it will be run in a separate
task."""
self._gateway_client().call_method(self._j_stream, "disableChain")
return self
def _gateway_client(self):
return self.get_streaming_context()._gateway_client
+102 -9
View File
@@ -1,12 +1,16 @@
import enum
import importlib
import logging
from abc import ABC, abstractmethod
from ray import streaming
from ray.streaming import function
from ray.streaming import message
from ray.streaming.collector import Collector
from ray.streaming.runtime import gateway_client
logger = logging.getLogger(__name__)
class OperatorType(enum.Enum):
SOURCE = 0 # Sources are where your program reads its input from
@@ -227,15 +231,93 @@ class UnionOperator(StreamOperator, OneInputOperator):
self.collect(record)
_function_to_operator = {
function.SourceFunction: SourceOperator,
function.MapFunction: MapOperator,
function.FlatMapFunction: FlatMapOperator,
function.FilterFunction: FilterOperator,
function.KeyFunction: KeyByOperator,
function.ReduceFunction: ReduceOperator,
function.SinkFunction: SinkOperator,
}
class ChainedOperator(StreamOperator, ABC):
class ForwardCollector(Collector):
def __init__(self, succeeding_operator):
self.succeeding_operator = succeeding_operator
def collect(self, record):
self.succeeding_operator.process_element(record)
def __init__(self, operators, configs):
super().__init__(operators[0].func)
self.operators = operators
self.configs = configs
def open(self, collectors, runtime_context):
# Dont' call super.open() as we `open` every operator separately.
num_operators = len(self.operators)
succeeding_collectors = [
ChainedOperator.ForwardCollector(operator)
for operator in self.operators[1:]
]
for i in range(0, num_operators - 1):
forward_collectors = [succeeding_collectors[i]]
self.operators[i].open(
forward_collectors,
self.__create_runtime_context(runtime_context, i))
self.operators[-1].open(
collectors,
self.__create_runtime_context(runtime_context, num_operators - 1))
def operator_type(self) -> OperatorType:
return self.operators[0].operator_type()
def __create_runtime_context(self, runtime_context, index):
def get_config():
return self.configs[index]
runtime_context.get_config = get_config
return runtime_context
@staticmethod
def new_chained_operator(operators, configs):
operator_type = operators[0].operator_type()
logger.info(
"Building ChainedOperator from operators {} and configs {}."
.format(operators, configs))
if operator_type == OperatorType.SOURCE:
return ChainedSourceOperator(operators, configs)
elif operator_type == OperatorType.ONE_INPUT:
return ChainedOneInputOperator(operators, configs)
elif operator_type == OperatorType.TWO_INPUT:
return ChainedTwoInputOperator(operators, configs)
else:
raise Exception("Current operator type is not supported")
class ChainedSourceOperator(ChainedOperator):
def __init__(self, operators, configs):
super().__init__(operators, configs)
def run(self):
self.operators[0].run()
class ChainedOneInputOperator(ChainedOperator):
def __init__(self, operators, configs):
super().__init__(operators, configs)
def process_element(self, record):
self.operators[0].process_element(record)
class ChainedTwoInputOperator(ChainedOperator):
def __init__(self, operators, configs):
super().__init__(operators, configs)
def process_element(self, record1, record2):
self.operators[0].process_element(record1, record2)
def load_chained_operator(chained_operator_bytes: bytes):
"""Load chained operator from serialized operators and configs"""
serialized_operators, configs = gateway_client.deserialize(
chained_operator_bytes)
operators = [
load_operator(desc_bytes) for desc_bytes in serialized_operators
]
return ChainedOperator.new_chained_operator(operators, configs)
def load_operator(descriptor_operator_bytes: bytes):
@@ -267,6 +349,17 @@ def load_operator(descriptor_operator_bytes: bytes):
return cls()
_function_to_operator = {
function.SourceFunction: SourceOperator,
function.MapFunction: MapOperator,
function.FlatMapFunction: FlatMapOperator,
function.FilterFunction: FilterOperator,
function.KeyFunction: KeyByOperator,
function.ReduceFunction: ReduceOperator,
function.SinkFunction: SinkOperator,
}
def create_operator_with_func(func: function.Function):
"""Create an operator according to a :class:`function.Function`
+11
View File
@@ -60,6 +60,17 @@ class RoundRobinPartition(Partition):
return self.__partitions
class ForwardPartition(Partition):
"""Default partition for operator if the operator can be chained with
succeeding operators."""
def __init__(self):
self.__partitions = [0]
def partition(self, key_record, num_partition: int):
return self.__partitions
class SimplePartition(Partition):
"""Wrap a python function as subclass of :class:`Partition`"""
+10 -1
View File
@@ -1,4 +1,5 @@
import enum
import logging
import ray
import ray.streaming.generated.remote_call_pb2 as remote_call_pb
@@ -6,6 +7,8 @@ import ray.streaming.operator as operator
import ray.streaming.partition as partition
from ray.streaming.generated.streaming_pb2 import Language
logger = logging.getLogger(__name__)
class NodeType(enum.Enum):
"""
@@ -43,7 +46,13 @@ class ExecutionVertex:
self.parallelism = vertex_pb.parallelism
if vertex_pb.language == Language.PYTHON:
operator_bytes = vertex_pb.operator # python operator descriptor
self.stream_operator = operator.load_operator(operator_bytes)
if vertex_pb.chained:
logger.info("Load chained operator")
self.stream_operator = operator.load_chained_operator(
operator_bytes)
else:
logger.info("Load operator")
self.stream_operator = operator.load_operator(operator_bytes)
self.worker_actor = ray.actor.ActorHandle. \
_deserialization_helper(vertex_pb.worker_actor)
self.container_id = vertex_pb.container_id
+7 -6
View File
@@ -30,12 +30,13 @@ message ExecutionVertexContext {
int32 parallelism = 5;
// serialized operator
bytes operator = 6;
bytes worker_actor = 7;
string container_id = 8;
uint64 build_time = 9;
Language language = 10;
map<string, string> config = 11;
map<string, double> resource = 12;
bool chained = 7;
bytes worker_actor = 8;
string container_id = 9;
uint64 build_time = 10;
Language language = 11;
map<string, string> config = 12;
map<string, double> resource = 13;
}
// vertices