[Streaming] Streaming data transfer supports cross language. (#7961)

* add init parameters for java

* fix bug

* cython

* fix compile

* fix test_direct_tranfer

* comment

* ChannelCreationParameter

* fix comment

* builder

* lint and fix tests

* fix single process test

* fix checkstyle and lint

* checkstyle

* lint python

Co-authored-by: wanxing <wanxing@B-458DMD6M-1753.local>
This commit is contained in:
wanxing
2020-04-16 15:16:48 +08:00
committed by GitHub
parent 5a7882bb44
commit 9345d03ffb
36 changed files with 618 additions and 333 deletions
@@ -0,0 +1,161 @@
package io.ray.streaming.runtime.transfer;
import com.google.common.base.Preconditions;
import io.ray.api.BaseActor;
import io.ray.api.id.ActorId;
import io.ray.runtime.actor.LocalModeRayActor;
import io.ray.runtime.actor.NativeRayJavaActor;
import io.ray.runtime.actor.NativeRayPyActor;
import io.ray.runtime.functionmanager.FunctionDescriptor;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.runtime.functionmanager.PyFunctionDescriptor;
import io.ray.streaming.runtime.worker.JobWorker;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* Save channel initial parameters needed by DataWriter/DataReader.
*/
public class ChannelCreationParametersBuilder {
public class Parameter {
private ActorId actorId;
private FunctionDescriptor asyncFunctionDescriptor;
private FunctionDescriptor syncFunctionDescriptor;
public void setActorId(ActorId actorId) {
this.actorId = actorId;
}
public void setAsyncFunctionDescriptor(
FunctionDescriptor asyncFunctionDescriptor) {
this.asyncFunctionDescriptor = asyncFunctionDescriptor;
}
public void setSyncFunctionDescriptor(
FunctionDescriptor syncFunctionDescriptor) {
this.syncFunctionDescriptor = syncFunctionDescriptor;
}
public String toString() {
String language =
asyncFunctionDescriptor instanceof JavaFunctionDescriptor ? "Java" : "Python";
return "Language: " + language + " Desc: " + asyncFunctionDescriptor.toList() + " "
+ syncFunctionDescriptor.toList();
}
// Get actor id in bytes, called from jni.
public byte[] getActorIdBytes() {
return actorId.getBytes();
}
// Get async function descriptor, called from jni.
public FunctionDescriptor getAsyncFunctionDescriptor() {
return asyncFunctionDescriptor;
}
// Get sync function descriptor, called from jni.
public FunctionDescriptor getSyncFunctionDescriptor() {
return syncFunctionDescriptor;
}
}
private List<Parameter> parameters;
// function descriptors of direct call entry point for Java workers
private static JavaFunctionDescriptor javaReaderAsyncFuncDesc = new JavaFunctionDescriptor(
JobWorker.class.getName(),
"onReaderMessage", "([B)V");
private static JavaFunctionDescriptor javaReaderSyncFuncDesc = new JavaFunctionDescriptor(
JobWorker.class.getName(),
"onReaderMessageSync", "([B)[B");
private static JavaFunctionDescriptor javaWriterAsyncFuncDesc = new JavaFunctionDescriptor(
JobWorker.class.getName(),
"onWriterMessage", "([B)V");
private static JavaFunctionDescriptor javaWriterSyncFuncDesc = new JavaFunctionDescriptor(
JobWorker.class.getName(),
"onWriterMessageSync", "([B)[B");
// function descriptors of direct call entry point for Python workers
private static PyFunctionDescriptor pyReaderAsyncFunctionDesc = new PyFunctionDescriptor(
"ray.streaming.runtime.worker",
"JobWorker", "on_reader_message");
private static PyFunctionDescriptor pyReaderSyncFunctionDesc = new PyFunctionDescriptor(
"ray.streaming.runtime.worker",
"JobWorker", "on_reader_message_sync");
private static PyFunctionDescriptor pyWriterAsyncFunctionDesc = new PyFunctionDescriptor(
"ray.streaming.runtime.worker",
"JobWorker", "on_writer_message");
private static PyFunctionDescriptor pyWriterSyncFunctionDesc = new PyFunctionDescriptor(
"ray.streaming.runtime.worker",
"JobWorker", "on_writer_message_sync");
public ChannelCreationParametersBuilder() {
}
public static void setJavaReaderFunctionDesc(JavaFunctionDescriptor asyncFunc,
JavaFunctionDescriptor syncFunc) {
javaReaderAsyncFuncDesc = asyncFunc;
javaReaderSyncFuncDesc = syncFunc;
}
public static void setJavaWriterFunctionDesc(JavaFunctionDescriptor asyncFunc,
JavaFunctionDescriptor syncFunc) {
javaWriterAsyncFuncDesc = asyncFunc;
javaWriterSyncFuncDesc = syncFunc;
}
public ChannelCreationParametersBuilder buildInputQueueParameters(List<String> queues,
Map<String, BaseActor> actors) {
return buildParameters(queues, actors, javaWriterAsyncFuncDesc, javaWriterSyncFuncDesc,
pyWriterAsyncFunctionDesc, pyWriterSyncFunctionDesc);
}
public ChannelCreationParametersBuilder buildOutputQueueParameters(List<String> queues,
Map<String, BaseActor> actors) {
return buildParameters(queues, actors, javaReaderAsyncFuncDesc, javaReaderSyncFuncDesc,
pyReaderAsyncFunctionDesc, pyReaderSyncFunctionDesc);
}
private ChannelCreationParametersBuilder buildParameters(List<String> queues,
Map<String, BaseActor> actors,
JavaFunctionDescriptor javaAsyncFunctionDesc, JavaFunctionDescriptor javaSyncFunctionDesc,
PyFunctionDescriptor pyAsyncFunctionDesc, PyFunctionDescriptor pySyncFunctionDesc
) {
parameters = new ArrayList<>(queues.size());
for (String queue : queues) {
Parameter parameter = new Parameter();
BaseActor actor = actors.get(queue);
Preconditions.checkArgument(actor != null);
parameter.setActorId(actor.getId());
/// LocalModeRayActor used in single-process mode.
if (actor instanceof NativeRayJavaActor || actor instanceof LocalModeRayActor) {
parameter.setAsyncFunctionDescriptor(javaAsyncFunctionDesc);
parameter.setSyncFunctionDescriptor(javaSyncFunctionDesc);
} else if (actor instanceof NativeRayPyActor) {
parameter.setAsyncFunctionDescriptor(pyAsyncFunctionDesc);
parameter.setSyncFunctionDescriptor(pySyncFunctionDesc);
} else {
Preconditions.checkArgument(false, "Invalid actor type");
}
parameters.add(parameter);
}
return this;
}
// Called from jni
public List<Parameter> getParameters() {
return parameters;
}
public String toString() {
String str = "";
for (Parameter param : parameters) {
str += param.toString();
}
return str;
}
}
@@ -1,7 +1,7 @@
package io.ray.streaming.runtime.transfer;
import com.google.common.base.Preconditions;
import io.ray.api.id.ActorId;
import io.ray.api.BaseActor;
import io.ray.streaming.runtime.util.Platform;
import io.ray.streaming.util.Config;
import java.nio.ByteBuffer;
@@ -24,14 +24,14 @@ public class DataReader {
private Queue<DataMessage> buf = new LinkedList<>();
public DataReader(List<String> inputChannels,
List<ActorId> fromActors,
Map<String, BaseActor> fromActors,
Map<String, String> conf) {
Preconditions.checkArgument(inputChannels.size() > 0);
Preconditions.checkArgument(inputChannels.size() == fromActors.size());
ChannelCreationParametersBuilder initialParameters =
new ChannelCreationParametersBuilder().buildInputQueueParameters(inputChannels, fromActors);
byte[][] inputChannelsBytes = inputChannels.stream()
.map(ChannelID::idStrToBytes).toArray(byte[][]::new);
byte[][] fromActorsBytes = fromActors.stream()
.map(ActorId::getBytes).toArray(byte[][]::new);
long[] seqIds = new long[inputChannels.size()];
long[] msgIds = new long[inputChannels.size()];
for (int i = 0; i < inputChannels.size(); i++) {
@@ -48,8 +48,8 @@ public class DataReader {
boolean isRecreate = Boolean.parseBoolean(
conf.getOrDefault(Config.IS_RECREATE, "false"));
this.nativeReaderPtr = createDataReaderNative(
initialParameters,
inputChannelsBytes,
fromActorsBytes,
seqIds,
msgIds,
timerInterval,
@@ -155,8 +155,8 @@ public class DataReader {
}
private static native long createDataReaderNative(
ChannelCreationParametersBuilder initialParameters,
byte[][] inputChannels,
byte[][] inputActorIds,
long[] seqIds,
long[] msgIds,
long timerInterval,
@@ -1,7 +1,7 @@
package io.ray.streaming.runtime.transfer;
import com.google.common.base.Preconditions;
import io.ray.api.id.ActorId;
import io.ray.api.BaseActor;
import io.ray.streaming.runtime.util.Platform;
import io.ray.streaming.util.Config;
import java.nio.ByteBuffer;
@@ -33,14 +33,14 @@ public class DataWriter {
* @param conf configuration
*/
public DataWriter(List<String> outputChannels,
List<ActorId> toActors,
Map<String, BaseActor> toActors,
Map<String, String> conf) {
Preconditions.checkArgument(!outputChannels.isEmpty());
Preconditions.checkArgument(outputChannels.size() == toActors.size());
ChannelCreationParametersBuilder initialParameters =
new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors);
byte[][] outputChannelsBytes = outputChannels.stream()
.map(ChannelID::idStrToBytes).toArray(byte[][]::new);
byte[][] toActorsBytes = toActors.stream()
.map(ActorId::getBytes).toArray(byte[][]::new);
long channelSize = Long.parseLong(
conf.getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT));
long[] msgIds = new long[outputChannels.size()];
@@ -53,8 +53,8 @@ public class DataWriter {
isMock = true;
}
this.nativeWriterPtr = createWriterNative(
initialParameters,
outputChannelsBytes,
toActorsBytes,
msgIds,
channelSize,
ChannelUtils.toNativeConf(conf),
@@ -123,8 +123,8 @@ public class DataWriter {
}
private static native long createWriterNative(
ChannelCreationParametersBuilder initialParameters,
byte[][] outputQueueIds,
byte[][] outputActorIds,
long[] msgIds,
long channelSize,
byte[] confBytes,
@@ -1,8 +1,6 @@
package io.ray.streaming.runtime.transfer;
import io.ray.runtime.RayNativeRuntime;
import io.ray.runtime.functionmanager.FunctionDescriptor;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.runtime.util.JniUtils;
/**
@@ -23,12 +21,9 @@ public class TransferHandler {
private long writerClientNative;
private long readerClientNative;
public TransferHandler(JavaFunctionDescriptor writerAsyncFunc,
JavaFunctionDescriptor writerSyncFunc,
JavaFunctionDescriptor readerAsyncFunc,
JavaFunctionDescriptor readerSyncFunc) {
writerClientNative = createWriterClientNative(writerAsyncFunc, writerSyncFunc);
readerClientNative = createReaderClientNative(readerAsyncFunc, readerSyncFunc);
public TransferHandler() {
writerClientNative = createWriterClientNative();
readerClientNative = createReaderClientNative();
}
public void onWriterMessage(byte[] buffer) {
@@ -47,13 +42,10 @@ public class TransferHandler {
return handleReaderMessageSyncNative(readerClientNative, buffer);
}
private native long createWriterClientNative(
FunctionDescriptor asyncFunc,
FunctionDescriptor syncFunc);
private native long createWriterClientNative();
private native long createReaderClientNative();
private native long createReaderClientNative(
FunctionDescriptor asyncFunc,
FunctionDescriptor syncFunc);
private native void handleWriterMessageNative(long handler, byte[] buffer);
@@ -1,6 +1,5 @@
package io.ray.streaming.runtime.worker;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.streaming.runtime.core.graph.ExecutionGraph;
import io.ray.streaming.runtime.core.graph.ExecutionNode;
import io.ray.streaming.runtime.core.graph.ExecutionNode.NodeType;
@@ -16,8 +15,10 @@ import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask;
import io.ray.streaming.runtime.worker.tasks.SourceStreamTask;
import io.ray.streaming.runtime.worker.tasks.StreamTask;
import io.ray.streaming.util.Config;
import java.io.Serializable;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -52,17 +53,13 @@ public class JobWorker implements Serializable {
this.nodeType = executionNode.getNodeType();
this.streamProcessor = ProcessBuilder
.buildProcessor(executionNode.getStreamOperator());
.buildProcessor(executionNode.getStreamOperator());
LOGGER.debug("Initializing StreamWorker, taskId: {}, operator: {}.", taskId, streamProcessor);
String channelType = (String) this.config.getOrDefault(
Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
if (channelType.equals(Config.NATIVE_CHANNEL)) {
transferHandler = new TransferHandler(
new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessage", "([B)V"),
new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessageSync", "([B)[B"),
new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessage", "([B)V"),
new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessageSync", "([B)[B"));
transferHandler = new TransferHandler();
}
task = createStreamTask();
task.start();
@@ -2,7 +2,6 @@ package io.ray.streaming.runtime.worker.tasks;
import io.ray.api.BaseActor;
import io.ray.api.Ray;
import io.ray.api.id.ActorId;
import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.partition.Partition;
@@ -17,10 +16,12 @@ import io.ray.streaming.runtime.transfer.DataWriter;
import io.ray.streaming.runtime.worker.JobWorker;
import io.ray.streaming.runtime.worker.context.RayRuntimeContext;
import io.ray.streaming.util.Config;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -41,7 +42,7 @@ public abstract class StreamTask implements Runnable {
prepareTask();
this.thread = new Thread(Ray.wrapRunnable(this), this.getClass().getName()
+ "-" + System.currentTimeMillis());
+ "-" + System.currentTimeMillis());
this.thread.setDaemon(true);
}
@@ -64,22 +65,20 @@ public abstract class StreamTask implements Runnable {
List<ExecutionEdge> outputEdges = executionNode.getOutputEdges();
List<Collector> collectors = new ArrayList<>();
for (ExecutionEdge edge : outputEdges) {
Map<String, ActorId> outputActorIds = new HashMap<>();
Map<String, BaseActor> outputActors = new HashMap<>();
Map<Integer, BaseActor> taskId2Worker = executionGraph
.getTaskId2WorkerByNodeId(edge.getTargetNodeId());
taskId2Worker.forEach((targetTaskId, targetActor) -> {
String queueName = ChannelID.genIdStr(taskId, targetTaskId, executionGraph.getBuildTime());
outputActorIds.put(queueName, targetActor.getId());
outputActors.put(queueName, targetActor);
});
if (!outputActorIds.isEmpty()) {
if (!outputActors.isEmpty()) {
List<String> channelIDs = new ArrayList<>();
List<ActorId> toActorIds = new ArrayList<>();
outputActorIds.forEach((k, v) -> {
outputActors.forEach((k, v) -> {
channelIDs.add(k);
toActorIds.add(v);
});
DataWriter writer = new DataWriter(channelIDs, toActorIds, queueConf);
DataWriter writer = new DataWriter(channelIDs, outputActors, queueConf);
LOG.info("Create DataWriter succeed.");
writers.put(edge, writer);
Partition partition = edge.getPartition();
@@ -89,24 +88,22 @@ public abstract class StreamTask implements Runnable {
// consumer
List<ExecutionEdge> inputEdges = executionNode.getInputsEdges();
Map<String, ActorId> inputActorIds = new HashMap<>();
Map<String, BaseActor> inputActors = new HashMap<>();
for (ExecutionEdge edge : inputEdges) {
Map<Integer, BaseActor> taskId2Worker = executionGraph
.getTaskId2WorkerByNodeId(edge.getSrcNodeId());
taskId2Worker.forEach((srcTaskId, srcActor) -> {
String queueName = ChannelID.genIdStr(srcTaskId, taskId, executionGraph.getBuildTime());
inputActorIds.put(queueName, srcActor.getId());
inputActors.put(queueName, srcActor);
});
}
if (!inputActorIds.isEmpty()) {
if (!inputActors.isEmpty()) {
List<String> channelIDs = new ArrayList<>();
List<ActorId> fromActorIds = new ArrayList<>();
inputActorIds.forEach((k, v) -> {
inputActors.forEach((k, v) -> {
channelIDs.add(k);
fromActorIds.add(v);
});
LOG.info("Register queue consumer, queues {}.", channelIDs);
reader = new DataReader(channelIDs, fromActorIds, queueConf);
reader = new DataReader(channelIDs, inputActors, queueConf);
}
RuntimeContext runtimeContext = new RayRuntimeContext(
@@ -36,7 +36,6 @@ import org.testng.annotations.Test;
public class StreamingQueueTest extends BaseUnitTest implements Serializable {
private static Logger LOGGER = LoggerFactory.getLogger(StreamingQueueTest.class);
static {
EnvUtil.loadNativeLibraries();
}
@@ -62,7 +61,6 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
@BeforeMethod
void beforeMethod() {
LOGGER.info("beforeTest");
Ray.shutdown();
System.setProperty("ray.resources", "CPU:4,RES-A:4");
@@ -144,6 +142,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
@Test(timeOut = 60000)
public void testWordCount() {
LOGGER.info("testWordCount");
LOGGER.info("StreamingQueueTest.testWordCount run-mode: {}",
System.getProperty("ray.run-mode"));
String resultFile = "/tmp/io.ray.streaming.runtime.streamingqueue.testWordCount.txt";
@@ -1,10 +1,11 @@
package io.ray.streaming.runtime.streamingqueue;
import io.ray.api.BaseActor;
import io.ray.api.Ray;
import io.ray.api.RayActor;
import io.ray.api.id.ActorId;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.streaming.runtime.transfer.ChannelID;
import io.ray.streaming.runtime.transfer.ChannelCreationParametersBuilder;
import io.ray.streaming.runtime.transfer.DataMessage;
import io.ray.streaming.runtime.transfer.DataReader;
import io.ray.streaming.runtime.transfer.DataWriter;
@@ -12,7 +13,6 @@ import io.ray.streaming.runtime.transfer.TransferHandler;
import io.ray.streaming.util.Config;
import java.lang.management.ManagementFactory;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -27,15 +27,7 @@ public class Worker {
protected TransferHandler transferHandler = null;
public Worker() {
transferHandler = new TransferHandler(
new JavaFunctionDescriptor(Worker.class.getName(),
"onWriterMessage", "([B)V"),
new JavaFunctionDescriptor(Worker.class.getName(),
"onWriterMessageSync", "([B)[B"),
new JavaFunctionDescriptor(Worker.class.getName(),
"onReaderMessage", "([B)V"),
new JavaFunctionDescriptor(Worker.class.getName(),
"onReaderMessageSync", "([B)[B"));
transferHandler = new TransferHandler();
}
public void onReaderMessage(byte[] buffer) {
@@ -60,7 +52,7 @@ class ReaderWorker extends Worker {
private String name = null;
private List<String> inputQueueList = null;
private List<ActorId> inputActorIds = new ArrayList<>();
Map<String, BaseActor> fromActors = new HashMap<>();
private DataReader dataReader = null;
private long handler = 0;
private RayActor<WriterWorker> peerActor = null;
@@ -95,7 +87,7 @@ class ReaderWorker extends Worker {
LOGGER.info("java.library.path = {}", System.getProperty("java.library.path"));
for (String queue : this.inputQueueList) {
inputActorIds.add(this.peerActor.getId());
fromActors.put(queue, this.peerActor);
LOGGER.info("ReaderWorker actorId: {}", this.peerActor.getId());
}
@@ -104,7 +96,10 @@ class ReaderWorker extends Worker {
conf.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL);
conf.put(Config.CHANNEL_SIZE, "100000");
conf.put(Config.STREAMING_JOB_NAME, "integrationTest1");
dataReader = new DataReader(inputQueueList, inputActorIds, conf);
ChannelCreationParametersBuilder.setJavaWriterFunctionDesc(
new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessage", "([B)V"),
new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessageSync", "([B)[B"));
dataReader = new DataReader(inputQueueList, fromActors, conf);
// Should not GetBundle in RayCall thread
Thread readThread = new Thread(Ray.wrapRunnable(new Runnable() {
@@ -176,7 +171,7 @@ class WriterWorker extends Worker {
private String name = null;
private List<String> outputQueueList = null;
private List<ActorId> outputActorIds = new ArrayList<>();
Map<String, BaseActor> toActors = new HashMap<>();
DataWriter dataWriter = null;
RayActor<ReaderWorker> peerActor = null;
int msgCount = 0;
@@ -208,7 +203,7 @@ class WriterWorker extends Worker {
LOGGER.info("WriterWorker init:");
for (String queue : this.outputQueueList) {
outputActorIds.add(this.peerActor.getId());
toActors.put(queue, this.peerActor);
LOGGER.info("WriterWorker actorId: {}", this.peerActor.getId());
}
@@ -227,8 +222,10 @@ class WriterWorker extends Worker {
conf.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL);
conf.put(Config.CHANNEL_SIZE, "100000");
conf.put(Config.STREAMING_JOB_NAME, "integrationTest1");
dataWriter = new DataWriter(this.outputQueueList, this.outputActorIds, conf);
ChannelCreationParametersBuilder.setJavaReaderFunctionDesc(
new JavaFunctionDescriptor(Worker.class.getName(), "onReaderMessage", "([B)V"),
new JavaFunctionDescriptor(Worker.class.getName(), "onReaderMessageSync", "([B)[B"));
dataWriter = new DataWriter(this.outputQueueList, this.toActors, conf);
Thread writerThread = new Thread(Ray.wrapRunnable(new Runnable() {
@Override
public void run() {