[Streaming] Streaming data transfer supports cross language. (#7961)

* add init parameters for java

* fix bug

* cython

* fix compile

* fix test_direct_tranfer

* comment

* ChannelCreationParameter

* fix comment

* builder

* lint and fix tests

* fix single process test

* fix checkstyle and lint

* checkstyle

* lint python

Co-authored-by: wanxing <wanxing@B-458DMD6M-1753.local>
This commit is contained in:
wanxing
2020-04-16 15:16:48 +08:00
committed by GitHub
parent 5a7882bb44
commit 9345d03ffb
36 changed files with 618 additions and 333 deletions
@@ -0,0 +1,161 @@
package io.ray.streaming.runtime.transfer;
import com.google.common.base.Preconditions;
import io.ray.api.BaseActor;
import io.ray.api.id.ActorId;
import io.ray.runtime.actor.LocalModeRayActor;
import io.ray.runtime.actor.NativeRayJavaActor;
import io.ray.runtime.actor.NativeRayPyActor;
import io.ray.runtime.functionmanager.FunctionDescriptor;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.runtime.functionmanager.PyFunctionDescriptor;
import io.ray.streaming.runtime.worker.JobWorker;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* Save channel initial parameters needed by DataWriter/DataReader.
*/
public class ChannelCreationParametersBuilder {
public class Parameter {
private ActorId actorId;
private FunctionDescriptor asyncFunctionDescriptor;
private FunctionDescriptor syncFunctionDescriptor;
public void setActorId(ActorId actorId) {
this.actorId = actorId;
}
public void setAsyncFunctionDescriptor(
FunctionDescriptor asyncFunctionDescriptor) {
this.asyncFunctionDescriptor = asyncFunctionDescriptor;
}
public void setSyncFunctionDescriptor(
FunctionDescriptor syncFunctionDescriptor) {
this.syncFunctionDescriptor = syncFunctionDescriptor;
}
public String toString() {
String language =
asyncFunctionDescriptor instanceof JavaFunctionDescriptor ? "Java" : "Python";
return "Language: " + language + " Desc: " + asyncFunctionDescriptor.toList() + " "
+ syncFunctionDescriptor.toList();
}
// Get actor id in bytes, called from jni.
public byte[] getActorIdBytes() {
return actorId.getBytes();
}
// Get async function descriptor, called from jni.
public FunctionDescriptor getAsyncFunctionDescriptor() {
return asyncFunctionDescriptor;
}
// Get sync function descriptor, called from jni.
public FunctionDescriptor getSyncFunctionDescriptor() {
return syncFunctionDescriptor;
}
}
private List<Parameter> parameters;
// function descriptors of direct call entry point for Java workers
private static JavaFunctionDescriptor javaReaderAsyncFuncDesc = new JavaFunctionDescriptor(
JobWorker.class.getName(),
"onReaderMessage", "([B)V");
private static JavaFunctionDescriptor javaReaderSyncFuncDesc = new JavaFunctionDescriptor(
JobWorker.class.getName(),
"onReaderMessageSync", "([B)[B");
private static JavaFunctionDescriptor javaWriterAsyncFuncDesc = new JavaFunctionDescriptor(
JobWorker.class.getName(),
"onWriterMessage", "([B)V");
private static JavaFunctionDescriptor javaWriterSyncFuncDesc = new JavaFunctionDescriptor(
JobWorker.class.getName(),
"onWriterMessageSync", "([B)[B");
// function descriptors of direct call entry point for Python workers
private static PyFunctionDescriptor pyReaderAsyncFunctionDesc = new PyFunctionDescriptor(
"ray.streaming.runtime.worker",
"JobWorker", "on_reader_message");
private static PyFunctionDescriptor pyReaderSyncFunctionDesc = new PyFunctionDescriptor(
"ray.streaming.runtime.worker",
"JobWorker", "on_reader_message_sync");
private static PyFunctionDescriptor pyWriterAsyncFunctionDesc = new PyFunctionDescriptor(
"ray.streaming.runtime.worker",
"JobWorker", "on_writer_message");
private static PyFunctionDescriptor pyWriterSyncFunctionDesc = new PyFunctionDescriptor(
"ray.streaming.runtime.worker",
"JobWorker", "on_writer_message_sync");
public ChannelCreationParametersBuilder() {
}
public static void setJavaReaderFunctionDesc(JavaFunctionDescriptor asyncFunc,
JavaFunctionDescriptor syncFunc) {
javaReaderAsyncFuncDesc = asyncFunc;
javaReaderSyncFuncDesc = syncFunc;
}
public static void setJavaWriterFunctionDesc(JavaFunctionDescriptor asyncFunc,
JavaFunctionDescriptor syncFunc) {
javaWriterAsyncFuncDesc = asyncFunc;
javaWriterSyncFuncDesc = syncFunc;
}
public ChannelCreationParametersBuilder buildInputQueueParameters(List<String> queues,
Map<String, BaseActor> actors) {
return buildParameters(queues, actors, javaWriterAsyncFuncDesc, javaWriterSyncFuncDesc,
pyWriterAsyncFunctionDesc, pyWriterSyncFunctionDesc);
}
public ChannelCreationParametersBuilder buildOutputQueueParameters(List<String> queues,
Map<String, BaseActor> actors) {
return buildParameters(queues, actors, javaReaderAsyncFuncDesc, javaReaderSyncFuncDesc,
pyReaderAsyncFunctionDesc, pyReaderSyncFunctionDesc);
}
private ChannelCreationParametersBuilder buildParameters(List<String> queues,
Map<String, BaseActor> actors,
JavaFunctionDescriptor javaAsyncFunctionDesc, JavaFunctionDescriptor javaSyncFunctionDesc,
PyFunctionDescriptor pyAsyncFunctionDesc, PyFunctionDescriptor pySyncFunctionDesc
) {
parameters = new ArrayList<>(queues.size());
for (String queue : queues) {
Parameter parameter = new Parameter();
BaseActor actor = actors.get(queue);
Preconditions.checkArgument(actor != null);
parameter.setActorId(actor.getId());
/// LocalModeRayActor used in single-process mode.
if (actor instanceof NativeRayJavaActor || actor instanceof LocalModeRayActor) {
parameter.setAsyncFunctionDescriptor(javaAsyncFunctionDesc);
parameter.setSyncFunctionDescriptor(javaSyncFunctionDesc);
} else if (actor instanceof NativeRayPyActor) {
parameter.setAsyncFunctionDescriptor(pyAsyncFunctionDesc);
parameter.setSyncFunctionDescriptor(pySyncFunctionDesc);
} else {
Preconditions.checkArgument(false, "Invalid actor type");
}
parameters.add(parameter);
}
return this;
}
// Called from jni
public List<Parameter> getParameters() {
return parameters;
}
public String toString() {
String str = "";
for (Parameter param : parameters) {
str += param.toString();
}
return str;
}
}
@@ -1,7 +1,7 @@
package io.ray.streaming.runtime.transfer;
import com.google.common.base.Preconditions;
import io.ray.api.id.ActorId;
import io.ray.api.BaseActor;
import io.ray.streaming.runtime.util.Platform;
import io.ray.streaming.util.Config;
import java.nio.ByteBuffer;
@@ -24,14 +24,14 @@ public class DataReader {
private Queue<DataMessage> buf = new LinkedList<>();
public DataReader(List<String> inputChannels,
List<ActorId> fromActors,
Map<String, BaseActor> fromActors,
Map<String, String> conf) {
Preconditions.checkArgument(inputChannels.size() > 0);
Preconditions.checkArgument(inputChannels.size() == fromActors.size());
ChannelCreationParametersBuilder initialParameters =
new ChannelCreationParametersBuilder().buildInputQueueParameters(inputChannels, fromActors);
byte[][] inputChannelsBytes = inputChannels.stream()
.map(ChannelID::idStrToBytes).toArray(byte[][]::new);
byte[][] fromActorsBytes = fromActors.stream()
.map(ActorId::getBytes).toArray(byte[][]::new);
long[] seqIds = new long[inputChannels.size()];
long[] msgIds = new long[inputChannels.size()];
for (int i = 0; i < inputChannels.size(); i++) {
@@ -48,8 +48,8 @@ public class DataReader {
boolean isRecreate = Boolean.parseBoolean(
conf.getOrDefault(Config.IS_RECREATE, "false"));
this.nativeReaderPtr = createDataReaderNative(
initialParameters,
inputChannelsBytes,
fromActorsBytes,
seqIds,
msgIds,
timerInterval,
@@ -155,8 +155,8 @@ public class DataReader {
}
private static native long createDataReaderNative(
ChannelCreationParametersBuilder initialParameters,
byte[][] inputChannels,
byte[][] inputActorIds,
long[] seqIds,
long[] msgIds,
long timerInterval,
@@ -1,7 +1,7 @@
package io.ray.streaming.runtime.transfer;
import com.google.common.base.Preconditions;
import io.ray.api.id.ActorId;
import io.ray.api.BaseActor;
import io.ray.streaming.runtime.util.Platform;
import io.ray.streaming.util.Config;
import java.nio.ByteBuffer;
@@ -33,14 +33,14 @@ public class DataWriter {
* @param conf configuration
*/
public DataWriter(List<String> outputChannels,
List<ActorId> toActors,
Map<String, BaseActor> toActors,
Map<String, String> conf) {
Preconditions.checkArgument(!outputChannels.isEmpty());
Preconditions.checkArgument(outputChannels.size() == toActors.size());
ChannelCreationParametersBuilder initialParameters =
new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors);
byte[][] outputChannelsBytes = outputChannels.stream()
.map(ChannelID::idStrToBytes).toArray(byte[][]::new);
byte[][] toActorsBytes = toActors.stream()
.map(ActorId::getBytes).toArray(byte[][]::new);
long channelSize = Long.parseLong(
conf.getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT));
long[] msgIds = new long[outputChannels.size()];
@@ -53,8 +53,8 @@ public class DataWriter {
isMock = true;
}
this.nativeWriterPtr = createWriterNative(
initialParameters,
outputChannelsBytes,
toActorsBytes,
msgIds,
channelSize,
ChannelUtils.toNativeConf(conf),
@@ -123,8 +123,8 @@ public class DataWriter {
}
private static native long createWriterNative(
ChannelCreationParametersBuilder initialParameters,
byte[][] outputQueueIds,
byte[][] outputActorIds,
long[] msgIds,
long channelSize,
byte[] confBytes,
@@ -1,8 +1,6 @@
package io.ray.streaming.runtime.transfer;
import io.ray.runtime.RayNativeRuntime;
import io.ray.runtime.functionmanager.FunctionDescriptor;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.runtime.util.JniUtils;
/**
@@ -23,12 +21,9 @@ public class TransferHandler {
private long writerClientNative;
private long readerClientNative;
public TransferHandler(JavaFunctionDescriptor writerAsyncFunc,
JavaFunctionDescriptor writerSyncFunc,
JavaFunctionDescriptor readerAsyncFunc,
JavaFunctionDescriptor readerSyncFunc) {
writerClientNative = createWriterClientNative(writerAsyncFunc, writerSyncFunc);
readerClientNative = createReaderClientNative(readerAsyncFunc, readerSyncFunc);
public TransferHandler() {
writerClientNative = createWriterClientNative();
readerClientNative = createReaderClientNative();
}
public void onWriterMessage(byte[] buffer) {
@@ -47,13 +42,10 @@ public class TransferHandler {
return handleReaderMessageSyncNative(readerClientNative, buffer);
}
private native long createWriterClientNative(
FunctionDescriptor asyncFunc,
FunctionDescriptor syncFunc);
private native long createWriterClientNative();
private native long createReaderClientNative();
private native long createReaderClientNative(
FunctionDescriptor asyncFunc,
FunctionDescriptor syncFunc);
private native void handleWriterMessageNative(long handler, byte[] buffer);
@@ -1,6 +1,5 @@
package io.ray.streaming.runtime.worker;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.streaming.runtime.core.graph.ExecutionGraph;
import io.ray.streaming.runtime.core.graph.ExecutionNode;
import io.ray.streaming.runtime.core.graph.ExecutionNode.NodeType;
@@ -16,8 +15,10 @@ import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask;
import io.ray.streaming.runtime.worker.tasks.SourceStreamTask;
import io.ray.streaming.runtime.worker.tasks.StreamTask;
import io.ray.streaming.util.Config;
import java.io.Serializable;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -52,17 +53,13 @@ public class JobWorker implements Serializable {
this.nodeType = executionNode.getNodeType();
this.streamProcessor = ProcessBuilder
.buildProcessor(executionNode.getStreamOperator());
.buildProcessor(executionNode.getStreamOperator());
LOGGER.debug("Initializing StreamWorker, taskId: {}, operator: {}.", taskId, streamProcessor);
String channelType = (String) this.config.getOrDefault(
Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
if (channelType.equals(Config.NATIVE_CHANNEL)) {
transferHandler = new TransferHandler(
new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessage", "([B)V"),
new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessageSync", "([B)[B"),
new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessage", "([B)V"),
new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessageSync", "([B)[B"));
transferHandler = new TransferHandler();
}
task = createStreamTask();
task.start();
@@ -2,7 +2,6 @@ package io.ray.streaming.runtime.worker.tasks;
import io.ray.api.BaseActor;
import io.ray.api.Ray;
import io.ray.api.id.ActorId;
import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.partition.Partition;
@@ -17,10 +16,12 @@ import io.ray.streaming.runtime.transfer.DataWriter;
import io.ray.streaming.runtime.worker.JobWorker;
import io.ray.streaming.runtime.worker.context.RayRuntimeContext;
import io.ray.streaming.util.Config;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -41,7 +42,7 @@ public abstract class StreamTask implements Runnable {
prepareTask();
this.thread = new Thread(Ray.wrapRunnable(this), this.getClass().getName()
+ "-" + System.currentTimeMillis());
+ "-" + System.currentTimeMillis());
this.thread.setDaemon(true);
}
@@ -64,22 +65,20 @@ public abstract class StreamTask implements Runnable {
List<ExecutionEdge> outputEdges = executionNode.getOutputEdges();
List<Collector> collectors = new ArrayList<>();
for (ExecutionEdge edge : outputEdges) {
Map<String, ActorId> outputActorIds = new HashMap<>();
Map<String, BaseActor> outputActors = new HashMap<>();
Map<Integer, BaseActor> taskId2Worker = executionGraph
.getTaskId2WorkerByNodeId(edge.getTargetNodeId());
taskId2Worker.forEach((targetTaskId, targetActor) -> {
String queueName = ChannelID.genIdStr(taskId, targetTaskId, executionGraph.getBuildTime());
outputActorIds.put(queueName, targetActor.getId());
outputActors.put(queueName, targetActor);
});
if (!outputActorIds.isEmpty()) {
if (!outputActors.isEmpty()) {
List<String> channelIDs = new ArrayList<>();
List<ActorId> toActorIds = new ArrayList<>();
outputActorIds.forEach((k, v) -> {
outputActors.forEach((k, v) -> {
channelIDs.add(k);
toActorIds.add(v);
});
DataWriter writer = new DataWriter(channelIDs, toActorIds, queueConf);
DataWriter writer = new DataWriter(channelIDs, outputActors, queueConf);
LOG.info("Create DataWriter succeed.");
writers.put(edge, writer);
Partition partition = edge.getPartition();
@@ -89,24 +88,22 @@ public abstract class StreamTask implements Runnable {
// consumer
List<ExecutionEdge> inputEdges = executionNode.getInputsEdges();
Map<String, ActorId> inputActorIds = new HashMap<>();
Map<String, BaseActor> inputActors = new HashMap<>();
for (ExecutionEdge edge : inputEdges) {
Map<Integer, BaseActor> taskId2Worker = executionGraph
.getTaskId2WorkerByNodeId(edge.getSrcNodeId());
taskId2Worker.forEach((srcTaskId, srcActor) -> {
String queueName = ChannelID.genIdStr(srcTaskId, taskId, executionGraph.getBuildTime());
inputActorIds.put(queueName, srcActor.getId());
inputActors.put(queueName, srcActor);
});
}
if (!inputActorIds.isEmpty()) {
if (!inputActors.isEmpty()) {
List<String> channelIDs = new ArrayList<>();
List<ActorId> fromActorIds = new ArrayList<>();
inputActorIds.forEach((k, v) -> {
inputActors.forEach((k, v) -> {
channelIDs.add(k);
fromActorIds.add(v);
});
LOG.info("Register queue consumer, queues {}.", channelIDs);
reader = new DataReader(channelIDs, fromActorIds, queueConf);
reader = new DataReader(channelIDs, inputActors, queueConf);
}
RuntimeContext runtimeContext = new RayRuntimeContext(
@@ -36,7 +36,6 @@ import org.testng.annotations.Test;
public class StreamingQueueTest extends BaseUnitTest implements Serializable {
private static Logger LOGGER = LoggerFactory.getLogger(StreamingQueueTest.class);
static {
EnvUtil.loadNativeLibraries();
}
@@ -62,7 +61,6 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
@BeforeMethod
void beforeMethod() {
LOGGER.info("beforeTest");
Ray.shutdown();
System.setProperty("ray.resources", "CPU:4,RES-A:4");
@@ -144,6 +142,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
@Test(timeOut = 60000)
public void testWordCount() {
LOGGER.info("testWordCount");
LOGGER.info("StreamingQueueTest.testWordCount run-mode: {}",
System.getProperty("ray.run-mode"));
String resultFile = "/tmp/io.ray.streaming.runtime.streamingqueue.testWordCount.txt";
@@ -1,10 +1,11 @@
package io.ray.streaming.runtime.streamingqueue;
import io.ray.api.BaseActor;
import io.ray.api.Ray;
import io.ray.api.RayActor;
import io.ray.api.id.ActorId;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.streaming.runtime.transfer.ChannelID;
import io.ray.streaming.runtime.transfer.ChannelCreationParametersBuilder;
import io.ray.streaming.runtime.transfer.DataMessage;
import io.ray.streaming.runtime.transfer.DataReader;
import io.ray.streaming.runtime.transfer.DataWriter;
@@ -12,7 +13,6 @@ import io.ray.streaming.runtime.transfer.TransferHandler;
import io.ray.streaming.util.Config;
import java.lang.management.ManagementFactory;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -27,15 +27,7 @@ public class Worker {
protected TransferHandler transferHandler = null;
public Worker() {
transferHandler = new TransferHandler(
new JavaFunctionDescriptor(Worker.class.getName(),
"onWriterMessage", "([B)V"),
new JavaFunctionDescriptor(Worker.class.getName(),
"onWriterMessageSync", "([B)[B"),
new JavaFunctionDescriptor(Worker.class.getName(),
"onReaderMessage", "([B)V"),
new JavaFunctionDescriptor(Worker.class.getName(),
"onReaderMessageSync", "([B)[B"));
transferHandler = new TransferHandler();
}
public void onReaderMessage(byte[] buffer) {
@@ -60,7 +52,7 @@ class ReaderWorker extends Worker {
private String name = null;
private List<String> inputQueueList = null;
private List<ActorId> inputActorIds = new ArrayList<>();
Map<String, BaseActor> fromActors = new HashMap<>();
private DataReader dataReader = null;
private long handler = 0;
private RayActor<WriterWorker> peerActor = null;
@@ -95,7 +87,7 @@ class ReaderWorker extends Worker {
LOGGER.info("java.library.path = {}", System.getProperty("java.library.path"));
for (String queue : this.inputQueueList) {
inputActorIds.add(this.peerActor.getId());
fromActors.put(queue, this.peerActor);
LOGGER.info("ReaderWorker actorId: {}", this.peerActor.getId());
}
@@ -104,7 +96,10 @@ class ReaderWorker extends Worker {
conf.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL);
conf.put(Config.CHANNEL_SIZE, "100000");
conf.put(Config.STREAMING_JOB_NAME, "integrationTest1");
dataReader = new DataReader(inputQueueList, inputActorIds, conf);
ChannelCreationParametersBuilder.setJavaWriterFunctionDesc(
new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessage", "([B)V"),
new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessageSync", "([B)[B"));
dataReader = new DataReader(inputQueueList, fromActors, conf);
// Should not GetBundle in RayCall thread
Thread readThread = new Thread(Ray.wrapRunnable(new Runnable() {
@@ -176,7 +171,7 @@ class WriterWorker extends Worker {
private String name = null;
private List<String> outputQueueList = null;
private List<ActorId> outputActorIds = new ArrayList<>();
Map<String, BaseActor> toActors = new HashMap<>();
DataWriter dataWriter = null;
RayActor<ReaderWorker> peerActor = null;
int msgCount = 0;
@@ -208,7 +203,7 @@ class WriterWorker extends Worker {
LOGGER.info("WriterWorker init:");
for (String queue : this.outputQueueList) {
outputActorIds.add(this.peerActor.getId());
toActors.put(queue, this.peerActor);
LOGGER.info("WriterWorker actorId: {}", this.peerActor.getId());
}
@@ -227,8 +222,10 @@ class WriterWorker extends Worker {
conf.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL);
conf.put(Config.CHANNEL_SIZE, "100000");
conf.put(Config.STREAMING_JOB_NAME, "integrationTest1");
dataWriter = new DataWriter(this.outputQueueList, this.outputActorIds, conf);
ChannelCreationParametersBuilder.setJavaReaderFunctionDesc(
new JavaFunctionDescriptor(Worker.class.getName(), "onReaderMessage", "([B)V"),
new JavaFunctionDescriptor(Worker.class.getName(), "onReaderMessageSync", "([B)[B"));
dataWriter = new DataWriter(this.outputQueueList, this.toActors, conf);
Thread writerThread = new Thread(Ray.wrapRunnable(new Runnable() {
@Override
public void run() {
+11 -6
View File
@@ -97,16 +97,21 @@ cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil:
void GetMessageListFromRawData(const uint8_t *data, uint32_t size, uint32_t msg_nums,
c_list[shared_ptr[CStreamingMessage]] &msg_list);
cdef extern from "channel.h" namespace "ray::streaming" nogil:
cdef struct CChannelCreationParameter "ray::streaming::ChannelCreationParameter":
CChannelCreationParameter()
CActorID actor_id;
shared_ptr[CRayFunction] async_function;
shared_ptr[CRayFunction] sync_function;
cdef extern from "queue/queue_client.h" namespace "ray::streaming" nogil:
cdef cppclass CReaderClient "ray::streaming::ReaderClient":
CReaderClient(CRayFunction &async_func,
CRayFunction &sync_func)
CReaderClient()
void OnReaderMessage(shared_ptr[CLocalMemoryBuffer] buffer);
shared_ptr[CLocalMemoryBuffer] OnReaderMessageSync(shared_ptr[CLocalMemoryBuffer] buffer);
cdef cppclass CWriterClient "ray::streaming::WriterClient":
CWriterClient(CRayFunction &async_func,
CRayFunction &sync_func)
CWriterClient()
void OnWriterMessage(shared_ptr[CLocalMemoryBuffer] buffer);
shared_ptr[CLocalMemoryBuffer] OnWriterMessageSync(shared_ptr[CLocalMemoryBuffer] buffer);
@@ -122,7 +127,7 @@ cdef extern from "data_reader.h" namespace "ray::streaming" nogil:
cdef cppclass CDataReader "ray::streaming::DataReader"(CStreamingCommon):
CDataReader(shared_ptr[CRuntimeContext] &runtime_context)
void Init(const c_vector[CObjectID] &input_ids,
const c_vector[CActorID] &actor_ids,
const c_vector[CChannelCreationParameter] &params,
const c_vector[uint64_t] &seq_ids,
const c_vector[uint64_t] &msg_ids,
int64_t timer_interval);
@@ -135,7 +140,7 @@ cdef extern from "data_writer.h" namespace "ray::streaming" nogil:
cdef cppclass CDataWriter "ray::streaming::DataWriter"(CStreamingCommon):
CDataWriter(shared_ptr[CRuntimeContext] &runtime_context)
CStreamingStatus Init(const c_vector[CObjectID] &channel_ids,
const c_vector[CActorID] &actor_ids,
const c_vector[CChannelCreationParameter] &params,
const c_vector[uint64_t] &message_ids,
const c_vector[uint64_t] &queue_size_vec);
long WriteMessageToBufferRing(
+42 -28
View File
@@ -10,6 +10,7 @@ from libcpp.list cimport list as c_list
from ray.includes.common cimport (
CRayFunction,
LANGUAGE_PYTHON,
LANGUAGE_JAVA,
CBuffer
)
@@ -36,27 +37,43 @@ from ray.streaming.includes.libstreaming cimport (
CReaderClient,
CWriterClient,
CLocalMemoryBuffer,
CChannelCreationParameter,
)
from ray._raylet import JavaFunctionDescriptor
import logging
channel_logger = logging.getLogger(__name__)
cdef class ChannelCreationParameter:
cdef:
CChannelCreationParameter parameter
def __cinit__(self, ActorID actor_id, FunctionDescriptor async_func, FunctionDescriptor sync_func):
cdef:
shared_ptr[CRayFunction] async_func_ptr
shared_ptr[CRayFunction] sync_func_ptr
self.parameter = CChannelCreationParameter()
self.parameter.actor_id = (<ActorID>actor_id).data
if isinstance(async_func, JavaFunctionDescriptor):
self.parameter.async_function = make_shared[CRayFunction](LANGUAGE_JAVA, async_func.descriptor)
else:
self.parameter.async_function = make_shared[CRayFunction](LANGUAGE_PYTHON, async_func.descriptor)
if isinstance(sync_func, JavaFunctionDescriptor):
self.parameter.sync_function = make_shared[CRayFunction](LANGUAGE_JAVA, sync_func.descriptor)
else:
self.parameter.sync_function = make_shared[CRayFunction](LANGUAGE_PYTHON, sync_func.descriptor)
cdef CChannelCreationParameter get_parameter(self):
return self.parameter
cdef class ReaderClient:
cdef:
CReaderClient *client
def __cinit__(self,
FunctionDescriptor async_func,
FunctionDescriptor sync_func):
cdef:
CRayFunction async_native_func
CRayFunction sync_native_func
async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor)
sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor)
self.client = new CReaderClient(async_native_func, sync_native_func)
def __cinit__(self):
self.client = new CReaderClient()
def __dealloc__(self):
del self.client
@@ -85,15 +102,8 @@ cdef class WriterClient:
cdef:
CWriterClient * client
def __cinit__(self,
FunctionDescriptor async_func,
FunctionDescriptor sync_func):
cdef:
CRayFunction async_native_func
CRayFunction sync_native_func
async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor)
sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor)
self.client = new CWriterClient(async_native_func, sync_native_func)
def __cinit__(self):
self.client = new CWriterClient()
def __dealloc__(self):
del self.client
@@ -127,19 +137,21 @@ cdef class DataWriter:
@staticmethod
def create(list py_output_channels,
list output_actor_ids: list[ActorID],
list output_creation_parameters: list[ChannelCreationParameter],
uint64_t queue_size,
list py_msg_ids,
bytes config_bytes,
c_bool is_mock):
cdef:
c_vector[CObjectID] channel_ids = bytes_list_to_qid_vec(py_output_channels)
c_vector[CActorID] actor_ids
c_vector[CChannelCreationParameter] initial_parameters
c_vector[uint64_t] msg_ids
CDataWriter *c_writer
ChannelCreationParameter parameter
cdef const unsigned char[:] config_data
for actor_id in output_actor_ids:
actor_ids.push_back((<ActorID>actor_id).data)
for param in output_creation_parameters:
parameter = param
initial_parameters.push_back(parameter.get_parameter())
for py_msg_id in py_msg_ids:
msg_ids.push_back(<uint64_t>py_msg_id)
@@ -156,7 +168,7 @@ cdef class DataWriter:
c_vector[uint64_t] queue_size_vec
for i in range(channel_ids.size()):
queue_size_vec.push_back(queue_size)
cdef CStreamingStatus status = c_writer.Init(channel_ids, actor_ids, msg_ids, queue_size_vec)
cdef CStreamingStatus status = c_writer.Init(channel_ids, initial_parameters, msg_ids, queue_size_vec)
if remain_id_vec.size() != 0:
channel_logger.warning("failed queue amounts => %s", remain_id_vec.size())
if <uint32_t>status != <uint32_t> libstreaming.StatusOK:
@@ -205,7 +217,7 @@ cdef class DataReader:
@staticmethod
def create(list py_input_queues,
list input_actor_ids: list[ActorID],
list input_creation_parameters: list[ChannelCreationParameter],
list py_seq_ids,
list py_msg_ids,
int64_t timer_interval,
@@ -214,13 +226,15 @@ cdef class DataReader:
c_bool is_mock):
cdef:
c_vector[CObjectID] queue_id_vec = bytes_list_to_qid_vec(py_input_queues)
c_vector[CActorID] actor_ids
c_vector[CChannelCreationParameter] initial_parameters
c_vector[uint64_t] seq_ids
c_vector[uint64_t] msg_ids
CDataReader *c_reader
ChannelCreationParameter parameter
cdef const unsigned char[:] config_data
for actor_id in input_actor_ids:
actor_ids.push_back((<ActorID>actor_id).data)
for param in input_creation_parameters:
parameter = param
initial_parameters.push_back(parameter.get_parameter())
for py_seq_id in py_seq_ids:
seq_ids.push_back(<uint64_t>py_seq_id)
for py_msg_id in py_msg_ids:
@@ -233,7 +247,7 @@ cdef class DataReader:
if is_mock:
ctx.get().MarkMockTest()
c_reader = new CDataReader(ctx)
c_reader.Init(queue_id_vec, actor_ids, seq_ids, msg_ids, timer_interval)
c_reader.Init(queue_id_vec, initial_parameters, seq_ids, msg_ids, timer_interval)
channel_logger.info("create native reader succeed")
cdef DataReader reader = DataReader.__new__(DataReader)
reader.reader = c_reader
+90 -9
View File
@@ -6,9 +6,11 @@ from typing import List
import ray
import ray.streaming._streaming as _streaming
import ray.streaming.generated.streaming_pb2 as streaming_pb
from ray import ActorID
from ray.actor import ActorHandle
from ray.streaming.config import Config
from ray._raylet import JavaFunctionDescriptor
from ray._raylet import PythonFunctionDescriptor
from ray._raylet import Language
CHANNEL_ID_LEN = 20
@@ -140,6 +142,85 @@ class DataMessage:
return self.__message_id
class ChannelCreationParametersBuilder:
"""
wrap initial parameters needed by a streaming queue
"""
_java_reader_async_function_descriptor = JavaFunctionDescriptor(
"org.ray.streaming.runtime.worker",
"onReaderMessage", "([B)V")
_java_reader_sync_function_descriptor = JavaFunctionDescriptor(
"org.ray.streaming.runtime.worker",
"onReaderMessageSync", "([B)[B")
_java_writer_async_function_descriptor = JavaFunctionDescriptor(
"org.ray.streaming.runtime.worker",
"onWriterMessage", "([B)V")
_java_writer_sync_function_descriptor = JavaFunctionDescriptor(
"org.ray.streaming.runtime.worker",
"onWriterMessageSync", "([B)[B")
_python_reader_async_function_descriptor = PythonFunctionDescriptor(
"ray.streaming.runtime.core.worker",
"on_reader_message", "JobWorker")
_python_reader_sync_function_descriptor = PythonFunctionDescriptor(
"ray.streaming.runtime.core.worker",
"on_reader_message_sync", "JobWorker")
_python_writer_async_function_descriptor = PythonFunctionDescriptor(
"ray.streaming.runtime.core.worker",
"on_writer_message", "JobWorker")
_python_writer_sync_function_descriptor = PythonFunctionDescriptor(
"ray.streaming.runtime.core.worker",
"on_writer_message_sync", "JobWorker")
def get_parameters(self):
return self._parameters
def __init__(self):
self._parameters = []
def build_input_queue_parameters(self, queue_ids_dict):
self.build_parameters(queue_ids_dict,
self._java_writer_async_function_descriptor,
self._java_writer_sync_function_descriptor,
self._python_writer_async_function_descriptor,
self._python_writer_sync_function_descriptor)
return self
def build_output_queue_parameters(self, to_actors):
self.build_parameters(to_actors,
self._java_reader_async_function_descriptor,
self._java_reader_sync_function_descriptor,
self._python_reader_async_function_descriptor,
self._python_reader_sync_function_descriptor)
return self
def build_parameters(self, actors, java_async_func,
java_sync_func, py_async_func, py_sync_func):
for handle in actors:
parameter = None
if handle._ray_actor_language == Language.PYTHON:
parameter = _streaming.ChannelCreationParameter(
handle._ray_actor_id, py_async_func, py_sync_func)
else:
parameter = _streaming.ChannelCreationParameter(
handle._ray_actor_id, java_async_func, java_sync_func)
self._parameters.append(parameter)
return self
@staticmethod
def set_python_writer_function_descriptor(async_function, sync_function):
ChannelCreationParametersBuilder.\
_python_writer_async_function_descriptor = async_function
ChannelCreationParametersBuilder.\
_python_writer_sync_function_descriptor = sync_function
@staticmethod
def set_python_reader_function_descriptor(async_function, sync_function):
ChannelCreationParametersBuilder.\
_python_reader_async_function_descriptor = async_function
ChannelCreationParametersBuilder.\
_python_reader_sync_function_descriptor = sync_function
logger = logging.getLogger(__name__)
@@ -161,16 +242,16 @@ class DataWriter:
py_output_channels = [
channel_id_str_to_bytes(qid_str) for qid_str in output_channels
]
output_actor_ids: List[ActorID] = [
handle._ray_actor_id for handle in to_actors
]
creation_parameters = ChannelCreationParametersBuilder()
creation_parameters.build_output_queue_parameters(to_actors)
channel_size = conf.get(Config.CHANNEL_SIZE,
Config.CHANNEL_SIZE_DEFAULT)
py_msg_ids = [0 for _ in range(len(output_channels))]
config_bytes = _to_native_conf(conf)
is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL
self.writer = _streaming.DataWriter.create(
py_output_channels, output_actor_ids, channel_size, py_msg_ids,
py_output_channels, creation_parameters.get_parameters(),
channel_size, py_msg_ids,
config_bytes, is_mock)
logger.info("create DataWriter succeed")
@@ -215,9 +296,8 @@ class DataReader:
py_input_channels = [
channel_id_str_to_bytes(qid_str) for qid_str in input_channels
]
input_actor_ids: List[ActorID] = [
handle._ray_actor_id for handle in from_actors
]
creation_parameters = ChannelCreationParametersBuilder()
creation_parameters.build_input_queue_parameters(from_actors)
py_seq_ids = [0 for _ in range(len(input_channels))]
py_msg_ids = [0 for _ in range(len(input_channels))]
timer_interval = int(conf.get(Config.TIMER_INTERVAL_MS, -1))
@@ -226,7 +306,8 @@ class DataReader:
self.__queue = Queue(10000)
is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL
self.reader = _streaming.DataReader.create(
py_input_channels, input_actor_ids, py_seq_ids, py_msg_ids,
py_input_channels, creation_parameters.get_parameters(),
py_seq_ids, py_msg_ids,
timer_interval, is_recreate, config_bytes, is_mock)
logger.info("create DataReader succeed")
+2 -17
View File
@@ -4,7 +4,6 @@ import ray
import ray.streaming._streaming as _streaming
import ray.streaming.generated.remote_call_pb2 as remote_call_pb
import ray.streaming.runtime.processor as processor
from ray._raylet import PythonFunctionDescriptor
from ray.streaming.config import Config
from ray.streaming.runtime.graph import ExecutionGraph
from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
@@ -48,22 +47,8 @@ class JobWorker(object):
self.task_id, self.stream_processor))
if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL):
reader_async_func = PythonFunctionDescriptor(
__name__, self.on_reader_message.__name__,
self.__class__.__name__)
reader_sync_func = PythonFunctionDescriptor(
__name__, self.on_reader_message_sync.__name__,
self.__class__.__name__)
self.reader_client = _streaming.ReaderClient(
reader_async_func, reader_sync_func)
writer_async_func = PythonFunctionDescriptor(
__name__, self.on_writer_message.__name__,
self.__class__.__name__)
writer_sync_func = PythonFunctionDescriptor(
__name__, self.on_writer_message_sync.__name__,
self.__class__.__name__)
self.writer_client = _streaming.WriterClient(
writer_async_func, writer_sync_func)
self.reader_client = _streaming.ReaderClient()
self.writer_client = _streaming.WriterClient()
self.task = self.create_stream_task()
self.task.start()
+18 -14
View File
@@ -12,20 +12,8 @@ from ray.streaming.config import Config
@ray.remote
class Worker:
def __init__(self):
writer_async_func = PythonFunctionDescriptor(
__name__, self.on_writer_message.__name__, self.__class__.__name__)
writer_sync_func = PythonFunctionDescriptor(
__name__, self.on_writer_message_sync.__name__,
self.__class__.__name__)
self.writer_client = _streaming.WriterClient(writer_async_func,
writer_sync_func)
reader_async_func = PythonFunctionDescriptor(
__name__, self.on_reader_message.__name__, self.__class__.__name__)
reader_sync_func = PythonFunctionDescriptor(
__name__, self.on_reader_message_sync.__name__,
self.__class__.__name__)
self.reader_client = _streaming.ReaderClient(reader_async_func,
reader_sync_func)
self.writer_client = _streaming.WriterClient()
self.reader_client = _streaming.ReaderClient()
self.writer = None
self.output_channel_id = None
self.reader = None
@@ -35,6 +23,14 @@ class Worker:
Config.TASK_JOB_ID: ray.worker.global_worker.current_job_id,
Config.CHANNEL_TYPE: Config.NATIVE_CHANNEL
}
reader_async_func = PythonFunctionDescriptor(
__name__, self.on_reader_message.__name__, self.__class__.__name__)
reader_sync_func = PythonFunctionDescriptor(
__name__, self.on_reader_message_sync.__name__,
self.__class__.__name__)
transfer.ChannelCreationParametersBuilder.\
set_python_reader_function_descriptor(
reader_async_func, reader_sync_func)
self.writer = transfer.DataWriter([output_channel],
[pickle.loads(reader_actor)], conf)
self.output_channel_id = transfer.ChannelID(output_channel)
@@ -44,6 +40,14 @@ class Worker:
Config.TASK_JOB_ID: ray.worker.global_worker.current_job_id,
Config.CHANNEL_TYPE: Config.NATIVE_CHANNEL
}
writer_async_func = PythonFunctionDescriptor(
__name__, self.on_writer_message.__name__, self.__class__.__name__)
writer_sync_func = PythonFunctionDescriptor(
__name__, self.on_writer_message_sync.__name__,
self.__class__.__name__)
transfer.ChannelCreationParametersBuilder.\
set_python_writer_function_descriptor(
writer_async_func, writer_sync_func)
self.reader = transfer.DataReader([input_channel],
[pickle.loads(writer_actor)], conf)
+10 -5
View File
@@ -60,9 +60,12 @@ StreamingStatus StreamingQueueProducer::CreateQueue() {
return StreamingStatus::OK;
}
upstream_handler->SetPeerActorID(channel_info_.channel_id, channel_info_.actor_id);
queue_ = upstream_handler->CreateUpstreamQueue(
channel_info_.channel_id, channel_info_.actor_id, channel_info_.queue_size);
upstream_handler->SetPeerActorID(
channel_info_.channel_id, channel_info_.parameter.actor_id,
*channel_info_.parameter.async_function, *channel_info_.parameter.sync_function);
queue_ = upstream_handler->CreateUpstreamQueue(channel_info_.channel_id,
channel_info_.parameter.actor_id,
channel_info_.queue_size);
STREAMING_CHECK(queue_ != nullptr);
std::vector<ObjectID> queue_ids, failed_queues;
@@ -154,11 +157,13 @@ StreamingStatus StreamingQueueConsumer::CreateTransferChannel() {
return StreamingStatus::OK;
}
downstream_handler->SetPeerActorID(channel_info_.channel_id, channel_info_.actor_id);
downstream_handler->SetPeerActorID(
channel_info_.channel_id, channel_info_.parameter.actor_id,
*channel_info_.parameter.async_function, *channel_info_.parameter.sync_function);
STREAMING_LOG(INFO) << "Create ReaderQueue " << channel_info_.channel_id
<< " pull from start_seq_id: " << channel_info_.current_seq_id + 1;
queue_ = downstream_handler->CreateDownstreamQueue(channel_info_.channel_id,
channel_info_.actor_id);
channel_info_.parameter.actor_id);
return StreamingStatus::OK;
}
+8 -2
View File
@@ -17,6 +17,12 @@ struct StreamingQueueInfo {
uint64_t consumed_seq_id = 0;
};
struct ChannelCreationParameter {
ActorID actor_id;
std::shared_ptr<ray::RayFunction> async_function;
std::shared_ptr<ray::RayFunction> sync_function;
};
/// PrducerChannelinfo and ConsumerChannelInfo contains channel information and
/// its metrics that help us to debug or show important messages in logging.
struct ProducerChannelInfo {
@@ -28,7 +34,7 @@ struct ProducerChannelInfo {
StreamingQueueInfo queue_info;
uint32_t queue_size;
int64_t message_pass_by_ts;
ActorID actor_id;
ChannelCreationParameter parameter;
/// The following parameters are used for event driven to record different
/// input events.
@@ -55,7 +61,7 @@ struct ConsumerChannelInfo {
uint64_t last_queue_item_latency = 0;
uint64_t last_queue_target_diff = 0;
uint64_t get_queue_item_times = 0;
ActorID actor_id;
ChannelCreationParameter parameter;
// Total count of notify request.
uint64_t notify_cnt = 0;
};
+5 -4
View File
@@ -17,11 +17,11 @@ namespace streaming {
const uint32_t DataReader::kReadItemTimeout = 1000;
void DataReader::Init(const std::vector<ObjectID> &input_ids,
const std::vector<ActorID> &actor_ids,
const std::vector<ChannelCreationParameter> &init_params,
const std::vector<uint64_t> &queue_seq_ids,
const std::vector<uint64_t> &streaming_msg_ids,
int64_t timer_interval) {
Init(input_ids, actor_ids, timer_interval);
Init(input_ids, init_params, timer_interval);
for (size_t i = 0; i < input_ids.size(); ++i) {
auto &q_id = input_ids[i];
channel_info_map_[q_id].current_seq_id = queue_seq_ids[i];
@@ -30,7 +30,8 @@ void DataReader::Init(const std::vector<ObjectID> &input_ids,
}
void DataReader::Init(const std::vector<ObjectID> &input_ids,
const std::vector<ActorID> &actor_ids, int64_t timer_interval) {
const std::vector<ChannelCreationParameter> &init_params,
int64_t timer_interval) {
STREAMING_LOG(INFO) << input_ids.size() << " queue to init.";
transfer_config_->Set(ConfigEnum::QUEUE_ID_VECTOR, input_ids);
@@ -47,7 +48,7 @@ void DataReader::Init(const std::vector<ObjectID> &input_ids,
STREAMING_LOG(INFO) << "[Reader] Init queue id: " << q_id;
auto &channel_info = channel_info_map_[q_id];
channel_info.channel_id = q_id;
channel_info.actor_id = actor_ids[i];
channel_info.parameter = init_params[i];
channel_info.last_queue_item_delay = 0;
channel_info.last_queue_item_latency = 0;
channel_info.last_queue_target_diff = 0;
+4 -2
View File
@@ -78,11 +78,13 @@ class DataReader {
/// \param channel_seq_ids
/// \param msg_ids
/// \param timer_interval
void Init(const std::vector<ObjectID> &input_ids, const std::vector<ActorID> &actor_ids,
void Init(const std::vector<ObjectID> &input_ids,
const std::vector<ChannelCreationParameter> &init_params,
const std::vector<uint64_t> &channel_seq_ids,
const std::vector<uint64_t> &msg_ids, int64_t timer_interval);
void Init(const std::vector<ObjectID> &input_ids, const std::vector<ActorID> &actor_ids,
void Init(const std::vector<ObjectID> &input_ids,
const std::vector<ChannelCreationParameter> &init_params,
int64_t timer_interval);
/// Get latest message from input queues.
+5 -4
View File
@@ -102,13 +102,14 @@ uint64_t DataWriter::WriteMessageToBufferRing(const ObjectID &q_id, uint8_t *dat
return write_message_id;
}
StreamingStatus DataWriter::InitChannel(const ObjectID &q_id, const ActorID &actor_id,
StreamingStatus DataWriter::InitChannel(const ObjectID &q_id,
const ChannelCreationParameter &param,
uint64_t channel_message_id,
uint64_t queue_size) {
ProducerChannelInfo &channel_info = channel_info_map_[q_id];
channel_info.current_message_id = channel_message_id;
channel_info.channel_id = q_id;
channel_info.actor_id = actor_id;
channel_info.parameter = param;
channel_info.queue_size = queue_size;
STREAMING_LOG(WARNING) << " Init queue [" << q_id << "]";
channel_info.writer_ring_buffer = std::make_shared<StreamingRingBuffer>(
@@ -129,7 +130,7 @@ StreamingStatus DataWriter::InitChannel(const ObjectID &q_id, const ActorID &act
}
StreamingStatus DataWriter::Init(const std::vector<ObjectID> &queue_id_vec,
const std::vector<ActorID> &actor_ids,
const std::vector<ChannelCreationParameter> &init_params,
const std::vector<uint64_t> &channel_message_id_vec,
const std::vector<uint64_t> &queue_size_vec) {
STREAMING_CHECK(!queue_id_vec.empty() && !channel_message_id_vec.empty());
@@ -144,7 +145,7 @@ StreamingStatus DataWriter::Init(const std::vector<ObjectID> &queue_id_vec,
transfer_config_->Set(ConfigEnum::QUEUE_ID_VECTOR, queue_id_vec);
for (size_t i = 0; i < queue_id_vec.size(); ++i) {
StreamingStatus status = InitChannel(queue_id_vec[i], actor_ids[i],
StreamingStatus status = InitChannel(queue_id_vec[i], init_params[i],
channel_message_id_vec[i], queue_size_vec[i]);
if (status != StreamingStatus::OK) {
return status;
+3 -2
View File
@@ -37,10 +37,11 @@ class DataWriter {
/// Streaming writer client initialization.
/// \param queue_id_vec queue id vector
/// \param init_params some parameters for initializing channels
/// \param channel_message_id_vec channel seq id is related with message checkpoint
/// \param queue_size queue size (memory size not length)
StreamingStatus Init(const std::vector<ObjectID> &channel_ids,
const std::vector<ActorID> &actor_ids,
const std::vector<ChannelCreationParameter> &init_params,
const std::vector<uint64_t> &channel_message_id_vec,
const std::vector<uint64_t> &queue_size_vec);
@@ -91,7 +92,7 @@ class DataWriter {
StreamingStatus WriteChannelProcess(ProducerChannelInfo &channel_info,
bool *is_empty_message);
StreamingStatus InitChannel(const ObjectID &q_id, const ActorID &actor_id,
StreamingStatus InitChannel(const ObjectID &q_id, const ChannelCreationParameter &param,
uint64_t channel_message_id, uint64_t queue_size);
/// Write all messages to channel util ringbuffer is empty.
@@ -9,13 +9,15 @@ using namespace ray::streaming;
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative(
JNIEnv *env, jclass, jobjectArray input_channels, jobjectArray input_actor_ids,
jlongArray seq_id_array, jlongArray msg_id_array, jlong timer_interval,
jboolean isRecreate, jbyteArray config_bytes, jboolean is_mock) {
JNIEnv *env, jclass, jobject streaming_queue_initial_parameters,
jobjectArray input_channels, jlongArray seq_id_array, jlongArray msg_id_array,
jlong timer_interval, jboolean isRecreate, jbyteArray config_bytes,
jboolean is_mock) {
STREAMING_LOG(INFO) << "[JNI]: create DataReader.";
std::vector<ray::streaming::ChannelCreationParameter> parameter_vec;
ParseChannelInitParameters(env, streaming_queue_initial_parameters, parameter_vec);
std::vector<ray::ObjectID> input_channels_ids =
jarray_to_object_id_vec(env, input_channels);
std::vector<ray::ActorID> actor_ids = jarray_to_actor_id_vec(env, input_actor_ids);
std::vector<uint64_t> seq_ids = LongVectorFromJLongArray(env, seq_id_array).data;
std::vector<uint64_t> msg_ids = LongVectorFromJLongArray(env, msg_id_array).data;
@@ -29,7 +31,7 @@ Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative(
ctx->MarkMockTest();
}
auto reader = new DataReader(ctx);
reader->Init(input_channels_ids, actor_ids, seq_ids, msg_ids, timer_interval);
reader->Init(input_channels_ids, parameter_vec, seq_ids, msg_ids, timer_interval);
STREAMING_LOG(INFO) << "create native DataReader succeed";
return reinterpret_cast<jlong>(reader);
}
@@ -72,17 +74,15 @@ JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBund
std::memcpy(meta + kMessageBundleHeaderSize, bundle->from.Data(), kUniqueIDSize);
}
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_DataReader_stopReaderNative(JNIEnv *env,
jobject thisObj,
jlong ptr) {
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_stopReaderNative(
JNIEnv *env, jobject thisObj, jlong ptr) {
auto reader = reinterpret_cast<DataReader *>(ptr);
reader->Stop();
}
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_DataReader_closeReaderNative(JNIEnv *env,
jobject thisObj,
jlong ptr) {
jobject thisObj,
jlong ptr) {
delete reinterpret_cast<DataReader *>(ptr);
}
@@ -10,34 +10,37 @@ extern "C" {
/*
* Class: io_ray_streaming_runtime_transfer_DataReader
* Method: createDataReaderNative
* Signature: ([[B[[B[J[JJZ[BZ)J
* Signature: (Lio/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder;[[B[J[JJZ[BZ)J
*/
JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative
(JNIEnv *, jclass, jobjectArray, jobjectArray, jlongArray, jlongArray, jlong, jboolean, jbyteArray, jboolean);
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative(
JNIEnv *, jclass, jobject, jobjectArray, jlongArray, jlongArray, jlong, jboolean,
jbyteArray, jboolean);
/*
* Class: io_ray_streaming_runtime_transfer_DataReader
* Method: getBundleNative
* Signature: (JJJJ)V
*/
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBundleNative
(JNIEnv *, jobject, jlong, jlong, jlong, jlong);
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBundleNative(
JNIEnv *, jobject, jlong, jlong, jlong, jlong);
/*
* Class: io_ray_streaming_runtime_transfer_DataReader
* Method: stopReaderNative
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_stopReaderNative
(JNIEnv *, jobject, jlong);
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_stopReaderNative(
JNIEnv *, jobject, jlong);
/*
* Class: io_ray_streaming_runtime_transfer_DataReader
* Method: closeReaderNative
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_closeReaderNative
(JNIEnv *, jobject, jlong);
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_DataReader_closeReaderNative(JNIEnv *, jobject,
jlong);
#ifdef __cplusplus
}
@@ -7,10 +7,13 @@ using namespace ray::streaming;
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative(
JNIEnv *env, jclass, jobjectArray output_queue_ids, jobjectArray output_actor_ids,
JNIEnv *env, jclass, jobject initial_parameters, jobjectArray output_queue_ids,
jlongArray msg_ids, jlong channel_size, jbyteArray conf_bytes_array,
jboolean is_mock) {
STREAMING_LOG(INFO) << "[JNI]: createDataWriterNative.";
std::vector<ray::streaming::ChannelCreationParameter> parameter_vec;
ParseChannelInitParameters(env, initial_parameters, parameter_vec);
std::vector<ray::ObjectID> queue_id_vec =
jarray_to_object_id_vec(env, output_queue_ids);
for (auto id : queue_id_vec) {
@@ -22,9 +25,6 @@ Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative(
std::vector<uint64_t> msg_ids_vec = LongVectorFromJLongArray(env, msg_ids).data;
std::vector<uint64_t> queue_size_vec(long_array_obj.data.size(), channel_size);
std::vector<ray::ObjectID> remain_id_vec;
std::vector<ray::ActorID> actor_ids = jarray_to_actor_id_vec(env, output_actor_ids);
STREAMING_LOG(INFO) << "actor_ids: " << actor_ids[0];
RawDataFromJByteArray conf(env, conf_bytes_array);
STREAMING_CHECK(conf.data != nullptr);
@@ -36,7 +36,8 @@ Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative(
runtime_context->MarkMockTest();
}
auto *data_writer = new DataWriter(runtime_context);
auto status = data_writer->Init(queue_id_vec, actor_ids, msg_ids_vec, queue_size_vec);
auto status =
data_writer->Init(queue_id_vec, parameter_vec, msg_ids_vec, queue_size_vec);
if (status != StreamingStatus::OK) {
STREAMING_LOG(WARNING) << "DataWriter init failed.";
} else {
@@ -64,10 +65,8 @@ Java_io_ray_streaming_runtime_transfer_DataWriter_writeMessageNative(
return result;
}
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative(JNIEnv *env,
jobject thisObj,
jlong ptr) {
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative(
JNIEnv *env, jobject thisObj, jlong ptr) {
STREAMING_LOG(INFO) << "jni: stop writer.";
auto *data_writer = reinterpret_cast<DataWriter *>(ptr);
data_writer->Stop();
@@ -75,8 +74,8 @@ Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative(JNIEnv *env,
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *env,
jobject thisObj,
jlong ptr) {
jobject thisObj,
jlong ptr) {
auto *data_writer = reinterpret_cast<DataWriter *>(ptr);
delete data_writer;
}
@@ -10,34 +10,38 @@ extern "C" {
/*
* Class: io_ray_streaming_runtime_transfer_DataWriter
* Method: createWriterNative
* Signature: ([[B[[B[JJ[BZ)J
* Signature: (Lio/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder;[[B[JJ[BZ)J
*/
JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative
(JNIEnv *, jclass, jobjectArray, jobjectArray, jlongArray, jlong, jbyteArray, jboolean);
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative(
JNIEnv *, jclass, jobject, jobjectArray, jlongArray, jlong, jbyteArray, jboolean);
/*
* Class: io_ray_streaming_runtime_transfer_DataWriter
* Method: writeMessageNative
* Signature: (JJJI)J
*/
JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_writeMessageNative
(JNIEnv *, jobject, jlong, jlong, jlong, jint);
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_DataWriter_writeMessageNative(JNIEnv *, jobject,
jlong, jlong, jlong,
jint);
/*
* Class: io_ray_streaming_runtime_transfer_DataWriter
* Method: stopWriterNative
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative
(JNIEnv *, jobject, jlong);
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative(
JNIEnv *, jobject, jlong);
/*
* Class: io_ray_streaming_runtime_transfer_DataWriter
* Method: closeWriterNative
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative
(JNIEnv *, jobject, jlong);
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *, jobject,
jlong);
#ifdef __cplusplus
}
@@ -14,29 +14,18 @@ static std::shared_ptr<ray::LocalMemoryBuffer> JByteArrayToBuffer(JNIEnv *env,
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(
JNIEnv *env, jobject this_obj, jobject async_func, jobject sync_func) {
auto ray_async_func = FunctionDescriptorToRayFunction(env, async_func);
auto ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func);
auto *writer_client = new WriterClient(ray_async_func, ray_sync_func);
JNIEnv *env, jobject this_obj) {
auto *writer_client = new WriterClient();
return reinterpret_cast<jlong>(writer_client);
}
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(
JNIEnv *env, jobject this_obj, jobject async_func, jobject sync_func) {
ray::RayFunction ray_async_func = FunctionDescriptorToRayFunction(env, async_func);
ray::RayFunction ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func);
auto *reader_client = new ReaderClient(ray_async_func, ray_sync_func);
JNIEnv *env, jobject this_obj) {
auto *reader_client = new ReaderClient();
return reinterpret_cast<jlong>(reader_client);
}
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative(
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
auto *writer_client = reinterpret_cast<WriterClient *>(ptr);
writer_client->OnWriterMessage(JByteArrayToBuffer(env, bytes));
}
JNIEXPORT jbyteArray JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
@@ -66,4 +55,4 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNa
env->SetByteArrayRegion(arr, 0, result_buffer->Size(),
reinterpret_cast<jbyte *>(result_buffer->Data()));
return arr;
}
}
@@ -12,22 +12,16 @@ extern "C" {
* Method: createWriterClientNative
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(JNIEnv *,
jobject,
jobject,
jobject);
JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative
(JNIEnv *, jobject);
/*
* Class: io_ray_streaming_runtime_transfer_TransferHandler
* Method: createReaderClientNative
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(JNIEnv *,
jobject,
jobject,
jobject);
JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative
(JNIEnv *, jobject);
/*
* Class: io_ray_streaming_runtime_transfer_TransferHandler
@@ -44,7 +38,7 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative
* Signature: (J[B)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
JNIEnv *, jobject, jlong, jbyteArray);
/*
@@ -62,10 +56,10 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative
* Signature: (J[B)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative(
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative(
JNIEnv *, jobject, jlong, jbyteArray);
#ifdef __cplusplus
}
#endif
#endif
#endif
+64 -5
View File
@@ -88,6 +88,15 @@ void JavaListToNativeVector(JNIEnv *env, jobject java_list,
}
}
/// Convert a Java byte array to a C++ UniqueID.
template <typename ID>
inline ID JavaByteArrayToId(JNIEnv *env, const jbyteArray &bytes) {
std::string id_str(ID::Size(), 0);
env->GetByteArrayRegion(bytes, 0, ID::Size(),
reinterpret_cast<jbyte *>(&id_str.front()));
return ID::FromBinary(id_str);
}
/// Convert a Java String to C++ std::string.
std::string JavaStringToNativeString(JNIEnv *env, jstring jstr) {
const char *c_str = env->GetStringUTFChars(jstr, nullptr);
@@ -105,10 +114,9 @@ void JavaStringListToNativeStringVector(JNIEnv *env, jobject java_list,
});
}
ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env,
jobject functionDescriptor) {
jclass java_language_class =
LoadClass(env, "io/ray/runtime/generated/Common$Language");
std::shared_ptr<ray::RayFunction> FunctionDescriptorToRayFunction(
JNIEnv *env, jobject functionDescriptor) {
jclass java_language_class = LoadClass(env, "io/ray/runtime/generated/Common$Language");
jclass java_function_descriptor_class =
LoadClass(env, "io/ray/runtime/functionmanager/FunctionDescriptor");
jmethodID java_language_get_number =
@@ -129,5 +137,56 @@ ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env,
ray::FunctionDescriptor function_descriptor =
ray::FunctionDescriptorBuilder::FromVector(language, function_descriptor_list);
ray::RayFunction ray_function{language, function_descriptor};
return ray_function;
return std::make_shared<ray::RayFunction>(ray_function);
}
void ParseChannelInitParameters(
JNIEnv *env, jobject param_obj,
std::vector<ray::streaming::ChannelCreationParameter> &parameter_vec) {
jclass java_streaming_queue_initial_parameters_class =
LoadClass(env,
"io/ray/streaming/runtime/transfer/"
"ChannelCreationParametersBuilder");
jmethodID java_streaming_queue_initial_parameters_getParameters_method =
env->GetMethodID(java_streaming_queue_initial_parameters_class, "getParameters",
"()Ljava/util/List;");
STREAMING_CHECK(java_streaming_queue_initial_parameters_getParameters_method !=
nullptr);
jclass java_streaming_queue_initial_parameters_parameter_class =
LoadClass(env,
"io/ray/streaming/runtime/transfer/"
"ChannelCreationParametersBuilder$Parameter");
jmethodID java_getActorIdBytes_method = env->GetMethodID(
java_streaming_queue_initial_parameters_parameter_class, "getActorIdBytes", "()[B");
jmethodID java_getAsyncFunctionDescriptor_method =
env->GetMethodID(java_streaming_queue_initial_parameters_parameter_class,
"getAsyncFunctionDescriptor",
"()Lio/ray/runtime/functionmanager/FunctionDescriptor;");
jmethodID java_getSyncFunctionDescriptor_method =
env->GetMethodID(java_streaming_queue_initial_parameters_parameter_class,
"getSyncFunctionDescriptor",
"()Lio/ray/runtime/functionmanager/FunctionDescriptor;");
// Call getParameters method
jobject parameter_list = env->CallObjectMethod(
param_obj, java_streaming_queue_initial_parameters_getParameters_method);
JavaListToNativeVector<ray::streaming::ChannelCreationParameter>(
env, parameter_list, &parameter_vec,
[java_getActorIdBytes_method, java_getAsyncFunctionDescriptor_method,
java_getSyncFunctionDescriptor_method](JNIEnv *env, jobject jobject_parameter) {
ray::streaming::ChannelCreationParameter native_parameter;
jbyteArray jobject_actor_id_bytes = (jbyteArray)env->CallObjectMethod(
jobject_parameter, java_getActorIdBytes_method);
native_parameter.actor_id =
JavaByteArrayToId<ray::ActorID>(env, jobject_actor_id_bytes);
jobject jobject_async_func = env->CallObjectMethod(
jobject_parameter, java_getAsyncFunctionDescriptor_method);
native_parameter.async_function =
FunctionDescriptorToRayFunction(env, jobject_async_func);
jobject jobject_sync_func = env->CallObjectMethod(
jobject_parameter, java_getSyncFunctionDescriptor_method);
native_parameter.sync_function =
FunctionDescriptorToRayFunction(env, jobject_sync_func);
return native_parameter;
});
}
+13 -17
View File
@@ -3,6 +3,7 @@
#include <jni.h>
#include <string>
#include "channel.h"
#include "ray/core_worker/common.h"
#include "util/streaming_logging.h"
@@ -21,12 +22,10 @@ class UniqueIdFromJByteArray {
b = reinterpret_cast<jbyte *>(_env->GetByteArrayElements(_bytes, nullptr));
PID = ray::ObjectID::FromBinary(
std::string(reinterpret_cast<const char*>(b), ray::ObjectID::Size()));
std::string(reinterpret_cast<const char *>(b), ray::ObjectID::Size()));
}
~UniqueIdFromJByteArray() {
_env->ReleaseByteArrayElements(_bytes, b, 0);
}
~UniqueIdFromJByteArray() { _env->ReleaseByteArrayElements(_bytes, b, 0); }
};
class RawDataFromJByteArray {
@@ -42,15 +41,13 @@ class RawDataFromJByteArray {
_env = env;
_bytes = bytes;
data_size = _env->GetArrayLength(_bytes);
jbyte *b =
reinterpret_cast<jbyte *>(_env->GetByteArrayElements(_bytes, nullptr));
jbyte *b = reinterpret_cast<jbyte *>(_env->GetByteArrayElements(_bytes, nullptr));
data = reinterpret_cast<uint8_t *>(b);
}
~RawDataFromJByteArray() {
_env->ReleaseByteArrayElements(_bytes, reinterpret_cast<jbyte *>(data), 0);
}
};
class StringFromJString {
@@ -69,10 +66,7 @@ class StringFromJString {
str = std::string(j_str);
}
~StringFromJString() {
_env->ReleaseStringUTFChars(jni_str, j_str);
}
~StringFromJString() { _env->ReleaseStringUTFChars(jni_str, j_str); }
};
class LongVectorFromJLongArray {
@@ -98,14 +92,16 @@ class LongVectorFromJLongArray {
}
};
std::vector<ray::ObjectID>
jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr);
std::vector<ray::ActorID>
jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr);
std::vector<ray::ObjectID> jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr);
std::vector<ray::ActorID> jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr);
jint throwRuntimeException(JNIEnv *env, const char *message);
jint throwChannelInitException(JNIEnv *env, const char *message,
const std::vector<ray::ObjectID> &abnormal_queues);
jint throwChannelInterruptException(JNIEnv *env, const char *message);
ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env, jobject functionDescriptor);
#endif //RAY_STREAMING_JNI_COMMON_H
std::shared_ptr<ray::RayFunction> FunctionDescriptorToRayFunction(
JNIEnv *env, jobject functionDescriptor);
void ParseChannelInitParameters(
JNIEnv *env, jobject param_obj,
std::vector<ray::streaming::ChannelCreationParameter> &parameter_vec);
#endif // RAY_STREAMING_JNI_COMMON_H
+2 -3
View File
@@ -135,8 +135,7 @@ void WriterQueue::Send() {
item.IsRaw());
std::unique_ptr<LocalMemoryBuffer> buffer = msg.ToBytes();
STREAMING_CHECK(transport_ != nullptr);
transport_->Send(std::move(buffer),
DownstreamQueueMessageHandler::peer_async_function_);
transport_->Send(std::move(buffer));
}
}
@@ -188,7 +187,7 @@ void ReaderQueue::Notify(uint64_t seq_id) {
NotificationMessage msg(actor_id_, peer_actor_id_, queue_id_, seq_id);
std::unique_ptr<LocalMemoryBuffer> buffer = msg.ToBytes();
transport_->Send(std::move(buffer), UpstreamQueueMessageHandler::peer_async_function_);
transport_->Send(std::move(buffer));
}
void ReaderQueue::CreateNotifyTask(uint64_t seq_id, std::vector<TaskArg> &task_args) {}
+2 -6
View File
@@ -17,9 +17,7 @@ class ReaderClient {
/// \param[in] async_func DataReader's raycall function descriptor to be called by
/// DataWriter, asynchronous semantics \param[in] sync_func DataReader's raycall
/// function descriptor to be called by DataWriter, synchronous semantics
ReaderClient(RayFunction &async_func, RayFunction &sync_func) {
DownstreamQueueMessageHandler::peer_async_function_ = async_func;
DownstreamQueueMessageHandler::peer_sync_function_ = sync_func;
ReaderClient() {
downstream_handler_ = ray::streaming::DownstreamQueueMessageHandler::CreateService(
CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID());
}
@@ -38,9 +36,7 @@ class ReaderClient {
/// Interface of streaming queue for DataWriter. Similar to ReaderClient.
class WriterClient {
public:
WriterClient(RayFunction &async_func, RayFunction &sync_func) {
UpstreamQueueMessageHandler::peer_async_function_ = async_func;
UpstreamQueueMessageHandler::peer_sync_function_ = sync_func;
WriterClient() {
upstream_handler_ = ray::streaming::UpstreamQueueMessageHandler::CreateService(
CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID());
}
+5 -10
View File
@@ -12,11 +12,6 @@ std::shared_ptr<UpstreamQueueMessageHandler>
std::shared_ptr<DownstreamQueueMessageHandler>
DownstreamQueueMessageHandler::downstream_handler_ = nullptr;
RayFunction UpstreamQueueMessageHandler::peer_sync_function_;
RayFunction UpstreamQueueMessageHandler::peer_async_function_;
RayFunction DownstreamQueueMessageHandler::peer_sync_function_;
RayFunction DownstreamQueueMessageHandler::peer_async_function_;
std::shared_ptr<Message> QueueMessageHandler::ParseMessage(
std::shared_ptr<LocalMemoryBuffer> buffer) {
uint8_t *bytes = buffer->Data();
@@ -83,10 +78,11 @@ std::shared_ptr<Transport> QueueMessageHandler::GetOutTransport(
}
void QueueMessageHandler::SetPeerActorID(const ObjectID &queue_id,
const ActorID &actor_id) {
const ActorID &actor_id, RayFunction &async_func,
RayFunction &sync_func) {
actors_.emplace(queue_id, actor_id);
out_transports_.emplace(queue_id,
std::make_shared<ray::streaming::Transport>(actor_id));
out_transports_.emplace(queue_id, std::make_shared<ray::streaming::Transport>(
actor_id, async_func, sync_func));
}
ActorID QueueMessageHandler::GetPeerActorID(const ObjectID &queue_id) {
@@ -164,8 +160,7 @@ bool UpstreamQueueMessageHandler::CheckQueueSync(const ObjectID &queue_id) {
auto transport_it = GetOutTransport(queue_id);
STREAMING_CHECK(transport_it != nullptr);
std::shared_ptr<LocalMemoryBuffer> result_buffer = transport_it->SendForResultWithRetry(
std::move(buffer), DownstreamQueueMessageHandler::peer_sync_function_, 10,
COMMON_SYNC_CALL_TIMEOUTT_MS);
std::move(buffer), 10, COMMON_SYNC_CALL_TIMEOUTT_MS);
if (result_buffer == nullptr) {
return false;
}
+2 -6
View File
@@ -62,7 +62,8 @@ class QueueMessageHandler {
/// downstream queue with same queue_id, and vice versa.
/// \param[in] queue_id queue id of current queue.
/// \param[in] actor_id actor_id actor id of corresponded peer actor.
void SetPeerActorID(const ObjectID &queue_id, const ActorID &actor_id);
void SetPeerActorID(const ObjectID &queue_id, const ActorID &actor_id,
RayFunction &async_func, RayFunction &sync_func);
/// Obtain the actor id of the peer actor specified by queue_id.
/// \return actor id
@@ -133,9 +134,6 @@ class UpstreamQueueMessageHandler : public QueueMessageHandler {
const ActorID &actor_id);
static std::shared_ptr<UpstreamQueueMessageHandler> GetService();
static RayFunction peer_sync_function_;
static RayFunction peer_async_function_;
private:
bool CheckQueueSync(const ObjectID &queue_ids);
@@ -170,8 +168,6 @@ class DownstreamQueueMessageHandler : public QueueMessageHandler {
static std::shared_ptr<DownstreamQueueMessageHandler> CreateService(
const ActorID &actor_id);
static std::shared_ptr<DownstreamQueueMessageHandler> GetService();
static RayFunction peer_sync_function_;
static RayFunction peer_async_function_;
private:
std::unordered_map<ObjectID, std::shared_ptr<streaming::ReaderQueue>>
+7 -10
View File
@@ -36,17 +36,16 @@ void Transport::SendInternal(std::shared_ptr<LocalMemoryBuffer> buffer,
}
}
void Transport::Send(std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function) {
void Transport::Send(std::shared_ptr<LocalMemoryBuffer> buffer) {
STREAMING_LOG(INFO) << "Transport::Send buffer size: " << buffer->Size();
std::vector<ObjectID> return_ids;
SendInternal(std::move(buffer), function, TASK_OPTION_RETURN_NUM_0, return_ids);
SendInternal(std::move(buffer), async_func_, TASK_OPTION_RETURN_NUM_0, return_ids);
}
std::shared_ptr<LocalMemoryBuffer> Transport::SendForResult(
std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function,
int64_t timeout_ms) {
std::shared_ptr<LocalMemoryBuffer> buffer, int64_t timeout_ms) {
std::vector<ObjectID> return_ids;
SendInternal(buffer, function, TASK_OPTION_RETURN_NUM_1, return_ids);
SendInternal(buffer, sync_func_, TASK_OPTION_RETURN_NUM_1, return_ids);
std::vector<std::shared_ptr<RayObject>> results;
Status get_st =
@@ -73,14 +72,12 @@ std::shared_ptr<LocalMemoryBuffer> Transport::SendForResult(
}
std::shared_ptr<LocalMemoryBuffer> Transport::SendForResultWithRetry(
std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function, int retry_cnt,
int64_t timeout_ms) {
std::shared_ptr<LocalMemoryBuffer> buffer, int retry_cnt, int64_t timeout_ms) {
STREAMING_LOG(INFO) << "SendForResultWithRetry retry_cnt: " << retry_cnt
<< " timeout_ms: " << timeout_ms
<< " function: " << function.GetFunctionDescriptor()->ToString();
<< " timeout_ms: " << timeout_ms;
std::shared_ptr<LocalMemoryBuffer> buffer_shared = std::move(buffer);
for (int cnt = 0; cnt < retry_cnt; cnt++) {
auto result = SendForResult(buffer_shared, function, timeout_ms);
auto result = SendForResult(buffer_shared, timeout_ms);
if (result != nullptr) {
return result;
}
+18 -11
View File
@@ -14,34 +14,39 @@ class Transport {
public:
/// Construct a Transport object.
/// \param[in] peer_actor_id actor id of peer actor.
Transport(const ActorID &peer_actor_id)
: worker_id_(CoreWorkerProcess::GetCoreWorker().GetWorkerID()),
peer_actor_id_(peer_actor_id) {}
Transport(const ActorID &peer_actor_id, RayFunction &async_func, RayFunction &sync_func)
: peer_actor_id_(peer_actor_id), async_func_(async_func), sync_func_(sync_func) {
STREAMING_LOG(INFO) << "Transport constructor:";
STREAMING_LOG(INFO) << "async_func lang: " << async_func_.GetLanguage();
STREAMING_LOG(INFO) << "async_func: "
<< async_func_.GetFunctionDescriptor()->ToString();
STREAMING_LOG(INFO) << "sync_func lang: " << sync_func_.GetLanguage();
STREAMING_LOG(INFO) << "sync_func: "
<< sync_func_.GetFunctionDescriptor()->ToString();
}
virtual ~Transport() = default;
/// Send buffer asynchronously, peer's `function` will be called.
/// \param[in] buffer buffer to be sent.
/// \param[in] function the function descriptor of peer's function.
virtual void Send(std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function);
virtual void Send(std::shared_ptr<LocalMemoryBuffer> buffer);
/// Send buffer synchronously, peer's `function` will be called, and return the peer
/// function's return value.
/// \param[in] buffer buffer to be sent.
/// \param[in] function the function descriptor of peer's function.
/// \param[in] timeout_ms max time to wait for result.
/// \return peer function's result.
virtual std::shared_ptr<LocalMemoryBuffer> SendForResult(
std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function,
int64_t timeout_ms);
std::shared_ptr<LocalMemoryBuffer> buffer, int64_t timeout_ms);
/// Send buffer and get result with retry.
/// return value.
/// \param[in] buffer buffer to be sent.
/// \param[in] function the function descriptor of peer's function.
/// \param[in] max retry count
/// \param[in] timeout_ms max time to wait for result.
/// \return peer function's result.
std::shared_ptr<LocalMemoryBuffer> SendForResultWithRetry(
std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function, int retry_cnt,
int64_t timeout_ms);
std::shared_ptr<LocalMemoryBuffer> buffer, int retry_cnt, int64_t timeout_ms);
private:
/// Send buffer internal
@@ -56,6 +61,8 @@ class Transport {
private:
WorkerID worker_id_;
ActorID peer_actor_id_;
RayFunction async_func_;
RayFunction sync_func_;
};
} // namespace streaming
} // namespace ray
+30 -25
View File
@@ -94,8 +94,18 @@ class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite {
for (auto &queue_id : queue_ids_) {
STREAMING_LOG(INFO) << "queue_id: " << queue_id;
}
std::vector<ActorID> actor_ids(queue_ids_.size(), peer_actor_id_);
STREAMING_LOG(INFO) << "writer actor_ids size: " << actor_ids.size()
ChannelCreationParameter param{
peer_actor_id_,
std::make_shared<RayFunction>(
ray::Language::PYTHON,
ray::FunctionDescriptorBuilder::FromVector(
ray::Language::PYTHON, {"", "", "reader_async_call_func", ""})),
std::make_shared<RayFunction>(
ray::Language::PYTHON,
ray::FunctionDescriptorBuilder::FromVector(
ray::Language::PYTHON, {"", "", "reader_sync_call_func", ""}))};
std::vector<ChannelCreationParameter> params(queue_ids_.size(), param);
STREAMING_LOG(INFO) << "writer actor_ids size: " << params.size()
<< " actor_id: " << peer_actor_id_;
std::shared_ptr<RuntimeContext> runtime_context(new RuntimeContext());
@@ -104,7 +114,7 @@ class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite {
std::shared_ptr<DataWriter> streaming_writer_client(new DataWriter(runtime_context));
uint64_t queue_size = 10 * 1000 * 1000;
std::vector<uint64_t> channel_seq_id_vec(queue_ids_.size(), 0);
streaming_writer_client->Init(queue_ids_, actor_ids, channel_seq_id_vec,
streaming_writer_client->Init(queue_ids_, params, channel_seq_id_vec,
std::vector<uint64_t>(queue_ids_.size(), queue_size));
STREAMING_LOG(INFO) << "streaming_writer_client Init done";
@@ -214,14 +224,24 @@ class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite {
}
void StreamingReaderStrategyTest(StreamingConfig &config) {
std::vector<ActorID> actor_ids(queue_ids_.size(), peer_actor_id_);
STREAMING_LOG(INFO) << "reader actor_ids size: " << actor_ids.size()
ChannelCreationParameter param{
peer_actor_id_,
std::make_shared<RayFunction>(
ray::Language::PYTHON,
ray::FunctionDescriptorBuilder::FromVector(
ray::Language::PYTHON, {"", "", "writer_async_call_func", ""})),
std::make_shared<RayFunction>(
ray::Language::PYTHON,
ray::FunctionDescriptorBuilder::FromVector(
ray::Language::PYTHON, {"", "", "writer_sync_call_func", ""}))};
std::vector<ChannelCreationParameter> params(queue_ids_.size(), param);
STREAMING_LOG(INFO) << "reader actor_ids size: " << params.size()
<< " actor_id: " << peer_actor_id_;
std::shared_ptr<RuntimeContext> runtime_context(new RuntimeContext());
runtime_context->SetConfig(config);
std::shared_ptr<DataReader> reader(new DataReader(runtime_context));
reader->Init(queue_ids_, actor_ids, -1);
reader->Init(queue_ids_, params, -1);
ReaderLoopForward(reader, nullptr, queue_ids_);
STREAMING_LOG(INFO) << "Reader exit";
@@ -298,23 +318,8 @@ class StreamingWorker {
};
CoreWorkerProcess::Initialize(options);
RayFunction reader_async_call_func{ray::Language::PYTHON,
ray::FunctionDescriptorBuilder::BuildPython(
"reader_async_call_func", "", "", "")};
RayFunction reader_sync_call_func{
ray::Language::PYTHON,
ray::FunctionDescriptorBuilder::BuildPython("reader_sync_call_func", "", "", "")};
RayFunction writer_async_call_func{ray::Language::PYTHON,
ray::FunctionDescriptorBuilder::BuildPython(
"writer_async_call_func", "", "", "")};
RayFunction writer_sync_call_func{
ray::Language::PYTHON,
ray::FunctionDescriptorBuilder::BuildPython("writer_sync_call_func", "", "", "")};
reader_client_ =
std::make_shared<ReaderClient>(reader_async_call_func, reader_sync_call_func);
writer_client_ =
std::make_shared<WriterClient>(writer_async_call_func, writer_sync_call_func);
reader_client_ = std::make_shared<ReaderClient>();
writer_client_ = std::make_shared<WriterClient>();
STREAMING_LOG(INFO) << "StreamingWorker constructor";
}
@@ -338,9 +343,9 @@ class StreamingWorker {
ray::FunctionDescriptorType::kPythonFunctionDescriptor);
auto typed_descriptor = function_descriptor->As<ray::PythonFunctionDescriptor>();
STREAMING_LOG(INFO) << "StreamingWorker::ExecuteTask "
<< typed_descriptor->ModuleName();
<< typed_descriptor->ToString();
std::string func_name = typed_descriptor->ModuleName();
std::string func_name = typed_descriptor->FunctionName();
if (func_name == "init") {
std::shared_ptr<LocalMemoryBuffer> local_buffer =
std::make_shared<LocalMemoryBuffer>(args[0]->GetData()->Data(),
+3 -5
View File
@@ -51,11 +51,9 @@ class StreamingTransferTest : public ::testing::Test {
}
std::vector<uint64_t> channel_id_vec(queue_vec.size(), 0);
std::vector<uint64_t> queue_size_vec(queue_vec.size(), 10000);
// actor ids are not used in this test, so we can just use Nil.
std::vector<ActorID> actor_id_vec(queue_vec.size(),
ActorID::NilFromJob(JobID::FromInt(0)));
writer->Init(queue_vec, actor_id_vec, channel_id_vec, queue_size_vec);
reader->Init(queue_vec, actor_id_vec, channel_id_vec, queue_size_vec, -1);
std::vector<ChannelCreationParameter> params(queue_vec.size());
writer->Init(queue_vec, params, channel_id_vec, queue_size_vec);
reader->Init(queue_vec, params, channel_id_vec, queue_size_vec, -1);
}
void DestroyTransfer() {
writer.reset();
+4 -4
View File
@@ -175,7 +175,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
TaskOptions options{0, resources};
std::vector<ObjectID> return_ids;
RayFunction func{ray::Language::PYTHON,
ray::FunctionDescriptorBuilder::BuildPython("init", "", "", "")};
ray::FunctionDescriptorBuilder::BuildPython("", "", "init", "")};
RAY_CHECK_OK(driver.SubmitActorTask(self_actor_id, func, args, options, &return_ids));
}
@@ -191,7 +191,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
TaskOptions options{0, resources};
std::vector<ObjectID> return_ids;
RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython(
"execute_test", test, "", "")};
"", test, "execute_test", "")};
RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids));
}
@@ -207,7 +207,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
TaskOptions options{1, resources};
std::vector<ObjectID> return_ids;
RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython(
"check_current_test_status", "", "", "")};
"", "", "check_current_test_status", "")};
RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids));
@@ -267,7 +267,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
auto buffer = std::make_shared<LocalMemoryBuffer>(array, sizeof(array));
RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython(
"actor creation task", "", "", "")};
"", "", "actor creation task", "")};
std::vector<TaskArg> args;
args.emplace_back(TaskArg::PassByValue(
std::make_shared<RayObject>(buffer, nullptr, std::vector<ObjectID>())));