[Streaming] Streaming Cross-Lang API (#7464)

This commit is contained in:
chaokunyang
2020-04-29 13:42:08 +08:00
committed by GitHub
parent 101255f782
commit 91f630f709
72 changed files with 1612 additions and 408 deletions
@@ -0,0 +1,129 @@
package io.ray.streaming.api.context;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.gson.Gson;
import io.ray.api.Ray;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.util.NetworkUtil;
import java.io.File;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class ClusterStarter {
private static final Logger LOG = LoggerFactory.getLogger(ClusterStarter.class);
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/plasma_store_socket";
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/raylet_socket";
static synchronized void startCluster(boolean isCrossLanguage, boolean isLocal) {
Preconditions.checkArgument(Ray.internal() == null);
RayConfig.reset();
if (!isLocal) {
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
System.setProperty("ray.run-mode", "CLUSTER");
} else {
System.clearProperty("ray.raylet.config.num_workers_per_process_java");
System.setProperty("ray.run-mode", "SINGLE_PROCESS");
}
if (!isCrossLanguage) {
Ray.init();
return;
}
// Delete existing socket files.
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
File file = new File(socket);
if (file.exists()) {
LOG.info("Delete existing socket file {}", file);
file.delete();
}
}
String nodeManagerPort = String.valueOf(NetworkUtil.getUnusedPort());
// jars in the `ray` wheel doesn't contains test classes, so we add test classes explicitly.
// Since mvn test classes contains `test` in path and bazel test classes is located at a jar
// with `test` included in the name, we can check classpath `test` to filter out test classes.
String classpath = Stream.of(System.getProperty("java.class.path").split(":"))
.filter(s -> !s.contains(" ") && s.contains("test"))
.collect(Collectors.joining(":"));
String workerOptions = new Gson().toJson(ImmutableList.of("-classpath", classpath));
Map<String, String> config = new HashMap<>(RayConfig.create().rayletConfigParameters);
config.put("num_workers_per_process_java", "1");
// Start ray cluster.
List<String> startCommand = ImmutableList.of(
"ray",
"start",
"--head",
"--redis-port=6379",
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
String.format("--node-manager-port=%s", nodeManagerPort),
"--load-code-from-local",
"--include-java",
"--java-worker-options=" + workerOptions,
"--internal-config=" + new Gson().toJson(config)
);
if (!executeCommand(startCommand, 10)) {
throw new RuntimeException("Couldn't start ray cluster.");
}
// Connect to the cluster.
System.setProperty("ray.redis.address", "127.0.0.1:6379");
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
System.setProperty("ray.raylet.node-manager-port", nodeManagerPort);
Ray.init();
}
public static synchronized void stopCluster(boolean isCrossLanguage) {
// Disconnect to the cluster.
Ray.shutdown();
System.clearProperty("ray.redis.address");
System.clearProperty("ray.object-store.socket-name");
System.clearProperty("ray.raylet.socket-name");
System.clearProperty("ray.raylet.node-manager-port");
System.clearProperty("ray.raylet.config.num_workers_per_process_java");
System.clearProperty("ray.run-mode");
if (isCrossLanguage) {
// Stop ray cluster.
final List<String> stopCommand = ImmutableList.of(
"ray",
"stop"
);
if (!executeCommand(stopCommand, 10)) {
throw new RuntimeException("Couldn't stop ray cluster");
}
}
}
/**
* Execute an external command.
*
* @return Whether the command succeeded.
*/
private static boolean executeCommand(List<String> command, int waitTimeoutSeconds) {
LOG.info("Executing command: {}", String.join(" ", command));
try {
ProcessBuilder processBuilder = new ProcessBuilder(command)
.redirectOutput(ProcessBuilder.Redirect.INHERIT)
.redirectError(ProcessBuilder.Redirect.INHERIT);
Process process = processBuilder.start();
boolean exit = process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
if (!exit) {
process.destroyForcibly();
}
return process.exitValue() == 0;
} catch (Exception e) {
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
}
}
}
@@ -1,10 +1,12 @@
package io.ray.streaming.api.context;
import com.google.common.base.Preconditions;
import io.ray.api.Ray;
import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.jobgraph.JobGraphBuilder;
import io.ray.streaming.schedule.JobScheduler;
import io.ray.streaming.util.Config;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
@@ -13,11 +15,14 @@ import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Encapsulate the context information of a streaming Job.
*/
public class StreamingContext implements Serializable {
private static final Logger LOG = LoggerFactory.getLogger(StreamingContext.class);
private transient AtomicInteger idGenerator;
@@ -54,6 +59,20 @@ public class StreamingContext implements Serializable {
this.jobGraph = jobGraphBuilder.build();
jobGraph.printJobGraph();
if (Ray.internal() == null) {
if (Config.MEMORY_CHANNEL.equalsIgnoreCase(jobConfig.get(Config.CHANNEL_TYPE))) {
Preconditions.checkArgument(!jobGraph.isCrossLanguageGraph());
ClusterStarter.startCluster(false, true);
LOG.info("Created local cluster for job {}.", jobName);
} else {
ClusterStarter.startCluster(jobGraph.isCrossLanguageGraph(), false);
LOG.info("Created multi process cluster for job {}.", jobName);
}
Runtime.getRuntime().addShutdownHook(new Thread(StreamingContext.this::stop));
} else {
LOG.info("Reuse existing cluster.");
}
ServiceLoader<JobScheduler> serviceLoader = ServiceLoader.load(JobScheduler.class);
Iterator<JobScheduler> iterator = serviceLoader.iterator();
Preconditions.checkArgument(iterator.hasNext(),
@@ -77,4 +96,10 @@ public class StreamingContext implements Serializable {
public void withConfig(Map<String, String> jobConfig) {
this.jobConfig = jobConfig;
}
public void stop() {
if (Ray.internal() != null) {
ClusterStarter.stopCluster(jobGraph.isCrossLanguageGraph());
}
}
}
@@ -1,6 +1,7 @@
package io.ray.streaming.api.stream;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.function.impl.FilterFunction;
import io.ray.streaming.api.function.impl.FlatMapFunction;
@@ -15,24 +16,44 @@ import io.ray.streaming.operator.impl.FlatMapOperator;
import io.ray.streaming.operator.impl.KeyByOperator;
import io.ray.streaming.operator.impl.MapOperator;
import io.ray.streaming.operator.impl.SinkOperator;
import io.ray.streaming.python.stream.PythonDataStream;
/**
* Represents a stream of data.
*
* This class defines all the streaming operations.
* <p>This class defines all the streaming operations.
*
* @param <T> Type of data in the stream.
*/
public class DataStream<T> extends Stream<T> {
public class DataStream<T> extends Stream<DataStream<T>, T> {
public DataStream(StreamingContext streamingContext, StreamOperator streamOperator) {
super(streamingContext, streamOperator);
}
public DataStream(DataStream input, StreamOperator streamOperator) {
public DataStream(StreamingContext streamingContext,
StreamOperator streamOperator,
Partition<T> partition) {
super(streamingContext, streamOperator, partition);
}
public <R> DataStream(DataStream<R> input, StreamOperator streamOperator) {
super(input, streamOperator);
}
public <R> DataStream(DataStream<R> input,
StreamOperator streamOperator,
Partition<T> partition) {
super(input, streamOperator, partition);
}
/**
* Create a java stream that reference passed python stream.
* Changes in new stream will be reflected in referenced stream and vice versa
*/
public DataStream(PythonDataStream referencedStream) {
super(referencedStream);
}
/**
* Apply a map function to this stream.
*
@@ -41,7 +62,7 @@ public class DataStream<T> extends Stream<T> {
* @return A new DataStream.
*/
public <R> DataStream<R> map(MapFunction<T, R> mapFunction) {
return new DataStream<>(this, new MapOperator(mapFunction));
return new DataStream<>(this, new MapOperator<>(mapFunction));
}
/**
@@ -52,11 +73,11 @@ public class DataStream<T> extends Stream<T> {
* @return A new DataStream
*/
public <R> DataStream<R> flatMap(FlatMapFunction<T, R> flatMapFunction) {
return new DataStream(this, new FlatMapOperator(flatMapFunction));
return new DataStream<>(this, new FlatMapOperator<>(flatMapFunction));
}
public DataStream<T> filter(FilterFunction<T> filterFunction) {
return new DataStream<T>(this, new FilterOperator(filterFunction));
return new DataStream<>(this, new FilterOperator<>(filterFunction));
}
/**
@@ -66,7 +87,7 @@ public class DataStream<T> extends Stream<T> {
* @return A new UnionStream.
*/
public UnionStream<T> union(DataStream<T> other) {
return new UnionStream(this, null, other);
return new UnionStream<>(this, null, other);
}
/**
@@ -93,7 +114,7 @@ public class DataStream<T> extends Stream<T> {
* @return A new StreamSink.
*/
public DataStreamSink<T> sink(SinkFunction<T> sinkFunction) {
return new DataStreamSink<>(this, new SinkOperator(sinkFunction));
return new DataStreamSink<>(this, new SinkOperator<>(sinkFunction));
}
/**
@@ -104,7 +125,8 @@ public class DataStream<T> extends Stream<T> {
* @return A new KeyDataStream.
*/
public <K> KeyDataStream<K, T> keyBy(KeyFunction<T, K> keyFunction) {
return new KeyDataStream<>(this, new KeyByOperator(keyFunction));
checkPartitionCall();
return new KeyDataStream<>(this, new KeyByOperator<>(keyFunction));
}
/**
@@ -113,8 +135,8 @@ public class DataStream<T> extends Stream<T> {
* @return This stream.
*/
public DataStream<T> broadcast() {
this.partition = new BroadcastPartition<>();
return this;
checkPartitionCall();
return setPartition(new BroadcastPartition<>());
}
/**
@@ -124,19 +146,32 @@ public class DataStream<T> extends Stream<T> {
* @return This stream.
*/
public DataStream<T> partitionBy(Partition<T> partition) {
this.partition = partition;
return this;
checkPartitionCall();
return setPartition(partition);
}
/**
* Set parallelism to current transformation.
*
* @param parallelism The parallelism to set.
* @return This stream.
* If parent stream is a python stream, we can't call partition related methods
* in the java stream.
*/
public DataStream<T> setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
private void checkPartitionCall() {
if (getInputStream() != null && getInputStream().getLanguage() == Language.PYTHON) {
throw new RuntimeException("Partition related methods can't be called on a " +
"java stream if parent stream is a python stream.");
}
}
/**
* Convert this stream as a python stream.
* The converted stream and this stream are the same logical stream, which has same stream id.
* Changes in converted stream will be reflected in this stream and vice versa.
*/
public PythonDataStream asPythonStream() {
return new PythonDataStream(this);
}
@Override
public Language getLanguage() {
return Language.JAVA;
}
}
@@ -1,5 +1,6 @@
package io.ray.streaming.api.stream;
import io.ray.streaming.api.Language;
import io.ray.streaming.operator.impl.SinkOperator;
/**
@@ -9,13 +10,13 @@ import io.ray.streaming.operator.impl.SinkOperator;
*/
public class DataStreamSink<T> extends StreamSink<T> {
public DataStreamSink(DataStream<T> input, SinkOperator sinkOperator) {
public DataStreamSink(DataStream input, SinkOperator sinkOperator) {
super(input, sinkOperator);
this.streamingContext.addSink(this);
getStreamingContext().addSink(this);
}
public DataStreamSink<T> setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
@Override
public Language getLanguage() {
return Language.JAVA;
}
}
@@ -14,27 +14,26 @@ import java.util.Collection;
*/
public class DataStreamSource<T> extends DataStream<T> implements StreamSource<T> {
public DataStreamSource(StreamingContext streamingContext, SourceFunction<T> sourceFunction) {
super(streamingContext, new SourceOperator<>(sourceFunction));
super.partition = new RoundRobinPartition<>();
private DataStreamSource(StreamingContext streamingContext, SourceFunction<T> sourceFunction) {
super(streamingContext, new SourceOperator<>(sourceFunction), new RoundRobinPartition<>());
}
public static <T> DataStreamSource<T> fromSource(
StreamingContext context, SourceFunction<T> sourceFunction) {
return new DataStreamSource<>(context, sourceFunction);
}
/**
* Build a DataStreamSource source from a collection.
*
* @param context Stream context.
* @param values A collection of values.
* @param <T> The type of source data.
* @param values A collection of values.
* @param <T> The type of source data.
* @return A DataStreamSource.
*/
public static <T> DataStreamSource<T> buildSource(
public static <T> DataStreamSource<T> fromCollection(
StreamingContext context, Collection<T> values) {
return new DataStreamSource(context, new CollectionSourceFunction(values));
return new DataStreamSource<>(context, new CollectionSourceFunction<>(values));
}
@Override
public DataStreamSource<T> setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
}
}
@@ -2,9 +2,12 @@ package io.ray.streaming.api.stream;
import io.ray.streaming.api.function.impl.AggregateFunction;
import io.ray.streaming.api.function.impl.ReduceFunction;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.partition.impl.KeyPartition;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.operator.impl.ReduceOperator;
import io.ray.streaming.python.stream.PythonDataStream;
import io.ray.streaming.python.stream.PythonKeyDataStream;
/**
* Represents a DataStream returned by a key-by operation.
@@ -12,11 +15,19 @@ import io.ray.streaming.operator.impl.ReduceOperator;
* @param <K> Type of the key.
* @param <T> Type of the data.
*/
@SuppressWarnings("unchecked")
public class KeyDataStream<K, T> extends DataStream<T> {
public KeyDataStream(DataStream<T> input, StreamOperator streamOperator) {
super(input, streamOperator);
this.partition = new KeyPartition();
super(input, streamOperator, (Partition<T>) new KeyPartition<K, T>());
}
/**
* Create a java stream that reference passed python stream.
* Changes in new stream will be reflected in referenced stream and vice versa
*/
public KeyDataStream(PythonDataStream referencedStream) {
super(referencedStream);
}
/**
@@ -41,8 +52,13 @@ public class KeyDataStream<K, T> extends DataStream<T> {
return new DataStream<>(this, null);
}
public KeyDataStream<K, T> setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
/**
* Convert this stream as a python stream.
* The converted stream and this stream are the same logical stream, which has same stream id.
* Changes in converted stream will be reflected in this stream and vice versa.
*/
public PythonKeyDataStream asPythonStream() {
return new PythonKeyDataStream(this);
}
}
@@ -1,58 +1,99 @@
package io.ray.streaming.api.stream;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
import io.ray.streaming.operator.Operator;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.python.PythonOperator;
import io.ray.streaming.python.PythonPartition;
import io.ray.streaming.python.stream.PythonStream;
import java.io.Serializable;
/**
* Abstract base class of all stream types.
*
* @param <S> Type of stream class
* @param <T> Type of the data in the stream.
*/
public abstract class Stream<T> implements Serializable {
protected int id;
protected int parallelism = 1;
protected StreamOperator operator;
protected Stream<T> inputStream;
protected StreamingContext streamingContext;
protected Partition<T> partition;
public abstract class Stream<S extends Stream<S, T>, T>
implements Serializable {
private final int id;
private final StreamingContext streamingContext;
private final Stream inputStream;
private final StreamOperator operator;
private int parallelism = 1;
private Partition<T> partition;
private Stream originalStream;
@SuppressWarnings("unchecked")
public Stream(StreamingContext streamingContext, StreamOperator streamOperator) {
this(streamingContext, null, streamOperator,
selectPartition(streamOperator));
}
public Stream(StreamingContext streamingContext,
StreamOperator streamOperator,
Partition<T> partition) {
this(streamingContext, null, streamOperator, partition);
}
public Stream(Stream inputStream, StreamOperator streamOperator) {
this(inputStream.getStreamingContext(), inputStream, streamOperator,
selectPartition(streamOperator));
}
public Stream(Stream inputStream, StreamOperator streamOperator, Partition<T> partition) {
this(inputStream.getStreamingContext(), inputStream, streamOperator, partition);
}
protected Stream(StreamingContext streamingContext,
Stream inputStream,
StreamOperator streamOperator,
Partition<T> partition) {
this.streamingContext = streamingContext;
this.inputStream = inputStream;
this.operator = streamOperator;
this.partition = partition;
this.id = streamingContext.generateId();
if (streamOperator instanceof PythonOperator) {
this.partition = PythonPartition.RoundRobinPartition;
} else {
this.partition = new RoundRobinPartition<>();
if (inputStream != null) {
this.parallelism = inputStream.getParallelism();
}
}
public Stream(Stream<T> inputStream, StreamOperator streamOperator) {
this.inputStream = inputStream;
this.parallelism = inputStream.getParallelism();
this.streamingContext = this.inputStream.getStreamingContext();
this.operator = streamOperator;
this.id = streamingContext.generateId();
this.partition = selectPartition();
/**
* Create a proxy stream of original stream.
* Changes in new stream will be reflected in original stream and vice versa
*/
protected Stream(Stream originalStream) {
this.originalStream = originalStream;
this.id = originalStream.getId();
this.streamingContext = originalStream.getStreamingContext();
this.inputStream = originalStream.getInputStream();
this.operator = originalStream.getOperator();
}
@SuppressWarnings("unchecked")
private Partition<T> selectPartition() {
if (inputStream instanceof PythonStream) {
return PythonPartition.RoundRobinPartition;
} else {
return new RoundRobinPartition<>();
private static <T> Partition<T> selectPartition(Operator operator) {
switch (operator.getLanguage()) {
case PYTHON:
return (Partition<T>) PythonPartition.RoundRobinPartition;
case JAVA:
return new RoundRobinPartition<>();
default:
throw new UnsupportedOperationException(
"Unsupported language " + operator.getLanguage());
}
}
public Stream<T> getInputStream() {
public int getId() {
return id;
}
public StreamingContext getStreamingContext() {
return streamingContext;
}
public Stream getInputStream() {
return inputStream;
}
@@ -60,32 +101,47 @@ public abstract class Stream<T> implements Serializable {
return operator;
}
public void setOperator(StreamOperator operator) {
this.operator = operator;
}
public StreamingContext getStreamingContext() {
return streamingContext;
@SuppressWarnings("unchecked")
private S self() {
return (S) this;
}
public int getParallelism() {
return parallelism;
return originalStream != null ? originalStream.getParallelism() : parallelism;
}
public Stream<T> setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
}
public int getId() {
return id;
public S setParallelism(int parallelism) {
if (originalStream != null) {
originalStream.setParallelism(parallelism);
} else {
this.parallelism = parallelism;
}
return self();
}
@SuppressWarnings("unchecked")
public Partition<T> getPartition() {
return partition;
return originalStream != null ? originalStream.getPartition() : partition;
}
public void setPartition(Partition<T> partition) {
this.partition = partition;
@SuppressWarnings("unchecked")
protected S setPartition(Partition<T> partition) {
if (originalStream != null) {
originalStream.setPartition(partition);
} else {
this.partition = partition;
}
return self();
}
public boolean isProxyStream() {
return originalStream != null;
}
public Stream getOriginalStream() {
Preconditions.checkArgument(isProxyStream());
return originalStream;
}
public abstract Language getLanguage();
}
@@ -7,8 +7,8 @@ import io.ray.streaming.operator.StreamOperator;
*
* @param <T> Type of the input data of this sink.
*/
public class StreamSink<T> extends Stream<T> {
public StreamSink(Stream<T> inputStream, StreamOperator streamOperator) {
public abstract class StreamSink<T> extends Stream<StreamSink<T>, T> {
public StreamSink(Stream inputStream, StreamOperator streamOperator) {
super(inputStream, streamOperator);
}
}
@@ -11,15 +11,15 @@ import java.util.List;
*/
public class UnionStream<T> extends DataStream<T> {
private List<DataStream> unionStreams;
private List<DataStream<T>> unionStreams;
public UnionStream(DataStream input, StreamOperator streamOperator, DataStream<T> other) {
public UnionStream(DataStream<T> input, StreamOperator streamOperator, DataStream<T> other) {
super(input, streamOperator);
this.unionStreams = new ArrayList<>();
this.unionStreams.add(other);
}
public List<DataStream> getUnionStreams() {
public List<DataStream<T>> getUnionStreams() {
return unionStreams;
}
}
@@ -1,5 +1,6 @@
package io.ray.streaming.jobgraph;
import io.ray.streaming.api.Language;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
@@ -97,4 +98,14 @@ public class JobGraph implements Serializable {
}
}
public boolean isCrossLanguageGraph() {
Language language = jobVertexList.get(0).getLanguage();
for (JobVertex jobVertex : jobVertexList) {
if (jobVertex.getLanguage() != language) {
return true;
}
}
return false;
}
}
@@ -1,5 +1,6 @@
package io.ray.streaming.jobgraph;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.Stream;
import io.ray.streaming.api.stream.StreamSink;
@@ -10,8 +11,11 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class JobGraphBuilder {
private static final Logger LOG = LoggerFactory.getLogger(JobGraphBuilder.class);
private JobGraph jobGraph;
@@ -41,12 +45,19 @@ public class JobGraphBuilder {
}
private void processStream(Stream stream) {
while (stream.isProxyStream()) {
// Proxy stream and original stream are the same logical stream, both refer to the
// same data flow transformation. We should skip proxy stream to avoid applying same
// transformation multiple times.
LOG.debug("Skip proxy stream {} of id {}", stream, stream.getId());
stream = stream.getOriginalStream();
}
StreamOperator streamOperator = stream.getOperator();
Preconditions.checkArgument(stream.getLanguage() == streamOperator.getLanguage(),
"Reference stream should be skipped.");
int vertexId = stream.getId();
int parallelism = stream.getParallelism();
StreamOperator streamOperator = stream.getOperator();
JobVertex jobVertex = null;
JobVertex jobVertex;
if (stream instanceof StreamSink) {
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator);
Stream parentStream = stream.getInputStream();
@@ -1,6 +1,8 @@
package io.ray.streaming.message;
import java.util.Objects;
public class KeyRecord<K, T> extends Record<T> {
private K key;
@@ -17,4 +19,24 @@ public class KeyRecord<K, T> extends Record<T> {
public void setKey(K key) {
this.key = key;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
KeyRecord<?, ?> keyRecord = (KeyRecord<?, ?>) o;
return Objects.equals(key, keyRecord.key);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), key);
}
}
@@ -1,64 +0,0 @@
package io.ray.streaming.message;
import com.google.common.collect.Lists;
import java.io.Serializable;
import java.util.List;
public class Message implements Serializable {
private int taskId;
private long batchId;
private String stream;
private List<Record> recordList;
public Message(int taskId, long batchId, String stream, List<Record> recordList) {
this.taskId = taskId;
this.batchId = batchId;
this.stream = stream;
this.recordList = recordList;
}
public Message(int taskId, long batchId, String stream, Record record) {
this.taskId = taskId;
this.batchId = batchId;
this.stream = stream;
this.recordList = Lists.newArrayList(record);
}
public int getTaskId() {
return taskId;
}
public void setTaskId(int taskId) {
this.taskId = taskId;
}
public long getBatchId() {
return batchId;
}
public void setBatchId(long batchId) {
this.batchId = batchId;
}
public String getStream() {
return stream;
}
public void setStream(String stream) {
this.stream = stream;
}
public List<Record> getRecordList() {
return recordList;
}
public void setRecordList(List<Record> recordList) {
this.recordList = recordList;
}
public Record getRecord(int index) {
return recordList.get(0);
}
}
@@ -1,6 +1,7 @@
package io.ray.streaming.message;
import java.io.Serializable;
import java.util.Objects;
public class Record<T> implements Serializable {
@@ -27,6 +28,24 @@ public class Record<T> implements Serializable {
this.stream = stream;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Record<?> record = (Record<?>) o;
return Objects.equals(stream, record.stream) &&
Objects.equals(value, record.value);
}
@Override
public int hashCode() {
return Objects.hash(stream, value);
}
@Override
public String toString() {
return value.toString();
@@ -1,6 +1,8 @@
package io.ray.streaming.python;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.function.Function;
import org.apache.commons.lang3.StringUtils;
/**
* Represents a user defined python function.
@@ -14,9 +16,8 @@ import io.ray.streaming.api.function.Function;
*
* <p>If the python data stream api is invoked from python, `function` will be not null.</p>
* <p>If the python data stream api is invoked from java, `moduleName` and
* `className`/`functionName` will be not null.</p>
* `functionName` will be not null.</p>
* <p>
* TODO serialize to bytes using protobuf
*/
public class PythonFunction implements Function {
public enum FunctionInterface {
@@ -38,23 +39,43 @@ public class PythonFunction implements Function {
}
}
private byte[] function;
private String moduleName;
private String className;
private String functionName;
// null if this function is constructed from moduleName/functionName.
private final byte[] function;
// null if this function is constructed from serialized python function.
private final String moduleName;
// null if this function is constructed from serialized python function.
private final String functionName;
/**
* FunctionInterface can be used to validate python function,
* and look up operator class from FunctionInterface.
*/
private String functionInterface;
private PythonFunction(byte[] function,
String moduleName,
String className,
String functionName) {
/**
* Create a {@link PythonFunction} from a serialized streaming python function.
*
* @param function serialized streaming python function from python driver.
*/
public PythonFunction(byte[] function) {
Preconditions.checkNotNull(function);
this.function = function;
this.moduleName = null;
this.functionName = null;
}
/**
* Create a {@link PythonFunction} from a moduleName and streaming function name.
*
* @param moduleName module name of streaming function.
* @param functionName function name of streaming function. {@code functionName} is the name
* of a python function, or class name of subclass of `ray.streaming.function.`
*/
public PythonFunction(String moduleName,
String functionName) {
Preconditions.checkArgument(StringUtils.isNotBlank(moduleName));
Preconditions.checkArgument(StringUtils.isNotBlank(functionName));
this.function = null;
this.moduleName = moduleName;
this.className = className;
this.functionName = functionName;
}
@@ -70,10 +91,6 @@ public class PythonFunction implements Function {
return moduleName;
}
public String getClassName() {
return className;
}
public String getFunctionName() {
return functionName;
}
@@ -82,34 +99,4 @@ public class PythonFunction implements Function {
return functionInterface;
}
/**
* Create a {@link PythonFunction} using python serialized function
*
* @param function serialized python function sent from python driver
*/
public static PythonFunction fromFunction(byte[] function) {
return new PythonFunction(function, null, null, null);
}
/**
* Create a {@link PythonFunction} using <code>moduleName</code> and
* <code>className</code>.
*
* @param moduleName python module name
* @param className python class name
*/
public static PythonFunction fromClassName(String moduleName, String className) {
return new PythonFunction(null, moduleName, className, null);
}
/**
* Create a {@link PythonFunction} using <code>moduleName</code> and
* <code>functionName</code>.
*
* @param moduleName python module name
* @param functionName python function name
*/
public static PythonFunction fromFunctionName(String moduleName, String functionName) {
return new PythonFunction(null, moduleName, null, functionName);
}
}
@@ -1,6 +1,8 @@
package io.ray.streaming.python;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.partition.Partition;
import org.apache.commons.lang3.StringUtils;
/**
* Represents a python partition function.
@@ -13,28 +15,33 @@ import io.ray.streaming.api.partition.Partition;
* If this object is constructed from moduleName and className/functionName,
* python worker will use `importlib` to load python partition function.
* <p>
* TODO serialize to bytes using protobuf
*/
public class PythonPartition implements Partition {
public class PythonPartition implements Partition<Object> {
public static final PythonPartition BroadcastPartition = new PythonPartition(
"ray.streaming.partition", "BroadcastPartition", null);
"ray.streaming.partition", "BroadcastPartition");
public static final PythonPartition KeyPartition = new PythonPartition(
"ray.streaming.partition", "KeyPartition", null);
"ray.streaming.partition", "KeyPartition");
public static final PythonPartition RoundRobinPartition = new PythonPartition(
"ray.streaming.partition", "RoundRobinPartition", null);
"ray.streaming.partition", "RoundRobinPartition");
private byte[] partition;
private String moduleName;
private String className;
private String functionName;
public PythonPartition(byte[] partition) {
Preconditions.checkNotNull(partition);
this.partition = partition;
}
public PythonPartition(String moduleName, String className, String functionName) {
/**
* Create a python partition from a moduleName and partition function name
* @param moduleName module name of python partition
* @param functionName function/class name of the partition function.
*/
public PythonPartition(String moduleName, String functionName) {
Preconditions.checkArgument(StringUtils.isNotBlank(moduleName));
Preconditions.checkArgument(StringUtils.isNotBlank(functionName));
this.moduleName = moduleName;
this.className = className;
this.functionName = functionName;
}
@@ -53,10 +60,6 @@ public class PythonPartition implements Partition {
return moduleName;
}
public String getClassName() {
return className;
}
public String getFunctionName() {
return functionName;
}
@@ -1,6 +1,9 @@
package io.ray.streaming.python.stream;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.Stream;
import io.ray.streaming.python.PythonFunction;
import io.ray.streaming.python.PythonFunction.FunctionInterface;
@@ -10,19 +13,39 @@ import io.ray.streaming.python.PythonPartition;
/**
* Represents a stream of data whose transformations will be executed in python.
*/
public class PythonDataStream extends Stream implements PythonStream {
public class PythonDataStream extends Stream<PythonDataStream, Object> implements PythonStream {
protected PythonDataStream(StreamingContext streamingContext,
PythonOperator pythonOperator) {
super(streamingContext, pythonOperator);
}
protected PythonDataStream(StreamingContext streamingContext,
PythonOperator pythonOperator,
Partition<Object> partition) {
super(streamingContext, pythonOperator, partition);
}
public PythonDataStream(PythonDataStream input, PythonOperator pythonOperator) {
super(input, pythonOperator);
}
protected PythonDataStream(Stream inputStream, PythonOperator pythonOperator) {
super(inputStream, pythonOperator);
public PythonDataStream(PythonDataStream input,
PythonOperator pythonOperator,
Partition<Object> partition) {
super(input, pythonOperator, partition);
}
/**
* Create a python stream that reference passed java stream.
* Changes in new stream will be reflected in referenced stream and vice versa
*/
public PythonDataStream(DataStream referencedStream) {
super(referencedStream);
}
public PythonDataStream map(String moduleName, String funcName) {
return map(new PythonFunction(moduleName, funcName));
}
/**
@@ -36,6 +59,10 @@ public class PythonDataStream extends Stream implements PythonStream {
return new PythonDataStream(this, new PythonOperator(func));
}
public PythonDataStream flatMap(String moduleName, String funcName) {
return flatMap(new PythonFunction(moduleName, funcName));
}
/**
* Apply a flat-map function to this stream.
*
@@ -47,6 +74,10 @@ public class PythonDataStream extends Stream implements PythonStream {
return new PythonDataStream(this, new PythonOperator(func));
}
public PythonDataStream filter(String moduleName, String funcName) {
return filter(new PythonFunction(moduleName, funcName));
}
/**
* Apply a filter function to this stream.
*
@@ -59,6 +90,10 @@ public class PythonDataStream extends Stream implements PythonStream {
return new PythonDataStream(this, new PythonOperator(func));
}
public PythonStreamSink sink(String moduleName, String funcName) {
return sink(new PythonFunction(moduleName, funcName));
}
/**
* Apply a sink function and get a StreamSink.
*
@@ -70,6 +105,10 @@ public class PythonDataStream extends Stream implements PythonStream {
return new PythonStreamSink(this, new PythonOperator(func));
}
public PythonKeyDataStream keyBy(String moduleName, String funcName) {
return keyBy(new PythonFunction(moduleName, funcName));
}
/**
* Apply a key-by function to this stream.
*
@@ -77,6 +116,7 @@ public class PythonDataStream extends Stream implements PythonStream {
* @return A new KeyDataStream.
*/
public PythonKeyDataStream keyBy(PythonFunction func) {
checkPartitionCall();
func.setFunctionInterface(FunctionInterface.KEY_FUNCTION);
return new PythonKeyDataStream(this, new PythonOperator(func));
}
@@ -87,8 +127,8 @@ public class PythonDataStream extends Stream implements PythonStream {
* @return This stream.
*/
public PythonDataStream broadcast() {
this.partition = PythonPartition.BroadcastPartition;
return this;
checkPartitionCall();
return setPartition(PythonPartition.BroadcastPartition);
}
/**
@@ -98,19 +138,33 @@ public class PythonDataStream extends Stream implements PythonStream {
* @return This stream.
*/
public PythonDataStream partitionBy(PythonPartition partition) {
this.partition = partition;
return this;
checkPartitionCall();
return setPartition(partition);
}
/**
* Set parallelism to current transformation.
*
* @param parallelism The parallelism to set.
* @return This stream.
* If parent stream is a python stream, we can't call partition related methods
* in the java stream.
*/
public PythonDataStream setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
private void checkPartitionCall() {
if (getInputStream() != null && getInputStream().getLanguage() == Language.JAVA) {
throw new RuntimeException("Partition related methods can't be called on a " +
"python stream if parent stream is a java stream.");
}
}
/**
* Convert this stream as a java stream.
* The converted stream and this stream are the same logical stream, which has same stream id.
* Changes in converted stream will be reflected in this stream and vice versa.
*/
public DataStream<Object> asJavaStream() {
return new DataStream<>(this);
}
@Override
public Language getLanguage() {
return Language.PYTHON;
}
}
@@ -1,5 +1,7 @@
package io.ray.streaming.python.stream;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.KeyDataStream;
import io.ray.streaming.python.PythonFunction;
import io.ray.streaming.python.PythonFunction.FunctionInterface;
import io.ray.streaming.python.PythonOperator;
@@ -8,11 +10,23 @@ import io.ray.streaming.python.PythonPartition;
/**
* Represents a python DataStream returned by a key-by operation.
*/
public class PythonKeyDataStream extends PythonDataStream implements PythonStream {
@SuppressWarnings("unchecked")
public class PythonKeyDataStream extends PythonDataStream implements PythonStream {
public PythonKeyDataStream(PythonDataStream input, PythonOperator pythonOperator) {
super(input, pythonOperator);
this.partition = PythonPartition.KeyPartition;
super(input, pythonOperator, PythonPartition.KeyPartition);
}
/**
* Create a python stream that reference passed python stream.
* Changes in new stream will be reflected in referenced stream and vice versa
*/
public PythonKeyDataStream(DataStream referencedStream) {
super(referencedStream);
}
public PythonDataStream reduce(String moduleName, String funcName) {
return reduce(new PythonFunction(moduleName, funcName));
}
/**
@@ -26,9 +40,13 @@ public class PythonKeyDataStream extends PythonDataStream implements PythonStrea
return new PythonDataStream(this, new PythonOperator(func));
}
public PythonKeyDataStream setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
/**
* Convert this stream as a java stream.
* The converted stream and this stream are the same logical stream, which has same stream id.
* Changes in converted stream will be reflected in this stream and vice versa.
*/
public KeyDataStream<Object, Object> asJavaStream() {
return new KeyDataStream(this);
}
}
@@ -1,5 +1,6 @@
package io.ray.streaming.python.stream;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.python.PythonOperator;
@@ -9,12 +10,12 @@ import io.ray.streaming.python.PythonOperator;
public class PythonStreamSink extends StreamSink implements PythonStream {
public PythonStreamSink(PythonDataStream input, PythonOperator sinkOperator) {
super(input, sinkOperator);
this.streamingContext.addSink(this);
getStreamingContext().addSink(this);
}
public PythonStreamSink setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
@Override
public Language getLanguage() {
return Language.PYTHON;
}
}
@@ -13,17 +13,12 @@ import io.ray.streaming.python.PythonPartition;
public class PythonStreamSource extends PythonDataStream implements StreamSource {
private PythonStreamSource(StreamingContext streamingContext, PythonFunction sourceFunction) {
super(streamingContext, new PythonOperator(sourceFunction));
super.partition = PythonPartition.RoundRobinPartition;
}
public PythonStreamSource setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
super(streamingContext, new PythonOperator(sourceFunction),
PythonPartition.RoundRobinPartition);
}
public static PythonStreamSource from(StreamingContext streamingContext,
PythonFunction sourceFunction) {
PythonFunction sourceFunction) {
sourceFunction.setFunctionInterface(FunctionInterface.SOURCE_FUNCTION);
return new PythonStreamSource(streamingContext, sourceFunction);
}
@@ -21,7 +21,6 @@ public class Config {
public static final String CHANNEL_TYPE = "channel_type";
public static final String MEMORY_CHANNEL = "memory_channel";
public static final String NATIVE_CHANNEL = "native_channel";
public static final String DEFAULT_CHANNEL_TYPE = NATIVE_CHANNEL;
public static final String CHANNEL_SIZE = "channel_size";
public static final String CHANNEL_SIZE_DEFAULT = String.valueOf((long)Math.pow(10, 8));
public static final String IS_RECREATE = "streaming.is_recreate";
@@ -0,0 +1,40 @@
package io.ray.streaming.api.stream;
import static org.testng.Assert.assertEquals;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.operator.impl.MapOperator;
import io.ray.streaming.python.stream.PythonDataStream;
import io.ray.streaming.python.stream.PythonKeyDataStream;
import org.testng.annotations.Test;
@SuppressWarnings("unchecked")
public class StreamTest {
@Test
public void testReferencedDataStream() {
DataStream dataStream = new DataStream(StreamingContext.buildContext(),
new MapOperator(value -> null));
PythonDataStream pythonDataStream = dataStream.asPythonStream();
DataStream javaStream = pythonDataStream.asJavaStream();
assertEquals(dataStream.getId(), pythonDataStream.getId());
assertEquals(dataStream.getId(), javaStream.getId());
javaStream.setParallelism(10);
assertEquals(dataStream.getParallelism(), pythonDataStream.getParallelism());
assertEquals(dataStream.getParallelism(), javaStream.getParallelism());
}
@Test
public void testReferencedKeyDataStream() {
DataStream dataStream = new DataStream(StreamingContext.buildContext(),
new MapOperator(value -> null));
KeyDataStream keyDataStream = dataStream.keyBy(value -> null);
PythonKeyDataStream pythonKeyDataStream = keyDataStream.asPythonStream();
KeyDataStream javaKeyDataStream = pythonKeyDataStream.asJavaStream();
assertEquals(keyDataStream.getId(), pythonKeyDataStream.getId());
assertEquals(keyDataStream.getId(), javaKeyDataStream.getId());
javaKeyDataStream.setParallelism(10);
assertEquals(keyDataStream.getParallelism(), pythonKeyDataStream.getParallelism());
assertEquals(keyDataStream.getParallelism(), javaKeyDataStream.getParallelism());
}
}
@@ -38,7 +38,7 @@ public class JobGraphBuilderTest {
public JobGraph buildDataSyncJobGraph() {
StreamingContext streamingContext = StreamingContext.buildContext();
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
Lists.newArrayList("a", "b", "c"));
StreamSink streamSink = dataStream.sink(x -> LOG.info(x));
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink));
@@ -73,7 +73,7 @@ public class JobGraphBuilderTest {
public JobGraph buildKeyByJobGraph() {
StreamingContext streamingContext = StreamingContext.buildContext();
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
Lists.newArrayList("1", "2", "3", "4"));
StreamSink streamSink = dataStream.keyBy(x -> x)
.sink(x -> LOG.info(x));