diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder.java new file mode 100644 index 000000000..b152ca3b7 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder.java @@ -0,0 +1,161 @@ +package io.ray.streaming.runtime.transfer; + +import com.google.common.base.Preconditions; +import io.ray.api.BaseActor; +import io.ray.api.id.ActorId; +import io.ray.runtime.actor.LocalModeRayActor; +import io.ray.runtime.actor.NativeRayJavaActor; +import io.ray.runtime.actor.NativeRayPyActor; +import io.ray.runtime.functionmanager.FunctionDescriptor; +import io.ray.runtime.functionmanager.JavaFunctionDescriptor; +import io.ray.runtime.functionmanager.PyFunctionDescriptor; +import io.ray.streaming.runtime.worker.JobWorker; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Save channel initial parameters needed by DataWriter/DataReader. + */ +public class ChannelCreationParametersBuilder { + + public class Parameter { + + private ActorId actorId; + private FunctionDescriptor asyncFunctionDescriptor; + private FunctionDescriptor syncFunctionDescriptor; + + public void setActorId(ActorId actorId) { + this.actorId = actorId; + } + + public void setAsyncFunctionDescriptor( + FunctionDescriptor asyncFunctionDescriptor) { + this.asyncFunctionDescriptor = asyncFunctionDescriptor; + } + + public void setSyncFunctionDescriptor( + FunctionDescriptor syncFunctionDescriptor) { + this.syncFunctionDescriptor = syncFunctionDescriptor; + } + + public String toString() { + String language = + asyncFunctionDescriptor instanceof JavaFunctionDescriptor ? "Java" : "Python"; + return "Language: " + language + " Desc: " + asyncFunctionDescriptor.toList() + " " + + syncFunctionDescriptor.toList(); + } + + // Get actor id in bytes, called from jni. + public byte[] getActorIdBytes() { + return actorId.getBytes(); + } + + // Get async function descriptor, called from jni. + public FunctionDescriptor getAsyncFunctionDescriptor() { + return asyncFunctionDescriptor; + } + + // Get sync function descriptor, called from jni. + public FunctionDescriptor getSyncFunctionDescriptor() { + return syncFunctionDescriptor; + } + } + + private List parameters; + + // function descriptors of direct call entry point for Java workers + private static JavaFunctionDescriptor javaReaderAsyncFuncDesc = new JavaFunctionDescriptor( + JobWorker.class.getName(), + "onReaderMessage", "([B)V"); + private static JavaFunctionDescriptor javaReaderSyncFuncDesc = new JavaFunctionDescriptor( + JobWorker.class.getName(), + "onReaderMessageSync", "([B)[B"); + private static JavaFunctionDescriptor javaWriterAsyncFuncDesc = new JavaFunctionDescriptor( + JobWorker.class.getName(), + "onWriterMessage", "([B)V"); + private static JavaFunctionDescriptor javaWriterSyncFuncDesc = new JavaFunctionDescriptor( + JobWorker.class.getName(), + "onWriterMessageSync", "([B)[B"); + // function descriptors of direct call entry point for Python workers + private static PyFunctionDescriptor pyReaderAsyncFunctionDesc = new PyFunctionDescriptor( + "ray.streaming.runtime.worker", + "JobWorker", "on_reader_message"); + private static PyFunctionDescriptor pyReaderSyncFunctionDesc = new PyFunctionDescriptor( + "ray.streaming.runtime.worker", + "JobWorker", "on_reader_message_sync"); + private static PyFunctionDescriptor pyWriterAsyncFunctionDesc = new PyFunctionDescriptor( + "ray.streaming.runtime.worker", + "JobWorker", "on_writer_message"); + private static PyFunctionDescriptor pyWriterSyncFunctionDesc = new PyFunctionDescriptor( + "ray.streaming.runtime.worker", + "JobWorker", "on_writer_message_sync"); + + public ChannelCreationParametersBuilder() { + } + + public static void setJavaReaderFunctionDesc(JavaFunctionDescriptor asyncFunc, + JavaFunctionDescriptor syncFunc) { + javaReaderAsyncFuncDesc = asyncFunc; + javaReaderSyncFuncDesc = syncFunc; + } + + public static void setJavaWriterFunctionDesc(JavaFunctionDescriptor asyncFunc, + JavaFunctionDescriptor syncFunc) { + javaWriterAsyncFuncDesc = asyncFunc; + javaWriterSyncFuncDesc = syncFunc; + } + + public ChannelCreationParametersBuilder buildInputQueueParameters(List queues, + Map actors) { + return buildParameters(queues, actors, javaWriterAsyncFuncDesc, javaWriterSyncFuncDesc, + pyWriterAsyncFunctionDesc, pyWriterSyncFunctionDesc); + } + + public ChannelCreationParametersBuilder buildOutputQueueParameters(List queues, + Map actors) { + return buildParameters(queues, actors, javaReaderAsyncFuncDesc, javaReaderSyncFuncDesc, + pyReaderAsyncFunctionDesc, pyReaderSyncFunctionDesc); + } + + private ChannelCreationParametersBuilder buildParameters(List queues, + Map actors, + JavaFunctionDescriptor javaAsyncFunctionDesc, JavaFunctionDescriptor javaSyncFunctionDesc, + PyFunctionDescriptor pyAsyncFunctionDesc, PyFunctionDescriptor pySyncFunctionDesc + ) { + parameters = new ArrayList<>(queues.size()); + for (String queue : queues) { + Parameter parameter = new Parameter(); + BaseActor actor = actors.get(queue); + Preconditions.checkArgument(actor != null); + parameter.setActorId(actor.getId()); + /// LocalModeRayActor used in single-process mode. + if (actor instanceof NativeRayJavaActor || actor instanceof LocalModeRayActor) { + parameter.setAsyncFunctionDescriptor(javaAsyncFunctionDesc); + parameter.setSyncFunctionDescriptor(javaSyncFunctionDesc); + } else if (actor instanceof NativeRayPyActor) { + parameter.setAsyncFunctionDescriptor(pyAsyncFunctionDesc); + parameter.setSyncFunctionDescriptor(pySyncFunctionDesc); + } else { + Preconditions.checkArgument(false, "Invalid actor type"); + } + parameters.add(parameter); + } + + return this; + } + + // Called from jni + public List getParameters() { + return parameters; + } + + public String toString() { + String str = ""; + for (Parameter param : parameters) { + str += param.toString(); + } + return str; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java index 243c8d036..64e17f59c 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java @@ -1,7 +1,7 @@ package io.ray.streaming.runtime.transfer; import com.google.common.base.Preconditions; -import io.ray.api.id.ActorId; +import io.ray.api.BaseActor; import io.ray.streaming.runtime.util.Platform; import io.ray.streaming.util.Config; import java.nio.ByteBuffer; @@ -24,14 +24,14 @@ public class DataReader { private Queue buf = new LinkedList<>(); public DataReader(List inputChannels, - List fromActors, + Map fromActors, Map conf) { Preconditions.checkArgument(inputChannels.size() > 0); Preconditions.checkArgument(inputChannels.size() == fromActors.size()); + ChannelCreationParametersBuilder initialParameters = + new ChannelCreationParametersBuilder().buildInputQueueParameters(inputChannels, fromActors); byte[][] inputChannelsBytes = inputChannels.stream() .map(ChannelID::idStrToBytes).toArray(byte[][]::new); - byte[][] fromActorsBytes = fromActors.stream() - .map(ActorId::getBytes).toArray(byte[][]::new); long[] seqIds = new long[inputChannels.size()]; long[] msgIds = new long[inputChannels.size()]; for (int i = 0; i < inputChannels.size(); i++) { @@ -48,8 +48,8 @@ public class DataReader { boolean isRecreate = Boolean.parseBoolean( conf.getOrDefault(Config.IS_RECREATE, "false")); this.nativeReaderPtr = createDataReaderNative( + initialParameters, inputChannelsBytes, - fromActorsBytes, seqIds, msgIds, timerInterval, @@ -155,8 +155,8 @@ public class DataReader { } private static native long createDataReaderNative( + ChannelCreationParametersBuilder initialParameters, byte[][] inputChannels, - byte[][] inputActorIds, long[] seqIds, long[] msgIds, long timerInterval, diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java index 516f2c794..25e02940e 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java @@ -1,7 +1,7 @@ package io.ray.streaming.runtime.transfer; import com.google.common.base.Preconditions; -import io.ray.api.id.ActorId; +import io.ray.api.BaseActor; import io.ray.streaming.runtime.util.Platform; import io.ray.streaming.util.Config; import java.nio.ByteBuffer; @@ -33,14 +33,14 @@ public class DataWriter { * @param conf configuration */ public DataWriter(List outputChannels, - List toActors, + Map toActors, Map conf) { Preconditions.checkArgument(!outputChannels.isEmpty()); Preconditions.checkArgument(outputChannels.size() == toActors.size()); + ChannelCreationParametersBuilder initialParameters = + new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors); byte[][] outputChannelsBytes = outputChannels.stream() .map(ChannelID::idStrToBytes).toArray(byte[][]::new); - byte[][] toActorsBytes = toActors.stream() - .map(ActorId::getBytes).toArray(byte[][]::new); long channelSize = Long.parseLong( conf.getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT)); long[] msgIds = new long[outputChannels.size()]; @@ -53,8 +53,8 @@ public class DataWriter { isMock = true; } this.nativeWriterPtr = createWriterNative( + initialParameters, outputChannelsBytes, - toActorsBytes, msgIds, channelSize, ChannelUtils.toNativeConf(conf), @@ -123,8 +123,8 @@ public class DataWriter { } private static native long createWriterNative( + ChannelCreationParametersBuilder initialParameters, byte[][] outputQueueIds, - byte[][] outputActorIds, long[] msgIds, long channelSize, byte[] confBytes, diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/TransferHandler.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/TransferHandler.java index 7f06673a2..613c8490a 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/TransferHandler.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/TransferHandler.java @@ -1,8 +1,6 @@ package io.ray.streaming.runtime.transfer; import io.ray.runtime.RayNativeRuntime; -import io.ray.runtime.functionmanager.FunctionDescriptor; -import io.ray.runtime.functionmanager.JavaFunctionDescriptor; import io.ray.runtime.util.JniUtils; /** @@ -23,12 +21,9 @@ public class TransferHandler { private long writerClientNative; private long readerClientNative; - public TransferHandler(JavaFunctionDescriptor writerAsyncFunc, - JavaFunctionDescriptor writerSyncFunc, - JavaFunctionDescriptor readerAsyncFunc, - JavaFunctionDescriptor readerSyncFunc) { - writerClientNative = createWriterClientNative(writerAsyncFunc, writerSyncFunc); - readerClientNative = createReaderClientNative(readerAsyncFunc, readerSyncFunc); + public TransferHandler() { + writerClientNative = createWriterClientNative(); + readerClientNative = createReaderClientNative(); } public void onWriterMessage(byte[] buffer) { @@ -47,13 +42,10 @@ public class TransferHandler { return handleReaderMessageSyncNative(readerClientNative, buffer); } - private native long createWriterClientNative( - FunctionDescriptor asyncFunc, - FunctionDescriptor syncFunc); + private native long createWriterClientNative(); + + private native long createReaderClientNative(); - private native long createReaderClientNative( - FunctionDescriptor asyncFunc, - FunctionDescriptor syncFunc); private native void handleWriterMessageNative(long handler, byte[] buffer); diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java index d555f1446..75d587c5a 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java @@ -1,6 +1,5 @@ package io.ray.streaming.runtime.worker; -import io.ray.runtime.functionmanager.JavaFunctionDescriptor; import io.ray.streaming.runtime.core.graph.ExecutionGraph; import io.ray.streaming.runtime.core.graph.ExecutionNode; import io.ray.streaming.runtime.core.graph.ExecutionNode.NodeType; @@ -16,8 +15,10 @@ import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask; import io.ray.streaming.runtime.worker.tasks.SourceStreamTask; import io.ray.streaming.runtime.worker.tasks.StreamTask; import io.ray.streaming.util.Config; + import java.io.Serializable; import java.util.Map; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,17 +53,13 @@ public class JobWorker implements Serializable { this.nodeType = executionNode.getNodeType(); this.streamProcessor = ProcessBuilder - .buildProcessor(executionNode.getStreamOperator()); + .buildProcessor(executionNode.getStreamOperator()); LOGGER.debug("Initializing StreamWorker, taskId: {}, operator: {}.", taskId, streamProcessor); String channelType = (String) this.config.getOrDefault( Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE); if (channelType.equals(Config.NATIVE_CHANNEL)) { - transferHandler = new TransferHandler( - new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessage", "([B)V"), - new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessageSync", "([B)[B"), - new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessage", "([B)V"), - new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessageSync", "([B)[B")); + transferHandler = new TransferHandler(); } task = createStreamTask(); task.start(); diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java index c0b1961fe..e3ba8d034 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java @@ -2,7 +2,6 @@ package io.ray.streaming.runtime.worker.tasks; import io.ray.api.BaseActor; import io.ray.api.Ray; -import io.ray.api.id.ActorId; import io.ray.streaming.api.collector.Collector; import io.ray.streaming.api.context.RuntimeContext; import io.ray.streaming.api.partition.Partition; @@ -17,10 +16,12 @@ import io.ray.streaming.runtime.transfer.DataWriter; import io.ray.streaming.runtime.worker.JobWorker; import io.ray.streaming.runtime.worker.context.RayRuntimeContext; import io.ray.streaming.util.Config; + import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,7 +42,7 @@ public abstract class StreamTask implements Runnable { prepareTask(); this.thread = new Thread(Ray.wrapRunnable(this), this.getClass().getName() - + "-" + System.currentTimeMillis()); + + "-" + System.currentTimeMillis()); this.thread.setDaemon(true); } @@ -64,22 +65,20 @@ public abstract class StreamTask implements Runnable { List outputEdges = executionNode.getOutputEdges(); List collectors = new ArrayList<>(); for (ExecutionEdge edge : outputEdges) { - Map outputActorIds = new HashMap<>(); + Map outputActors = new HashMap<>(); Map taskId2Worker = executionGraph .getTaskId2WorkerByNodeId(edge.getTargetNodeId()); taskId2Worker.forEach((targetTaskId, targetActor) -> { String queueName = ChannelID.genIdStr(taskId, targetTaskId, executionGraph.getBuildTime()); - outputActorIds.put(queueName, targetActor.getId()); + outputActors.put(queueName, targetActor); }); - if (!outputActorIds.isEmpty()) { + if (!outputActors.isEmpty()) { List channelIDs = new ArrayList<>(); - List toActorIds = new ArrayList<>(); - outputActorIds.forEach((k, v) -> { + outputActors.forEach((k, v) -> { channelIDs.add(k); - toActorIds.add(v); }); - DataWriter writer = new DataWriter(channelIDs, toActorIds, queueConf); + DataWriter writer = new DataWriter(channelIDs, outputActors, queueConf); LOG.info("Create DataWriter succeed."); writers.put(edge, writer); Partition partition = edge.getPartition(); @@ -89,24 +88,22 @@ public abstract class StreamTask implements Runnable { // consumer List inputEdges = executionNode.getInputsEdges(); - Map inputActorIds = new HashMap<>(); + Map inputActors = new HashMap<>(); for (ExecutionEdge edge : inputEdges) { Map taskId2Worker = executionGraph .getTaskId2WorkerByNodeId(edge.getSrcNodeId()); taskId2Worker.forEach((srcTaskId, srcActor) -> { String queueName = ChannelID.genIdStr(srcTaskId, taskId, executionGraph.getBuildTime()); - inputActorIds.put(queueName, srcActor.getId()); + inputActors.put(queueName, srcActor); }); } - if (!inputActorIds.isEmpty()) { + if (!inputActors.isEmpty()) { List channelIDs = new ArrayList<>(); - List fromActorIds = new ArrayList<>(); - inputActorIds.forEach((k, v) -> { + inputActors.forEach((k, v) -> { channelIDs.add(k); - fromActorIds.add(v); }); LOG.info("Register queue consumer, queues {}.", channelIDs); - reader = new DataReader(channelIDs, fromActorIds, queueConf); + reader = new DataReader(channelIDs, inputActors, queueConf); } RuntimeContext runtimeContext = new RayRuntimeContext( diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java index e2d617927..cfa34dd04 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java @@ -36,7 +36,6 @@ import org.testng.annotations.Test; public class StreamingQueueTest extends BaseUnitTest implements Serializable { private static Logger LOGGER = LoggerFactory.getLogger(StreamingQueueTest.class); - static { EnvUtil.loadNativeLibraries(); } @@ -62,7 +61,6 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable { @BeforeMethod void beforeMethod() { - LOGGER.info("beforeTest"); Ray.shutdown(); System.setProperty("ray.resources", "CPU:4,RES-A:4"); @@ -144,6 +142,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable { @Test(timeOut = 60000) public void testWordCount() { + LOGGER.info("testWordCount"); LOGGER.info("StreamingQueueTest.testWordCount run-mode: {}", System.getProperty("ray.run-mode")); String resultFile = "/tmp/io.ray.streaming.runtime.streamingqueue.testWordCount.txt"; diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java index a1fe04f86..ab95f6ab0 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java @@ -1,10 +1,11 @@ package io.ray.streaming.runtime.streamingqueue; +import io.ray.api.BaseActor; import io.ray.api.Ray; import io.ray.api.RayActor; -import io.ray.api.id.ActorId; import io.ray.runtime.functionmanager.JavaFunctionDescriptor; import io.ray.streaming.runtime.transfer.ChannelID; +import io.ray.streaming.runtime.transfer.ChannelCreationParametersBuilder; import io.ray.streaming.runtime.transfer.DataMessage; import io.ray.streaming.runtime.transfer.DataReader; import io.ray.streaming.runtime.transfer.DataWriter; @@ -12,7 +13,6 @@ import io.ray.streaming.runtime.transfer.TransferHandler; import io.ray.streaming.util.Config; import java.lang.management.ManagementFactory; import java.nio.ByteBuffer; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -27,15 +27,7 @@ public class Worker { protected TransferHandler transferHandler = null; public Worker() { - transferHandler = new TransferHandler( - new JavaFunctionDescriptor(Worker.class.getName(), - "onWriterMessage", "([B)V"), - new JavaFunctionDescriptor(Worker.class.getName(), - "onWriterMessageSync", "([B)[B"), - new JavaFunctionDescriptor(Worker.class.getName(), - "onReaderMessage", "([B)V"), - new JavaFunctionDescriptor(Worker.class.getName(), - "onReaderMessageSync", "([B)[B")); + transferHandler = new TransferHandler(); } public void onReaderMessage(byte[] buffer) { @@ -60,7 +52,7 @@ class ReaderWorker extends Worker { private String name = null; private List inputQueueList = null; - private List inputActorIds = new ArrayList<>(); + Map fromActors = new HashMap<>(); private DataReader dataReader = null; private long handler = 0; private RayActor peerActor = null; @@ -95,7 +87,7 @@ class ReaderWorker extends Worker { LOGGER.info("java.library.path = {}", System.getProperty("java.library.path")); for (String queue : this.inputQueueList) { - inputActorIds.add(this.peerActor.getId()); + fromActors.put(queue, this.peerActor); LOGGER.info("ReaderWorker actorId: {}", this.peerActor.getId()); } @@ -104,7 +96,10 @@ class ReaderWorker extends Worker { conf.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL); conf.put(Config.CHANNEL_SIZE, "100000"); conf.put(Config.STREAMING_JOB_NAME, "integrationTest1"); - dataReader = new DataReader(inputQueueList, inputActorIds, conf); + ChannelCreationParametersBuilder.setJavaWriterFunctionDesc( + new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessage", "([B)V"), + new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessageSync", "([B)[B")); + dataReader = new DataReader(inputQueueList, fromActors, conf); // Should not GetBundle in RayCall thread Thread readThread = new Thread(Ray.wrapRunnable(new Runnable() { @@ -176,7 +171,7 @@ class WriterWorker extends Worker { private String name = null; private List outputQueueList = null; - private List outputActorIds = new ArrayList<>(); + Map toActors = new HashMap<>(); DataWriter dataWriter = null; RayActor peerActor = null; int msgCount = 0; @@ -208,7 +203,7 @@ class WriterWorker extends Worker { LOGGER.info("WriterWorker init:"); for (String queue : this.outputQueueList) { - outputActorIds.add(this.peerActor.getId()); + toActors.put(queue, this.peerActor); LOGGER.info("WriterWorker actorId: {}", this.peerActor.getId()); } @@ -227,8 +222,10 @@ class WriterWorker extends Worker { conf.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL); conf.put(Config.CHANNEL_SIZE, "100000"); conf.put(Config.STREAMING_JOB_NAME, "integrationTest1"); - - dataWriter = new DataWriter(this.outputQueueList, this.outputActorIds, conf); + ChannelCreationParametersBuilder.setJavaReaderFunctionDesc( + new JavaFunctionDescriptor(Worker.class.getName(), "onReaderMessage", "([B)V"), + new JavaFunctionDescriptor(Worker.class.getName(), "onReaderMessageSync", "([B)[B")); + dataWriter = new DataWriter(this.outputQueueList, this.toActors, conf); Thread writerThread = new Thread(Ray.wrapRunnable(new Runnable() { @Override public void run() { diff --git a/streaming/python/includes/libstreaming.pxd b/streaming/python/includes/libstreaming.pxd index bf210b4a5..08a1ce129 100644 --- a/streaming/python/includes/libstreaming.pxd +++ b/streaming/python/includes/libstreaming.pxd @@ -97,16 +97,21 @@ cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil: void GetMessageListFromRawData(const uint8_t *data, uint32_t size, uint32_t msg_nums, c_list[shared_ptr[CStreamingMessage]] &msg_list); +cdef extern from "channel.h" namespace "ray::streaming" nogil: + cdef struct CChannelCreationParameter "ray::streaming::ChannelCreationParameter": + CChannelCreationParameter() + CActorID actor_id; + shared_ptr[CRayFunction] async_function; + shared_ptr[CRayFunction] sync_function; + cdef extern from "queue/queue_client.h" namespace "ray::streaming" nogil: cdef cppclass CReaderClient "ray::streaming::ReaderClient": - CReaderClient(CRayFunction &async_func, - CRayFunction &sync_func) + CReaderClient() void OnReaderMessage(shared_ptr[CLocalMemoryBuffer] buffer); shared_ptr[CLocalMemoryBuffer] OnReaderMessageSync(shared_ptr[CLocalMemoryBuffer] buffer); cdef cppclass CWriterClient "ray::streaming::WriterClient": - CWriterClient(CRayFunction &async_func, - CRayFunction &sync_func) + CWriterClient() void OnWriterMessage(shared_ptr[CLocalMemoryBuffer] buffer); shared_ptr[CLocalMemoryBuffer] OnWriterMessageSync(shared_ptr[CLocalMemoryBuffer] buffer); @@ -122,7 +127,7 @@ cdef extern from "data_reader.h" namespace "ray::streaming" nogil: cdef cppclass CDataReader "ray::streaming::DataReader"(CStreamingCommon): CDataReader(shared_ptr[CRuntimeContext] &runtime_context) void Init(const c_vector[CObjectID] &input_ids, - const c_vector[CActorID] &actor_ids, + const c_vector[CChannelCreationParameter] ¶ms, const c_vector[uint64_t] &seq_ids, const c_vector[uint64_t] &msg_ids, int64_t timer_interval); @@ -135,7 +140,7 @@ cdef extern from "data_writer.h" namespace "ray::streaming" nogil: cdef cppclass CDataWriter "ray::streaming::DataWriter"(CStreamingCommon): CDataWriter(shared_ptr[CRuntimeContext] &runtime_context) CStreamingStatus Init(const c_vector[CObjectID] &channel_ids, - const c_vector[CActorID] &actor_ids, + const c_vector[CChannelCreationParameter] ¶ms, const c_vector[uint64_t] &message_ids, const c_vector[uint64_t] &queue_size_vec); long WriteMessageToBufferRing( diff --git a/streaming/python/includes/transfer.pxi b/streaming/python/includes/transfer.pxi index c317137d6..5beb261b4 100644 --- a/streaming/python/includes/transfer.pxi +++ b/streaming/python/includes/transfer.pxi @@ -10,6 +10,7 @@ from libcpp.list cimport list as c_list from ray.includes.common cimport ( CRayFunction, LANGUAGE_PYTHON, + LANGUAGE_JAVA, CBuffer ) @@ -36,27 +37,43 @@ from ray.streaming.includes.libstreaming cimport ( CReaderClient, CWriterClient, CLocalMemoryBuffer, + CChannelCreationParameter, ) +from ray._raylet import JavaFunctionDescriptor import logging channel_logger = logging.getLogger(__name__) +cdef class ChannelCreationParameter: + cdef: + CChannelCreationParameter parameter + + def __cinit__(self, ActorID actor_id, FunctionDescriptor async_func, FunctionDescriptor sync_func): + cdef: + shared_ptr[CRayFunction] async_func_ptr + shared_ptr[CRayFunction] sync_func_ptr + self.parameter = CChannelCreationParameter() + self.parameter.actor_id = (actor_id).data + if isinstance(async_func, JavaFunctionDescriptor): + self.parameter.async_function = make_shared[CRayFunction](LANGUAGE_JAVA, async_func.descriptor) + else: + self.parameter.async_function = make_shared[CRayFunction](LANGUAGE_PYTHON, async_func.descriptor) + if isinstance(sync_func, JavaFunctionDescriptor): + self.parameter.sync_function = make_shared[CRayFunction](LANGUAGE_JAVA, sync_func.descriptor) + else: + self.parameter.sync_function = make_shared[CRayFunction](LANGUAGE_PYTHON, sync_func.descriptor) + + cdef CChannelCreationParameter get_parameter(self): + return self.parameter cdef class ReaderClient: cdef: CReaderClient *client - def __cinit__(self, - FunctionDescriptor async_func, - FunctionDescriptor sync_func): - cdef: - CRayFunction async_native_func - CRayFunction sync_native_func - async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor) - sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor) - self.client = new CReaderClient(async_native_func, sync_native_func) + def __cinit__(self): + self.client = new CReaderClient() def __dealloc__(self): del self.client @@ -85,15 +102,8 @@ cdef class WriterClient: cdef: CWriterClient * client - def __cinit__(self, - FunctionDescriptor async_func, - FunctionDescriptor sync_func): - cdef: - CRayFunction async_native_func - CRayFunction sync_native_func - async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor) - sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor) - self.client = new CWriterClient(async_native_func, sync_native_func) + def __cinit__(self): + self.client = new CWriterClient() def __dealloc__(self): del self.client @@ -127,19 +137,21 @@ cdef class DataWriter: @staticmethod def create(list py_output_channels, - list output_actor_ids: list[ActorID], + list output_creation_parameters: list[ChannelCreationParameter], uint64_t queue_size, list py_msg_ids, bytes config_bytes, c_bool is_mock): cdef: c_vector[CObjectID] channel_ids = bytes_list_to_qid_vec(py_output_channels) - c_vector[CActorID] actor_ids + c_vector[CChannelCreationParameter] initial_parameters c_vector[uint64_t] msg_ids CDataWriter *c_writer + ChannelCreationParameter parameter cdef const unsigned char[:] config_data - for actor_id in output_actor_ids: - actor_ids.push_back((actor_id).data) + for param in output_creation_parameters: + parameter = param + initial_parameters.push_back(parameter.get_parameter()) for py_msg_id in py_msg_ids: msg_ids.push_back(py_msg_id) @@ -156,7 +168,7 @@ cdef class DataWriter: c_vector[uint64_t] queue_size_vec for i in range(channel_ids.size()): queue_size_vec.push_back(queue_size) - cdef CStreamingStatus status = c_writer.Init(channel_ids, actor_ids, msg_ids, queue_size_vec) + cdef CStreamingStatus status = c_writer.Init(channel_ids, initial_parameters, msg_ids, queue_size_vec) if remain_id_vec.size() != 0: channel_logger.warning("failed queue amounts => %s", remain_id_vec.size()) if status != libstreaming.StatusOK: @@ -205,7 +217,7 @@ cdef class DataReader: @staticmethod def create(list py_input_queues, - list input_actor_ids: list[ActorID], + list input_creation_parameters: list[ChannelCreationParameter], list py_seq_ids, list py_msg_ids, int64_t timer_interval, @@ -214,13 +226,15 @@ cdef class DataReader: c_bool is_mock): cdef: c_vector[CObjectID] queue_id_vec = bytes_list_to_qid_vec(py_input_queues) - c_vector[CActorID] actor_ids + c_vector[CChannelCreationParameter] initial_parameters c_vector[uint64_t] seq_ids c_vector[uint64_t] msg_ids CDataReader *c_reader + ChannelCreationParameter parameter cdef const unsigned char[:] config_data - for actor_id in input_actor_ids: - actor_ids.push_back((actor_id).data) + for param in input_creation_parameters: + parameter = param + initial_parameters.push_back(parameter.get_parameter()) for py_seq_id in py_seq_ids: seq_ids.push_back(py_seq_id) for py_msg_id in py_msg_ids: @@ -233,7 +247,7 @@ cdef class DataReader: if is_mock: ctx.get().MarkMockTest() c_reader = new CDataReader(ctx) - c_reader.Init(queue_id_vec, actor_ids, seq_ids, msg_ids, timer_interval) + c_reader.Init(queue_id_vec, initial_parameters, seq_ids, msg_ids, timer_interval) channel_logger.info("create native reader succeed") cdef DataReader reader = DataReader.__new__(DataReader) reader.reader = c_reader diff --git a/streaming/python/runtime/transfer.py b/streaming/python/runtime/transfer.py index ba83de20f..9da4be35a 100644 --- a/streaming/python/runtime/transfer.py +++ b/streaming/python/runtime/transfer.py @@ -6,9 +6,11 @@ from typing import List import ray import ray.streaming._streaming as _streaming import ray.streaming.generated.streaming_pb2 as streaming_pb -from ray import ActorID from ray.actor import ActorHandle from ray.streaming.config import Config +from ray._raylet import JavaFunctionDescriptor +from ray._raylet import PythonFunctionDescriptor +from ray._raylet import Language CHANNEL_ID_LEN = 20 @@ -140,6 +142,85 @@ class DataMessage: return self.__message_id +class ChannelCreationParametersBuilder: + """ + wrap initial parameters needed by a streaming queue + """ + _java_reader_async_function_descriptor = JavaFunctionDescriptor( + "org.ray.streaming.runtime.worker", + "onReaderMessage", "([B)V") + _java_reader_sync_function_descriptor = JavaFunctionDescriptor( + "org.ray.streaming.runtime.worker", + "onReaderMessageSync", "([B)[B") + _java_writer_async_function_descriptor = JavaFunctionDescriptor( + "org.ray.streaming.runtime.worker", + "onWriterMessage", "([B)V") + _java_writer_sync_function_descriptor = JavaFunctionDescriptor( + "org.ray.streaming.runtime.worker", + "onWriterMessageSync", "([B)[B") + _python_reader_async_function_descriptor = PythonFunctionDescriptor( + "ray.streaming.runtime.core.worker", + "on_reader_message", "JobWorker") + _python_reader_sync_function_descriptor = PythonFunctionDescriptor( + "ray.streaming.runtime.core.worker", + "on_reader_message_sync", "JobWorker") + _python_writer_async_function_descriptor = PythonFunctionDescriptor( + "ray.streaming.runtime.core.worker", + "on_writer_message", "JobWorker") + _python_writer_sync_function_descriptor = PythonFunctionDescriptor( + "ray.streaming.runtime.core.worker", + "on_writer_message_sync", "JobWorker") + + def get_parameters(self): + return self._parameters + + def __init__(self): + self._parameters = [] + + def build_input_queue_parameters(self, queue_ids_dict): + self.build_parameters(queue_ids_dict, + self._java_writer_async_function_descriptor, + self._java_writer_sync_function_descriptor, + self._python_writer_async_function_descriptor, + self._python_writer_sync_function_descriptor) + return self + + def build_output_queue_parameters(self, to_actors): + self.build_parameters(to_actors, + self._java_reader_async_function_descriptor, + self._java_reader_sync_function_descriptor, + self._python_reader_async_function_descriptor, + self._python_reader_sync_function_descriptor) + return self + + def build_parameters(self, actors, java_async_func, + java_sync_func, py_async_func, py_sync_func): + for handle in actors: + parameter = None + if handle._ray_actor_language == Language.PYTHON: + parameter = _streaming.ChannelCreationParameter( + handle._ray_actor_id, py_async_func, py_sync_func) + else: + parameter = _streaming.ChannelCreationParameter( + handle._ray_actor_id, java_async_func, java_sync_func) + self._parameters.append(parameter) + return self + + @staticmethod + def set_python_writer_function_descriptor(async_function, sync_function): + ChannelCreationParametersBuilder.\ + _python_writer_async_function_descriptor = async_function + ChannelCreationParametersBuilder.\ + _python_writer_sync_function_descriptor = sync_function + + @staticmethod + def set_python_reader_function_descriptor(async_function, sync_function): + ChannelCreationParametersBuilder.\ + _python_reader_async_function_descriptor = async_function + ChannelCreationParametersBuilder.\ + _python_reader_sync_function_descriptor = sync_function + + logger = logging.getLogger(__name__) @@ -161,16 +242,16 @@ class DataWriter: py_output_channels = [ channel_id_str_to_bytes(qid_str) for qid_str in output_channels ] - output_actor_ids: List[ActorID] = [ - handle._ray_actor_id for handle in to_actors - ] + creation_parameters = ChannelCreationParametersBuilder() + creation_parameters.build_output_queue_parameters(to_actors) channel_size = conf.get(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT) py_msg_ids = [0 for _ in range(len(output_channels))] config_bytes = _to_native_conf(conf) is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL self.writer = _streaming.DataWriter.create( - py_output_channels, output_actor_ids, channel_size, py_msg_ids, + py_output_channels, creation_parameters.get_parameters(), + channel_size, py_msg_ids, config_bytes, is_mock) logger.info("create DataWriter succeed") @@ -215,9 +296,8 @@ class DataReader: py_input_channels = [ channel_id_str_to_bytes(qid_str) for qid_str in input_channels ] - input_actor_ids: List[ActorID] = [ - handle._ray_actor_id for handle in from_actors - ] + creation_parameters = ChannelCreationParametersBuilder() + creation_parameters.build_input_queue_parameters(from_actors) py_seq_ids = [0 for _ in range(len(input_channels))] py_msg_ids = [0 for _ in range(len(input_channels))] timer_interval = int(conf.get(Config.TIMER_INTERVAL_MS, -1)) @@ -226,7 +306,8 @@ class DataReader: self.__queue = Queue(10000) is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL self.reader = _streaming.DataReader.create( - py_input_channels, input_actor_ids, py_seq_ids, py_msg_ids, + py_input_channels, creation_parameters.get_parameters(), + py_seq_ids, py_msg_ids, timer_interval, is_recreate, config_bytes, is_mock) logger.info("create DataReader succeed") diff --git a/streaming/python/runtime/worker.py b/streaming/python/runtime/worker.py index 3af4fcd78..9743205ef 100644 --- a/streaming/python/runtime/worker.py +++ b/streaming/python/runtime/worker.py @@ -4,7 +4,6 @@ import ray import ray.streaming._streaming as _streaming import ray.streaming.generated.remote_call_pb2 as remote_call_pb import ray.streaming.runtime.processor as processor -from ray._raylet import PythonFunctionDescriptor from ray.streaming.config import Config from ray.streaming.runtime.graph import ExecutionGraph from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask @@ -48,22 +47,8 @@ class JobWorker(object): self.task_id, self.stream_processor)) if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL): - reader_async_func = PythonFunctionDescriptor( - __name__, self.on_reader_message.__name__, - self.__class__.__name__) - reader_sync_func = PythonFunctionDescriptor( - __name__, self.on_reader_message_sync.__name__, - self.__class__.__name__) - self.reader_client = _streaming.ReaderClient( - reader_async_func, reader_sync_func) - writer_async_func = PythonFunctionDescriptor( - __name__, self.on_writer_message.__name__, - self.__class__.__name__) - writer_sync_func = PythonFunctionDescriptor( - __name__, self.on_writer_message_sync.__name__, - self.__class__.__name__) - self.writer_client = _streaming.WriterClient( - writer_async_func, writer_sync_func) + self.reader_client = _streaming.ReaderClient() + self.writer_client = _streaming.WriterClient() self.task = self.create_stream_task() self.task.start() diff --git a/streaming/python/tests/test_direct_transfer.py b/streaming/python/tests/test_direct_transfer.py index 7bc389e93..12d311528 100644 --- a/streaming/python/tests/test_direct_transfer.py +++ b/streaming/python/tests/test_direct_transfer.py @@ -12,20 +12,8 @@ from ray.streaming.config import Config @ray.remote class Worker: def __init__(self): - writer_async_func = PythonFunctionDescriptor( - __name__, self.on_writer_message.__name__, self.__class__.__name__) - writer_sync_func = PythonFunctionDescriptor( - __name__, self.on_writer_message_sync.__name__, - self.__class__.__name__) - self.writer_client = _streaming.WriterClient(writer_async_func, - writer_sync_func) - reader_async_func = PythonFunctionDescriptor( - __name__, self.on_reader_message.__name__, self.__class__.__name__) - reader_sync_func = PythonFunctionDescriptor( - __name__, self.on_reader_message_sync.__name__, - self.__class__.__name__) - self.reader_client = _streaming.ReaderClient(reader_async_func, - reader_sync_func) + self.writer_client = _streaming.WriterClient() + self.reader_client = _streaming.ReaderClient() self.writer = None self.output_channel_id = None self.reader = None @@ -35,6 +23,14 @@ class Worker: Config.TASK_JOB_ID: ray.worker.global_worker.current_job_id, Config.CHANNEL_TYPE: Config.NATIVE_CHANNEL } + reader_async_func = PythonFunctionDescriptor( + __name__, self.on_reader_message.__name__, self.__class__.__name__) + reader_sync_func = PythonFunctionDescriptor( + __name__, self.on_reader_message_sync.__name__, + self.__class__.__name__) + transfer.ChannelCreationParametersBuilder.\ + set_python_reader_function_descriptor( + reader_async_func, reader_sync_func) self.writer = transfer.DataWriter([output_channel], [pickle.loads(reader_actor)], conf) self.output_channel_id = transfer.ChannelID(output_channel) @@ -44,6 +40,14 @@ class Worker: Config.TASK_JOB_ID: ray.worker.global_worker.current_job_id, Config.CHANNEL_TYPE: Config.NATIVE_CHANNEL } + writer_async_func = PythonFunctionDescriptor( + __name__, self.on_writer_message.__name__, self.__class__.__name__) + writer_sync_func = PythonFunctionDescriptor( + __name__, self.on_writer_message_sync.__name__, + self.__class__.__name__) + transfer.ChannelCreationParametersBuilder.\ + set_python_writer_function_descriptor( + writer_async_func, writer_sync_func) self.reader = transfer.DataReader([input_channel], [pickle.loads(writer_actor)], conf) diff --git a/streaming/src/channel.cc b/streaming/src/channel.cc index 2d3c1271b..cbca454f1 100644 --- a/streaming/src/channel.cc +++ b/streaming/src/channel.cc @@ -60,9 +60,12 @@ StreamingStatus StreamingQueueProducer::CreateQueue() { return StreamingStatus::OK; } - upstream_handler->SetPeerActorID(channel_info_.channel_id, channel_info_.actor_id); - queue_ = upstream_handler->CreateUpstreamQueue( - channel_info_.channel_id, channel_info_.actor_id, channel_info_.queue_size); + upstream_handler->SetPeerActorID( + channel_info_.channel_id, channel_info_.parameter.actor_id, + *channel_info_.parameter.async_function, *channel_info_.parameter.sync_function); + queue_ = upstream_handler->CreateUpstreamQueue(channel_info_.channel_id, + channel_info_.parameter.actor_id, + channel_info_.queue_size); STREAMING_CHECK(queue_ != nullptr); std::vector queue_ids, failed_queues; @@ -154,11 +157,13 @@ StreamingStatus StreamingQueueConsumer::CreateTransferChannel() { return StreamingStatus::OK; } - downstream_handler->SetPeerActorID(channel_info_.channel_id, channel_info_.actor_id); + downstream_handler->SetPeerActorID( + channel_info_.channel_id, channel_info_.parameter.actor_id, + *channel_info_.parameter.async_function, *channel_info_.parameter.sync_function); STREAMING_LOG(INFO) << "Create ReaderQueue " << channel_info_.channel_id << " pull from start_seq_id: " << channel_info_.current_seq_id + 1; queue_ = downstream_handler->CreateDownstreamQueue(channel_info_.channel_id, - channel_info_.actor_id); + channel_info_.parameter.actor_id); return StreamingStatus::OK; } diff --git a/streaming/src/channel.h b/streaming/src/channel.h index 6f0fbe0e2..e5b2454bd 100644 --- a/streaming/src/channel.h +++ b/streaming/src/channel.h @@ -17,6 +17,12 @@ struct StreamingQueueInfo { uint64_t consumed_seq_id = 0; }; +struct ChannelCreationParameter { + ActorID actor_id; + std::shared_ptr async_function; + std::shared_ptr sync_function; +}; + /// PrducerChannelinfo and ConsumerChannelInfo contains channel information and /// its metrics that help us to debug or show important messages in logging. struct ProducerChannelInfo { @@ -28,7 +34,7 @@ struct ProducerChannelInfo { StreamingQueueInfo queue_info; uint32_t queue_size; int64_t message_pass_by_ts; - ActorID actor_id; + ChannelCreationParameter parameter; /// The following parameters are used for event driven to record different /// input events. @@ -55,7 +61,7 @@ struct ConsumerChannelInfo { uint64_t last_queue_item_latency = 0; uint64_t last_queue_target_diff = 0; uint64_t get_queue_item_times = 0; - ActorID actor_id; + ChannelCreationParameter parameter; // Total count of notify request. uint64_t notify_cnt = 0; }; diff --git a/streaming/src/data_reader.cc b/streaming/src/data_reader.cc index f53b2152b..4a96b18d6 100644 --- a/streaming/src/data_reader.cc +++ b/streaming/src/data_reader.cc @@ -17,11 +17,11 @@ namespace streaming { const uint32_t DataReader::kReadItemTimeout = 1000; void DataReader::Init(const std::vector &input_ids, - const std::vector &actor_ids, + const std::vector &init_params, const std::vector &queue_seq_ids, const std::vector &streaming_msg_ids, int64_t timer_interval) { - Init(input_ids, actor_ids, timer_interval); + Init(input_ids, init_params, timer_interval); for (size_t i = 0; i < input_ids.size(); ++i) { auto &q_id = input_ids[i]; channel_info_map_[q_id].current_seq_id = queue_seq_ids[i]; @@ -30,7 +30,8 @@ void DataReader::Init(const std::vector &input_ids, } void DataReader::Init(const std::vector &input_ids, - const std::vector &actor_ids, int64_t timer_interval) { + const std::vector &init_params, + int64_t timer_interval) { STREAMING_LOG(INFO) << input_ids.size() << " queue to init."; transfer_config_->Set(ConfigEnum::QUEUE_ID_VECTOR, input_ids); @@ -47,7 +48,7 @@ void DataReader::Init(const std::vector &input_ids, STREAMING_LOG(INFO) << "[Reader] Init queue id: " << q_id; auto &channel_info = channel_info_map_[q_id]; channel_info.channel_id = q_id; - channel_info.actor_id = actor_ids[i]; + channel_info.parameter = init_params[i]; channel_info.last_queue_item_delay = 0; channel_info.last_queue_item_latency = 0; channel_info.last_queue_target_diff = 0; diff --git a/streaming/src/data_reader.h b/streaming/src/data_reader.h index b1e94005b..e371472c8 100644 --- a/streaming/src/data_reader.h +++ b/streaming/src/data_reader.h @@ -78,11 +78,13 @@ class DataReader { /// \param channel_seq_ids /// \param msg_ids /// \param timer_interval - void Init(const std::vector &input_ids, const std::vector &actor_ids, + void Init(const std::vector &input_ids, + const std::vector &init_params, const std::vector &channel_seq_ids, const std::vector &msg_ids, int64_t timer_interval); - void Init(const std::vector &input_ids, const std::vector &actor_ids, + void Init(const std::vector &input_ids, + const std::vector &init_params, int64_t timer_interval); /// Get latest message from input queues. diff --git a/streaming/src/data_writer.cc b/streaming/src/data_writer.cc index ab929bf41..a7578f219 100644 --- a/streaming/src/data_writer.cc +++ b/streaming/src/data_writer.cc @@ -102,13 +102,14 @@ uint64_t DataWriter::WriteMessageToBufferRing(const ObjectID &q_id, uint8_t *dat return write_message_id; } -StreamingStatus DataWriter::InitChannel(const ObjectID &q_id, const ActorID &actor_id, +StreamingStatus DataWriter::InitChannel(const ObjectID &q_id, + const ChannelCreationParameter ¶m, uint64_t channel_message_id, uint64_t queue_size) { ProducerChannelInfo &channel_info = channel_info_map_[q_id]; channel_info.current_message_id = channel_message_id; channel_info.channel_id = q_id; - channel_info.actor_id = actor_id; + channel_info.parameter = param; channel_info.queue_size = queue_size; STREAMING_LOG(WARNING) << " Init queue [" << q_id << "]"; channel_info.writer_ring_buffer = std::make_shared( @@ -129,7 +130,7 @@ StreamingStatus DataWriter::InitChannel(const ObjectID &q_id, const ActorID &act } StreamingStatus DataWriter::Init(const std::vector &queue_id_vec, - const std::vector &actor_ids, + const std::vector &init_params, const std::vector &channel_message_id_vec, const std::vector &queue_size_vec) { STREAMING_CHECK(!queue_id_vec.empty() && !channel_message_id_vec.empty()); @@ -144,7 +145,7 @@ StreamingStatus DataWriter::Init(const std::vector &queue_id_vec, transfer_config_->Set(ConfigEnum::QUEUE_ID_VECTOR, queue_id_vec); for (size_t i = 0; i < queue_id_vec.size(); ++i) { - StreamingStatus status = InitChannel(queue_id_vec[i], actor_ids[i], + StreamingStatus status = InitChannel(queue_id_vec[i], init_params[i], channel_message_id_vec[i], queue_size_vec[i]); if (status != StreamingStatus::OK) { return status; diff --git a/streaming/src/data_writer.h b/streaming/src/data_writer.h index 6682e4311..673720a77 100644 --- a/streaming/src/data_writer.h +++ b/streaming/src/data_writer.h @@ -37,10 +37,11 @@ class DataWriter { /// Streaming writer client initialization. /// \param queue_id_vec queue id vector + /// \param init_params some parameters for initializing channels /// \param channel_message_id_vec channel seq id is related with message checkpoint /// \param queue_size queue size (memory size not length) StreamingStatus Init(const std::vector &channel_ids, - const std::vector &actor_ids, + const std::vector &init_params, const std::vector &channel_message_id_vec, const std::vector &queue_size_vec); @@ -91,7 +92,7 @@ class DataWriter { StreamingStatus WriteChannelProcess(ProducerChannelInfo &channel_info, bool *is_empty_message); - StreamingStatus InitChannel(const ObjectID &q_id, const ActorID &actor_id, + StreamingStatus InitChannel(const ObjectID &q_id, const ChannelCreationParameter ¶m, uint64_t channel_message_id, uint64_t queue_size); /// Write all messages to channel util ringbuffer is empty. diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc index 3caa7182a..260b50515 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc @@ -9,13 +9,15 @@ using namespace ray::streaming; JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative( - JNIEnv *env, jclass, jobjectArray input_channels, jobjectArray input_actor_ids, - jlongArray seq_id_array, jlongArray msg_id_array, jlong timer_interval, - jboolean isRecreate, jbyteArray config_bytes, jboolean is_mock) { + JNIEnv *env, jclass, jobject streaming_queue_initial_parameters, + jobjectArray input_channels, jlongArray seq_id_array, jlongArray msg_id_array, + jlong timer_interval, jboolean isRecreate, jbyteArray config_bytes, + jboolean is_mock) { STREAMING_LOG(INFO) << "[JNI]: create DataReader."; + std::vector parameter_vec; + ParseChannelInitParameters(env, streaming_queue_initial_parameters, parameter_vec); std::vector input_channels_ids = jarray_to_object_id_vec(env, input_channels); - std::vector actor_ids = jarray_to_actor_id_vec(env, input_actor_ids); std::vector seq_ids = LongVectorFromJLongArray(env, seq_id_array).data; std::vector msg_ids = LongVectorFromJLongArray(env, msg_id_array).data; @@ -29,7 +31,7 @@ Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative( ctx->MarkMockTest(); } auto reader = new DataReader(ctx); - reader->Init(input_channels_ids, actor_ids, seq_ids, msg_ids, timer_interval); + reader->Init(input_channels_ids, parameter_vec, seq_ids, msg_ids, timer_interval); STREAMING_LOG(INFO) << "create native DataReader succeed"; return reinterpret_cast(reader); } @@ -72,17 +74,15 @@ JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBund std::memcpy(meta + kMessageBundleHeaderSize, bundle->from.Data(), kUniqueIDSize); } -JNIEXPORT void JNICALL -Java_io_ray_streaming_runtime_transfer_DataReader_stopReaderNative(JNIEnv *env, - jobject thisObj, - jlong ptr) { +JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_stopReaderNative( + JNIEnv *env, jobject thisObj, jlong ptr) { auto reader = reinterpret_cast(ptr); reader->Stop(); } JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_closeReaderNative(JNIEnv *env, - jobject thisObj, - jlong ptr) { + jobject thisObj, + jlong ptr) { delete reinterpret_cast(ptr); } diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h index 19c02ac4a..b9f8a0196 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h @@ -10,34 +10,37 @@ extern "C" { /* * Class: io_ray_streaming_runtime_transfer_DataReader * Method: createDataReaderNative - * Signature: ([[B[[B[J[JJZ[BZ)J + * Signature: (Lio/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder;[[B[J[JJZ[BZ)J */ -JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative - (JNIEnv *, jclass, jobjectArray, jobjectArray, jlongArray, jlongArray, jlong, jboolean, jbyteArray, jboolean); +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative( + JNIEnv *, jclass, jobject, jobjectArray, jlongArray, jlongArray, jlong, jboolean, + jbyteArray, jboolean); /* * Class: io_ray_streaming_runtime_transfer_DataReader * Method: getBundleNative * Signature: (JJJJ)V */ -JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBundleNative - (JNIEnv *, jobject, jlong, jlong, jlong, jlong); +JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBundleNative( + JNIEnv *, jobject, jlong, jlong, jlong, jlong); /* * Class: io_ray_streaming_runtime_transfer_DataReader * Method: stopReaderNative * Signature: (J)V */ -JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_stopReaderNative - (JNIEnv *, jobject, jlong); +JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_stopReaderNative( + JNIEnv *, jobject, jlong); /* * Class: io_ray_streaming_runtime_transfer_DataReader * Method: closeReaderNative * Signature: (J)V */ -JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_closeReaderNative - (JNIEnv *, jobject, jlong); +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_DataReader_closeReaderNative(JNIEnv *, jobject, + jlong); #ifdef __cplusplus } diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc index 84e77153a..f7aafc5d4 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc @@ -7,10 +7,13 @@ using namespace ray::streaming; JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative( - JNIEnv *env, jclass, jobjectArray output_queue_ids, jobjectArray output_actor_ids, + JNIEnv *env, jclass, jobject initial_parameters, jobjectArray output_queue_ids, jlongArray msg_ids, jlong channel_size, jbyteArray conf_bytes_array, jboolean is_mock) { STREAMING_LOG(INFO) << "[JNI]: createDataWriterNative."; + + std::vector parameter_vec; + ParseChannelInitParameters(env, initial_parameters, parameter_vec); std::vector queue_id_vec = jarray_to_object_id_vec(env, output_queue_ids); for (auto id : queue_id_vec) { @@ -22,9 +25,6 @@ Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative( std::vector msg_ids_vec = LongVectorFromJLongArray(env, msg_ids).data; std::vector queue_size_vec(long_array_obj.data.size(), channel_size); std::vector remain_id_vec; - std::vector actor_ids = jarray_to_actor_id_vec(env, output_actor_ids); - - STREAMING_LOG(INFO) << "actor_ids: " << actor_ids[0]; RawDataFromJByteArray conf(env, conf_bytes_array); STREAMING_CHECK(conf.data != nullptr); @@ -36,7 +36,8 @@ Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative( runtime_context->MarkMockTest(); } auto *data_writer = new DataWriter(runtime_context); - auto status = data_writer->Init(queue_id_vec, actor_ids, msg_ids_vec, queue_size_vec); + auto status = + data_writer->Init(queue_id_vec, parameter_vec, msg_ids_vec, queue_size_vec); if (status != StreamingStatus::OK) { STREAMING_LOG(WARNING) << "DataWriter init failed."; } else { @@ -64,10 +65,8 @@ Java_io_ray_streaming_runtime_transfer_DataWriter_writeMessageNative( return result; } -JNIEXPORT void JNICALL -Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative(JNIEnv *env, - jobject thisObj, - jlong ptr) { +JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative( + JNIEnv *env, jobject thisObj, jlong ptr) { STREAMING_LOG(INFO) << "jni: stop writer."; auto *data_writer = reinterpret_cast(ptr); data_writer->Stop(); @@ -75,8 +74,8 @@ Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative(JNIEnv *env, JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *env, - jobject thisObj, - jlong ptr) { + jobject thisObj, + jlong ptr) { auto *data_writer = reinterpret_cast(ptr); delete data_writer; } \ No newline at end of file diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h index b54e900b1..dddcafdf7 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h @@ -10,34 +10,38 @@ extern "C" { /* * Class: io_ray_streaming_runtime_transfer_DataWriter * Method: createWriterNative - * Signature: ([[B[[B[JJ[BZ)J + * Signature: (Lio/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder;[[B[JJ[BZ)J */ -JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative - (JNIEnv *, jclass, jobjectArray, jobjectArray, jlongArray, jlong, jbyteArray, jboolean); +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative( + JNIEnv *, jclass, jobject, jobjectArray, jlongArray, jlong, jbyteArray, jboolean); /* * Class: io_ray_streaming_runtime_transfer_DataWriter * Method: writeMessageNative * Signature: (JJJI)J */ -JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_writeMessageNative - (JNIEnv *, jobject, jlong, jlong, jlong, jint); +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_writeMessageNative(JNIEnv *, jobject, + jlong, jlong, jlong, + jint); /* * Class: io_ray_streaming_runtime_transfer_DataWriter * Method: stopWriterNative * Signature: (J)V */ -JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative - (JNIEnv *, jobject, jlong); +JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative( + JNIEnv *, jobject, jlong); /* * Class: io_ray_streaming_runtime_transfer_DataWriter * Method: closeWriterNative * Signature: (J)V */ -JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative - (JNIEnv *, jobject, jlong); +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *, jobject, + jlong); #ifdef __cplusplus } diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc index a3298f081..e2cb2e861 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc @@ -14,29 +14,18 @@ static std::shared_ptr JByteArrayToBuffer(JNIEnv *env, JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative( - JNIEnv *env, jobject this_obj, jobject async_func, jobject sync_func) { - auto ray_async_func = FunctionDescriptorToRayFunction(env, async_func); - auto ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func); - auto *writer_client = new WriterClient(ray_async_func, ray_sync_func); + JNIEnv *env, jobject this_obj) { + auto *writer_client = new WriterClient(); return reinterpret_cast(writer_client); } JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative( - JNIEnv *env, jobject this_obj, jobject async_func, jobject sync_func) { - ray::RayFunction ray_async_func = FunctionDescriptorToRayFunction(env, async_func); - ray::RayFunction ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func); - auto *reader_client = new ReaderClient(ray_async_func, ray_sync_func); + JNIEnv *env, jobject this_obj) { + auto *reader_client = new ReaderClient(); return reinterpret_cast(reader_client); } -JNIEXPORT void JNICALL -Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative( - JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) { - auto *writer_client = reinterpret_cast(ptr); - writer_client->OnWriterMessage(JByteArrayToBuffer(env, bytes)); -} - JNIEXPORT jbyteArray JNICALL Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative( JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) { @@ -66,4 +55,4 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNa env->SetByteArrayRegion(arr, 0, result_buffer->Size(), reinterpret_cast(result_buffer->Data())); return arr; -} \ No newline at end of file +} diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h index 63d284c41..320b5c009 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h @@ -12,22 +12,16 @@ extern "C" { * Method: createWriterClientNative * Signature: (J)J */ -JNIEXPORT jlong JNICALL -Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(JNIEnv *, - jobject, - jobject, - jobject); +JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative + (JNIEnv *, jobject); /* * Class: io_ray_streaming_runtime_transfer_TransferHandler * Method: createReaderClientNative * Signature: (J)J */ -JNIEXPORT jlong JNICALL -Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(JNIEnv *, - jobject, - jobject, - jobject); +JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative + (JNIEnv *, jobject); /* * Class: io_ray_streaming_runtime_transfer_TransferHandler @@ -44,7 +38,7 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative * Signature: (J[B)[B */ JNIEXPORT jbyteArray JNICALL - Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative( +Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative( JNIEnv *, jobject, jlong, jbyteArray); /* @@ -62,10 +56,10 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative * Signature: (J[B)[B */ JNIEXPORT jbyteArray JNICALL - Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative( +Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative( JNIEnv *, jobject, jlong, jbyteArray); #ifdef __cplusplus } #endif -#endif \ No newline at end of file +#endif diff --git a/streaming/src/lib/java/streaming_jni_common.cc b/streaming/src/lib/java/streaming_jni_common.cc index cc807af72..197f7c251 100644 --- a/streaming/src/lib/java/streaming_jni_common.cc +++ b/streaming/src/lib/java/streaming_jni_common.cc @@ -88,6 +88,15 @@ void JavaListToNativeVector(JNIEnv *env, jobject java_list, } } +/// Convert a Java byte array to a C++ UniqueID. +template +inline ID JavaByteArrayToId(JNIEnv *env, const jbyteArray &bytes) { + std::string id_str(ID::Size(), 0); + env->GetByteArrayRegion(bytes, 0, ID::Size(), + reinterpret_cast(&id_str.front())); + return ID::FromBinary(id_str); +} + /// Convert a Java String to C++ std::string. std::string JavaStringToNativeString(JNIEnv *env, jstring jstr) { const char *c_str = env->GetStringUTFChars(jstr, nullptr); @@ -105,10 +114,9 @@ void JavaStringListToNativeStringVector(JNIEnv *env, jobject java_list, }); } -ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env, - jobject functionDescriptor) { - jclass java_language_class = - LoadClass(env, "io/ray/runtime/generated/Common$Language"); +std::shared_ptr FunctionDescriptorToRayFunction( + JNIEnv *env, jobject functionDescriptor) { + jclass java_language_class = LoadClass(env, "io/ray/runtime/generated/Common$Language"); jclass java_function_descriptor_class = LoadClass(env, "io/ray/runtime/functionmanager/FunctionDescriptor"); jmethodID java_language_get_number = @@ -129,5 +137,56 @@ ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env, ray::FunctionDescriptor function_descriptor = ray::FunctionDescriptorBuilder::FromVector(language, function_descriptor_list); ray::RayFunction ray_function{language, function_descriptor}; - return ray_function; + return std::make_shared(ray_function); +} + +void ParseChannelInitParameters( + JNIEnv *env, jobject param_obj, + std::vector ¶meter_vec) { + jclass java_streaming_queue_initial_parameters_class = + LoadClass(env, + "io/ray/streaming/runtime/transfer/" + "ChannelCreationParametersBuilder"); + jmethodID java_streaming_queue_initial_parameters_getParameters_method = + env->GetMethodID(java_streaming_queue_initial_parameters_class, "getParameters", + "()Ljava/util/List;"); + STREAMING_CHECK(java_streaming_queue_initial_parameters_getParameters_method != + nullptr); + jclass java_streaming_queue_initial_parameters_parameter_class = + LoadClass(env, + "io/ray/streaming/runtime/transfer/" + "ChannelCreationParametersBuilder$Parameter"); + jmethodID java_getActorIdBytes_method = env->GetMethodID( + java_streaming_queue_initial_parameters_parameter_class, "getActorIdBytes", "()[B"); + jmethodID java_getAsyncFunctionDescriptor_method = + env->GetMethodID(java_streaming_queue_initial_parameters_parameter_class, + "getAsyncFunctionDescriptor", + "()Lio/ray/runtime/functionmanager/FunctionDescriptor;"); + jmethodID java_getSyncFunctionDescriptor_method = + env->GetMethodID(java_streaming_queue_initial_parameters_parameter_class, + "getSyncFunctionDescriptor", + "()Lio/ray/runtime/functionmanager/FunctionDescriptor;"); + // Call getParameters method + jobject parameter_list = env->CallObjectMethod( + param_obj, java_streaming_queue_initial_parameters_getParameters_method); + + JavaListToNativeVector( + env, parameter_list, ¶meter_vec, + [java_getActorIdBytes_method, java_getAsyncFunctionDescriptor_method, + java_getSyncFunctionDescriptor_method](JNIEnv *env, jobject jobject_parameter) { + ray::streaming::ChannelCreationParameter native_parameter; + jbyteArray jobject_actor_id_bytes = (jbyteArray)env->CallObjectMethod( + jobject_parameter, java_getActorIdBytes_method); + native_parameter.actor_id = + JavaByteArrayToId(env, jobject_actor_id_bytes); + jobject jobject_async_func = env->CallObjectMethod( + jobject_parameter, java_getAsyncFunctionDescriptor_method); + native_parameter.async_function = + FunctionDescriptorToRayFunction(env, jobject_async_func); + jobject jobject_sync_func = env->CallObjectMethod( + jobject_parameter, java_getSyncFunctionDescriptor_method); + native_parameter.sync_function = + FunctionDescriptorToRayFunction(env, jobject_sync_func); + return native_parameter; + }); } diff --git a/streaming/src/lib/java/streaming_jni_common.h b/streaming/src/lib/java/streaming_jni_common.h index 921def7d0..acf6c13e8 100644 --- a/streaming/src/lib/java/streaming_jni_common.h +++ b/streaming/src/lib/java/streaming_jni_common.h @@ -3,6 +3,7 @@ #include #include +#include "channel.h" #include "ray/core_worker/common.h" #include "util/streaming_logging.h" @@ -21,12 +22,10 @@ class UniqueIdFromJByteArray { b = reinterpret_cast(_env->GetByteArrayElements(_bytes, nullptr)); PID = ray::ObjectID::FromBinary( - std::string(reinterpret_cast(b), ray::ObjectID::Size())); + std::string(reinterpret_cast(b), ray::ObjectID::Size())); } - ~UniqueIdFromJByteArray() { - _env->ReleaseByteArrayElements(_bytes, b, 0); - } + ~UniqueIdFromJByteArray() { _env->ReleaseByteArrayElements(_bytes, b, 0); } }; class RawDataFromJByteArray { @@ -42,15 +41,13 @@ class RawDataFromJByteArray { _env = env; _bytes = bytes; data_size = _env->GetArrayLength(_bytes); - jbyte *b = - reinterpret_cast(_env->GetByteArrayElements(_bytes, nullptr)); + jbyte *b = reinterpret_cast(_env->GetByteArrayElements(_bytes, nullptr)); data = reinterpret_cast(b); } ~RawDataFromJByteArray() { _env->ReleaseByteArrayElements(_bytes, reinterpret_cast(data), 0); } - }; class StringFromJString { @@ -69,10 +66,7 @@ class StringFromJString { str = std::string(j_str); } - ~StringFromJString() { - _env->ReleaseStringUTFChars(jni_str, j_str); - } - + ~StringFromJString() { _env->ReleaseStringUTFChars(jni_str, j_str); } }; class LongVectorFromJLongArray { @@ -98,14 +92,16 @@ class LongVectorFromJLongArray { } }; -std::vector -jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr); -std::vector -jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr); +std::vector jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr); +std::vector jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr); jint throwRuntimeException(JNIEnv *env, const char *message); jint throwChannelInitException(JNIEnv *env, const char *message, const std::vector &abnormal_queues); jint throwChannelInterruptException(JNIEnv *env, const char *message); -ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env, jobject functionDescriptor); -#endif //RAY_STREAMING_JNI_COMMON_H +std::shared_ptr FunctionDescriptorToRayFunction( + JNIEnv *env, jobject functionDescriptor); +void ParseChannelInitParameters( + JNIEnv *env, jobject param_obj, + std::vector ¶meter_vec); +#endif // RAY_STREAMING_JNI_COMMON_H diff --git a/streaming/src/queue/queue.cc b/streaming/src/queue/queue.cc index d2a9814eb..abd63723b 100644 --- a/streaming/src/queue/queue.cc +++ b/streaming/src/queue/queue.cc @@ -135,8 +135,7 @@ void WriterQueue::Send() { item.IsRaw()); std::unique_ptr buffer = msg.ToBytes(); STREAMING_CHECK(transport_ != nullptr); - transport_->Send(std::move(buffer), - DownstreamQueueMessageHandler::peer_async_function_); + transport_->Send(std::move(buffer)); } } @@ -188,7 +187,7 @@ void ReaderQueue::Notify(uint64_t seq_id) { NotificationMessage msg(actor_id_, peer_actor_id_, queue_id_, seq_id); std::unique_ptr buffer = msg.ToBytes(); - transport_->Send(std::move(buffer), UpstreamQueueMessageHandler::peer_async_function_); + transport_->Send(std::move(buffer)); } void ReaderQueue::CreateNotifyTask(uint64_t seq_id, std::vector &task_args) {} diff --git a/streaming/src/queue/queue_client.h b/streaming/src/queue/queue_client.h index 5d191b3fc..ea48093de 100644 --- a/streaming/src/queue/queue_client.h +++ b/streaming/src/queue/queue_client.h @@ -17,9 +17,7 @@ class ReaderClient { /// \param[in] async_func DataReader's raycall function descriptor to be called by /// DataWriter, asynchronous semantics \param[in] sync_func DataReader's raycall /// function descriptor to be called by DataWriter, synchronous semantics - ReaderClient(RayFunction &async_func, RayFunction &sync_func) { - DownstreamQueueMessageHandler::peer_async_function_ = async_func; - DownstreamQueueMessageHandler::peer_sync_function_ = sync_func; + ReaderClient() { downstream_handler_ = ray::streaming::DownstreamQueueMessageHandler::CreateService( CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID()); } @@ -38,9 +36,7 @@ class ReaderClient { /// Interface of streaming queue for DataWriter. Similar to ReaderClient. class WriterClient { public: - WriterClient(RayFunction &async_func, RayFunction &sync_func) { - UpstreamQueueMessageHandler::peer_async_function_ = async_func; - UpstreamQueueMessageHandler::peer_sync_function_ = sync_func; + WriterClient() { upstream_handler_ = ray::streaming::UpstreamQueueMessageHandler::CreateService( CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID()); } diff --git a/streaming/src/queue/queue_handler.cc b/streaming/src/queue/queue_handler.cc index 6ac0cf0ed..e0a4433a5 100644 --- a/streaming/src/queue/queue_handler.cc +++ b/streaming/src/queue/queue_handler.cc @@ -12,11 +12,6 @@ std::shared_ptr std::shared_ptr DownstreamQueueMessageHandler::downstream_handler_ = nullptr; -RayFunction UpstreamQueueMessageHandler::peer_sync_function_; -RayFunction UpstreamQueueMessageHandler::peer_async_function_; -RayFunction DownstreamQueueMessageHandler::peer_sync_function_; -RayFunction DownstreamQueueMessageHandler::peer_async_function_; - std::shared_ptr QueueMessageHandler::ParseMessage( std::shared_ptr buffer) { uint8_t *bytes = buffer->Data(); @@ -83,10 +78,11 @@ std::shared_ptr QueueMessageHandler::GetOutTransport( } void QueueMessageHandler::SetPeerActorID(const ObjectID &queue_id, - const ActorID &actor_id) { + const ActorID &actor_id, RayFunction &async_func, + RayFunction &sync_func) { actors_.emplace(queue_id, actor_id); - out_transports_.emplace(queue_id, - std::make_shared(actor_id)); + out_transports_.emplace(queue_id, std::make_shared( + actor_id, async_func, sync_func)); } ActorID QueueMessageHandler::GetPeerActorID(const ObjectID &queue_id) { @@ -164,8 +160,7 @@ bool UpstreamQueueMessageHandler::CheckQueueSync(const ObjectID &queue_id) { auto transport_it = GetOutTransport(queue_id); STREAMING_CHECK(transport_it != nullptr); std::shared_ptr result_buffer = transport_it->SendForResultWithRetry( - std::move(buffer), DownstreamQueueMessageHandler::peer_sync_function_, 10, - COMMON_SYNC_CALL_TIMEOUTT_MS); + std::move(buffer), 10, COMMON_SYNC_CALL_TIMEOUTT_MS); if (result_buffer == nullptr) { return false; } diff --git a/streaming/src/queue/queue_handler.h b/streaming/src/queue/queue_handler.h index 0525ec8a6..d3c497f3b 100644 --- a/streaming/src/queue/queue_handler.h +++ b/streaming/src/queue/queue_handler.h @@ -62,7 +62,8 @@ class QueueMessageHandler { /// downstream queue with same queue_id, and vice versa. /// \param[in] queue_id queue id of current queue. /// \param[in] actor_id actor_id actor id of corresponded peer actor. - void SetPeerActorID(const ObjectID &queue_id, const ActorID &actor_id); + void SetPeerActorID(const ObjectID &queue_id, const ActorID &actor_id, + RayFunction &async_func, RayFunction &sync_func); /// Obtain the actor id of the peer actor specified by queue_id. /// \return actor id @@ -133,9 +134,6 @@ class UpstreamQueueMessageHandler : public QueueMessageHandler { const ActorID &actor_id); static std::shared_ptr GetService(); - static RayFunction peer_sync_function_; - static RayFunction peer_async_function_; - private: bool CheckQueueSync(const ObjectID &queue_ids); @@ -170,8 +168,6 @@ class DownstreamQueueMessageHandler : public QueueMessageHandler { static std::shared_ptr CreateService( const ActorID &actor_id); static std::shared_ptr GetService(); - static RayFunction peer_sync_function_; - static RayFunction peer_async_function_; private: std::unordered_map> diff --git a/streaming/src/queue/transport.cc b/streaming/src/queue/transport.cc index 6341fce0a..427e5cab5 100644 --- a/streaming/src/queue/transport.cc +++ b/streaming/src/queue/transport.cc @@ -36,17 +36,16 @@ void Transport::SendInternal(std::shared_ptr buffer, } } -void Transport::Send(std::shared_ptr buffer, RayFunction &function) { +void Transport::Send(std::shared_ptr buffer) { STREAMING_LOG(INFO) << "Transport::Send buffer size: " << buffer->Size(); std::vector return_ids; - SendInternal(std::move(buffer), function, TASK_OPTION_RETURN_NUM_0, return_ids); + SendInternal(std::move(buffer), async_func_, TASK_OPTION_RETURN_NUM_0, return_ids); } std::shared_ptr Transport::SendForResult( - std::shared_ptr buffer, RayFunction &function, - int64_t timeout_ms) { + std::shared_ptr buffer, int64_t timeout_ms) { std::vector return_ids; - SendInternal(buffer, function, TASK_OPTION_RETURN_NUM_1, return_ids); + SendInternal(buffer, sync_func_, TASK_OPTION_RETURN_NUM_1, return_ids); std::vector> results; Status get_st = @@ -73,14 +72,12 @@ std::shared_ptr Transport::SendForResult( } std::shared_ptr Transport::SendForResultWithRetry( - std::shared_ptr buffer, RayFunction &function, int retry_cnt, - int64_t timeout_ms) { + std::shared_ptr buffer, int retry_cnt, int64_t timeout_ms) { STREAMING_LOG(INFO) << "SendForResultWithRetry retry_cnt: " << retry_cnt - << " timeout_ms: " << timeout_ms - << " function: " << function.GetFunctionDescriptor()->ToString(); + << " timeout_ms: " << timeout_ms; std::shared_ptr buffer_shared = std::move(buffer); for (int cnt = 0; cnt < retry_cnt; cnt++) { - auto result = SendForResult(buffer_shared, function, timeout_ms); + auto result = SendForResult(buffer_shared, timeout_ms); if (result != nullptr) { return result; } diff --git a/streaming/src/queue/transport.h b/streaming/src/queue/transport.h index 8d702f41b..2a3969064 100644 --- a/streaming/src/queue/transport.h +++ b/streaming/src/queue/transport.h @@ -14,34 +14,39 @@ class Transport { public: /// Construct a Transport object. /// \param[in] peer_actor_id actor id of peer actor. - Transport(const ActorID &peer_actor_id) - : worker_id_(CoreWorkerProcess::GetCoreWorker().GetWorkerID()), - peer_actor_id_(peer_actor_id) {} + Transport(const ActorID &peer_actor_id, RayFunction &async_func, RayFunction &sync_func) + : peer_actor_id_(peer_actor_id), async_func_(async_func), sync_func_(sync_func) { + STREAMING_LOG(INFO) << "Transport constructor:"; + STREAMING_LOG(INFO) << "async_func lang: " << async_func_.GetLanguage(); + STREAMING_LOG(INFO) << "async_func: " + << async_func_.GetFunctionDescriptor()->ToString(); + STREAMING_LOG(INFO) << "sync_func lang: " << sync_func_.GetLanguage(); + STREAMING_LOG(INFO) << "sync_func: " + << sync_func_.GetFunctionDescriptor()->ToString(); + } + virtual ~Transport() = default; /// Send buffer asynchronously, peer's `function` will be called. /// \param[in] buffer buffer to be sent. - /// \param[in] function the function descriptor of peer's function. - virtual void Send(std::shared_ptr buffer, RayFunction &function); + virtual void Send(std::shared_ptr buffer); + /// Send buffer synchronously, peer's `function` will be called, and return the peer /// function's return value. /// \param[in] buffer buffer to be sent. - /// \param[in] function the function descriptor of peer's function. /// \param[in] timeout_ms max time to wait for result. /// \return peer function's result. virtual std::shared_ptr SendForResult( - std::shared_ptr buffer, RayFunction &function, - int64_t timeout_ms); + std::shared_ptr buffer, int64_t timeout_ms); + /// Send buffer and get result with retry. /// return value. /// \param[in] buffer buffer to be sent. - /// \param[in] function the function descriptor of peer's function. /// \param[in] max retry count /// \param[in] timeout_ms max time to wait for result. /// \return peer function's result. std::shared_ptr SendForResultWithRetry( - std::shared_ptr buffer, RayFunction &function, int retry_cnt, - int64_t timeout_ms); + std::shared_ptr buffer, int retry_cnt, int64_t timeout_ms); private: /// Send buffer internal @@ -56,6 +61,8 @@ class Transport { private: WorkerID worker_id_; ActorID peer_actor_id_; + RayFunction async_func_; + RayFunction sync_func_; }; } // namespace streaming } // namespace ray diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc index 1ef0bc0d8..f5997d8d6 100644 --- a/streaming/src/test/mock_actor.cc +++ b/streaming/src/test/mock_actor.cc @@ -94,8 +94,18 @@ class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite { for (auto &queue_id : queue_ids_) { STREAMING_LOG(INFO) << "queue_id: " << queue_id; } - std::vector actor_ids(queue_ids_.size(), peer_actor_id_); - STREAMING_LOG(INFO) << "writer actor_ids size: " << actor_ids.size() + ChannelCreationParameter param{ + peer_actor_id_, + std::make_shared( + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "reader_async_call_func", ""})), + std::make_shared( + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "reader_sync_call_func", ""}))}; + std::vector params(queue_ids_.size(), param); + STREAMING_LOG(INFO) << "writer actor_ids size: " << params.size() << " actor_id: " << peer_actor_id_; std::shared_ptr runtime_context(new RuntimeContext()); @@ -104,7 +114,7 @@ class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite { std::shared_ptr streaming_writer_client(new DataWriter(runtime_context)); uint64_t queue_size = 10 * 1000 * 1000; std::vector channel_seq_id_vec(queue_ids_.size(), 0); - streaming_writer_client->Init(queue_ids_, actor_ids, channel_seq_id_vec, + streaming_writer_client->Init(queue_ids_, params, channel_seq_id_vec, std::vector(queue_ids_.size(), queue_size)); STREAMING_LOG(INFO) << "streaming_writer_client Init done"; @@ -214,14 +224,24 @@ class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite { } void StreamingReaderStrategyTest(StreamingConfig &config) { - std::vector actor_ids(queue_ids_.size(), peer_actor_id_); - STREAMING_LOG(INFO) << "reader actor_ids size: " << actor_ids.size() + ChannelCreationParameter param{ + peer_actor_id_, + std::make_shared( + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "writer_async_call_func", ""})), + std::make_shared( + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "writer_sync_call_func", ""}))}; + std::vector params(queue_ids_.size(), param); + STREAMING_LOG(INFO) << "reader actor_ids size: " << params.size() << " actor_id: " << peer_actor_id_; std::shared_ptr runtime_context(new RuntimeContext()); runtime_context->SetConfig(config); std::shared_ptr reader(new DataReader(runtime_context)); - reader->Init(queue_ids_, actor_ids, -1); + reader->Init(queue_ids_, params, -1); ReaderLoopForward(reader, nullptr, queue_ids_); STREAMING_LOG(INFO) << "Reader exit"; @@ -298,23 +318,8 @@ class StreamingWorker { }; CoreWorkerProcess::Initialize(options); - RayFunction reader_async_call_func{ray::Language::PYTHON, - ray::FunctionDescriptorBuilder::BuildPython( - "reader_async_call_func", "", "", "")}; - RayFunction reader_sync_call_func{ - ray::Language::PYTHON, - ray::FunctionDescriptorBuilder::BuildPython("reader_sync_call_func", "", "", "")}; - RayFunction writer_async_call_func{ray::Language::PYTHON, - ray::FunctionDescriptorBuilder::BuildPython( - "writer_async_call_func", "", "", "")}; - RayFunction writer_sync_call_func{ - ray::Language::PYTHON, - ray::FunctionDescriptorBuilder::BuildPython("writer_sync_call_func", "", "", "")}; - - reader_client_ = - std::make_shared(reader_async_call_func, reader_sync_call_func); - writer_client_ = - std::make_shared(writer_async_call_func, writer_sync_call_func); + reader_client_ = std::make_shared(); + writer_client_ = std::make_shared(); STREAMING_LOG(INFO) << "StreamingWorker constructor"; } @@ -338,9 +343,9 @@ class StreamingWorker { ray::FunctionDescriptorType::kPythonFunctionDescriptor); auto typed_descriptor = function_descriptor->As(); STREAMING_LOG(INFO) << "StreamingWorker::ExecuteTask " - << typed_descriptor->ModuleName(); + << typed_descriptor->ToString(); - std::string func_name = typed_descriptor->ModuleName(); + std::string func_name = typed_descriptor->FunctionName(); if (func_name == "init") { std::shared_ptr local_buffer = std::make_shared(args[0]->GetData()->Data(), diff --git a/streaming/src/test/mock_transfer_tests.cc b/streaming/src/test/mock_transfer_tests.cc index 012c9edd5..f6268ec81 100644 --- a/streaming/src/test/mock_transfer_tests.cc +++ b/streaming/src/test/mock_transfer_tests.cc @@ -51,11 +51,9 @@ class StreamingTransferTest : public ::testing::Test { } std::vector channel_id_vec(queue_vec.size(), 0); std::vector queue_size_vec(queue_vec.size(), 10000); - // actor ids are not used in this test, so we can just use Nil. - std::vector actor_id_vec(queue_vec.size(), - ActorID::NilFromJob(JobID::FromInt(0))); - writer->Init(queue_vec, actor_id_vec, channel_id_vec, queue_size_vec); - reader->Init(queue_vec, actor_id_vec, channel_id_vec, queue_size_vec, -1); + std::vector params(queue_vec.size()); + writer->Init(queue_vec, params, channel_id_vec, queue_size_vec); + reader->Init(queue_vec, params, channel_id_vec, queue_size_vec, -1); } void DestroyTransfer() { writer.reset(); diff --git a/streaming/src/test/queue_tests_base.h b/streaming/src/test/queue_tests_base.h index fb4fa845f..f1b2029c6 100644 --- a/streaming/src/test/queue_tests_base.h +++ b/streaming/src/test/queue_tests_base.h @@ -175,7 +175,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { TaskOptions options{0, resources}; std::vector return_ids; RayFunction func{ray::Language::PYTHON, - ray::FunctionDescriptorBuilder::BuildPython("init", "", "", "")}; + ray::FunctionDescriptorBuilder::BuildPython("", "", "init", "")}; RAY_CHECK_OK(driver.SubmitActorTask(self_actor_id, func, args, options, &return_ids)); } @@ -191,7 +191,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { TaskOptions options{0, resources}; std::vector return_ids; RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( - "execute_test", test, "", "")}; + "", test, "execute_test", "")}; RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids)); } @@ -207,7 +207,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { TaskOptions options{1, resources}; std::vector return_ids; RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( - "check_current_test_status", "", "", "")}; + "", "", "check_current_test_status", "")}; RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids)); @@ -267,7 +267,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { auto buffer = std::make_shared(array, sizeof(array)); RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( - "actor creation task", "", "", "")}; + "", "", "actor creation task", "")}; std::vector args; args.emplace_back(TaskArg::PassByValue( std::make_shared(buffer, nullptr, std::vector())));