mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 21:04:35 +08:00
[Streaming] Streaming data transfer java (#6474)
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
#include "org_ray_streaming_runtime_transfer_ChannelID.h"
|
||||
#include "streaming_jni_common.h"
|
||||
using namespace ray::streaming;
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_createNativeID(
|
||||
JNIEnv *env, jclass cls, jlong qid_address) {
|
||||
auto id = ray::ObjectID::FromBinary(
|
||||
std::string(reinterpret_cast<const char *>(qid_address), ray::ObjectID::Size()));
|
||||
return reinterpret_cast<jlong>(new ray::ObjectID(id));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_destroyNativeID(
|
||||
JNIEnv *env, jclass cls, jlong native_id_ptr) {
|
||||
auto id = reinterpret_cast<ray::ObjectID *>(native_id_ptr);
|
||||
STREAMING_CHECK(id != nullptr);
|
||||
delete id;
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||
#include <jni.h>
|
||||
/* Header for class org_ray_streaming_runtime_transfer_ChannelID */
|
||||
|
||||
#ifndef _Included_org_ray_streaming_runtime_transfer_ChannelID
|
||||
#define _Included_org_ray_streaming_runtime_transfer_ChannelID
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#undef org_ray_streaming_runtime_transfer_ChannelID_ID_LENGTH
|
||||
#define org_ray_streaming_runtime_transfer_ChannelID_ID_LENGTH 20L
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_ChannelID
|
||||
* Method: createNativeID
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_createNativeID
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_ChannelID
|
||||
* Method: destroyNativeID
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_destroyNativeID
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
@@ -0,0 +1,88 @@
|
||||
#include "org_ray_streaming_runtime_transfer_DataReader.h"
|
||||
#include <cstdlib>
|
||||
#include "data_reader.h"
|
||||
#include "runtime_context.h"
|
||||
#include "streaming_jni_common.h"
|
||||
|
||||
using namespace ray;
|
||||
using namespace ray::streaming;
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_DataReader_createDataReaderNative(
|
||||
JNIEnv *env, jclass, jobjectArray input_channels, jobjectArray input_actor_ids,
|
||||
jlongArray seq_id_array, jlongArray msg_id_array, jlong timer_interval,
|
||||
jboolean isRecreate, jbyteArray config_bytes, jboolean is_mock) {
|
||||
STREAMING_LOG(INFO) << "[JNI]: create DataReader.";
|
||||
std::vector<ray::ObjectID> input_channels_ids =
|
||||
jarray_to_object_id_vec(env, input_channels);
|
||||
std::vector<ray::ActorID> actor_ids = jarray_to_actor_id_vec(env, input_actor_ids);
|
||||
std::vector<uint64_t> seq_ids = LongVectorFromJLongArray(env, seq_id_array).data;
|
||||
std::vector<uint64_t> msg_ids = LongVectorFromJLongArray(env, msg_id_array).data;
|
||||
|
||||
auto ctx = std::make_shared<RuntimeContext>();
|
||||
RawDataFromJByteArray conf(env, config_bytes);
|
||||
if (conf.data_size > 0) {
|
||||
STREAMING_LOG(INFO) << "load config, config bytes size: " << conf.data_size;
|
||||
ctx->SetConfig(conf.data, conf.data_size);
|
||||
}
|
||||
if (is_mock) {
|
||||
ctx->MarkMockTest();
|
||||
}
|
||||
auto reader = new DataReader(ctx);
|
||||
reader->Init(input_channels_ids, actor_ids, seq_ids, msg_ids, timer_interval);
|
||||
STREAMING_LOG(INFO) << "create native DataReader succeed";
|
||||
return reinterpret_cast<jlong>(reader);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_getBundleNative(
|
||||
JNIEnv *env, jobject, jlong reader_ptr, jlong timeout_millis, jlong out,
|
||||
jlong meta_addr) {
|
||||
std::shared_ptr<ray::streaming::DataBundle> bundle;
|
||||
auto reader = reinterpret_cast<ray::streaming::DataReader *>(reader_ptr);
|
||||
auto status = reader->GetBundle((uint32_t)timeout_millis, bundle);
|
||||
|
||||
// over timeout, return empty array.
|
||||
if (StreamingStatus::Interrupted == status) {
|
||||
throwChannelInterruptException(env, "reader interrupted.");
|
||||
} else if (StreamingStatus::GetBundleTimeOut == status) {
|
||||
} else if (StreamingStatus::InitQueueFailed == status) {
|
||||
throwRuntimeException(env, "init channel failed");
|
||||
} else if (StreamingStatus::WaitQueueTimeOut == status) {
|
||||
throwRuntimeException(env, "wait channel object timeout");
|
||||
}
|
||||
|
||||
if (StreamingStatus::OK != status) {
|
||||
*reinterpret_cast<uint64_t *>(out) = 0;
|
||||
*reinterpret_cast<uint32_t *>(out + 8) = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
// bundle data
|
||||
// In streaming queue, bundle data and metadata will be different args of direct call,
|
||||
// so we separate it here for future extensibility.
|
||||
*reinterpret_cast<uint64_t *>(out) =
|
||||
reinterpret_cast<uint64_t>(bundle->data + kMessageBundleHeaderSize);
|
||||
*reinterpret_cast<uint32_t *>(out + 8) = bundle->data_size - kMessageBundleHeaderSize;
|
||||
|
||||
// bundle metadata
|
||||
auto meta = reinterpret_cast<uint8_t *>(meta_addr);
|
||||
// bundle header written by writer
|
||||
std::memcpy(meta, bundle->data, kMessageBundleHeaderSize);
|
||||
// append qid
|
||||
std::memcpy(meta + kMessageBundleHeaderSize, bundle->from.Data(), kUniqueIDSize);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_DataReader_stopReaderNative(JNIEnv *env,
|
||||
jobject thisObj,
|
||||
jlong ptr) {
|
||||
auto reader = reinterpret_cast<DataReader *>(ptr);
|
||||
reader->Stop();
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_DataReader_closeReaderNative(JNIEnv *env,
|
||||
jobject thisObj,
|
||||
jlong ptr) {
|
||||
delete reinterpret_cast<DataReader *>(ptr);
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||
#include <jni.h>
|
||||
/* Header for class org_ray_streaming_runtime_transfer_DataReader */
|
||||
|
||||
#ifndef _Included_org_ray_streaming_runtime_transfer_DataReader
|
||||
#define _Included_org_ray_streaming_runtime_transfer_DataReader
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_DataReader
|
||||
* Method: createDataReaderNative
|
||||
* Signature: ([[B[[B[J[JJZ[BZ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_createDataReaderNative
|
||||
(JNIEnv *, jclass, jobjectArray, jobjectArray, jlongArray, jlongArray, jlong, jboolean, jbyteArray, jboolean);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_DataReader
|
||||
* Method: getBundleNative
|
||||
* Signature: (JJJJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_getBundleNative
|
||||
(JNIEnv *, jobject, jlong, jlong, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_DataReader
|
||||
* Method: stopReaderNative
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_stopReaderNative
|
||||
(JNIEnv *, jobject, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_DataReader
|
||||
* Method: closeReaderNative
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_closeReaderNative
|
||||
(JNIEnv *, jobject, jlong);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
@@ -0,0 +1,82 @@
|
||||
#include "org_ray_streaming_runtime_transfer_DataWriter.h"
|
||||
#include "config/streaming_config.h"
|
||||
#include "data_writer.h"
|
||||
#include "streaming_jni_common.h"
|
||||
|
||||
using namespace ray::streaming;
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_DataWriter_createWriterNative(
|
||||
JNIEnv *env, jclass, jobjectArray output_queue_ids, jobjectArray output_actor_ids,
|
||||
jlongArray msg_ids, jlong channel_size, jbyteArray conf_bytes_array,
|
||||
jboolean is_mock) {
|
||||
STREAMING_LOG(INFO) << "[JNI]: createDataWriterNative.";
|
||||
std::vector<ray::ObjectID> queue_id_vec =
|
||||
jarray_to_object_id_vec(env, output_queue_ids);
|
||||
for (auto id : queue_id_vec) {
|
||||
STREAMING_LOG(INFO) << "output channel id: " << id.Hex();
|
||||
}
|
||||
STREAMING_LOG(INFO) << "total channel size: " << channel_size << "*"
|
||||
<< queue_id_vec.size() << "=" << queue_id_vec.size() * channel_size;
|
||||
LongVectorFromJLongArray long_array_obj(env, msg_ids);
|
||||
std::vector<uint64_t> msg_ids_vec = LongVectorFromJLongArray(env, msg_ids).data;
|
||||
std::vector<uint64_t> queue_size_vec(long_array_obj.data.size(), channel_size);
|
||||
std::vector<ray::ObjectID> remain_id_vec;
|
||||
std::vector<ray::ActorID> actor_ids = jarray_to_actor_id_vec(env, output_actor_ids);
|
||||
|
||||
STREAMING_LOG(INFO) << "actor_ids: " << actor_ids[0];
|
||||
|
||||
RawDataFromJByteArray conf(env, conf_bytes_array);
|
||||
STREAMING_CHECK(conf.data != nullptr);
|
||||
auto runtime_context = std::make_shared<RuntimeContext>();
|
||||
if (conf.data_size > 0) {
|
||||
runtime_context->SetConfig(conf.data, conf.data_size);
|
||||
}
|
||||
if (is_mock) {
|
||||
runtime_context->MarkMockTest();
|
||||
}
|
||||
auto *data_writer = new DataWriter(runtime_context);
|
||||
auto status = data_writer->Init(queue_id_vec, actor_ids, msg_ids_vec, queue_size_vec);
|
||||
if (status != StreamingStatus::OK) {
|
||||
STREAMING_LOG(WARNING) << "DataWriter init failed.";
|
||||
} else {
|
||||
STREAMING_LOG(INFO) << "DataWriter init success";
|
||||
}
|
||||
|
||||
data_writer->Run();
|
||||
return reinterpret_cast<jlong>(data_writer);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_DataWriter_writeMessageNative(
|
||||
JNIEnv *env, jobject, jlong writer_ptr, jlong qid_ptr, jlong address, jint size) {
|
||||
auto *data_writer = reinterpret_cast<DataWriter *>(writer_ptr);
|
||||
auto qid = *reinterpret_cast<ray::ObjectID *>(qid_ptr);
|
||||
auto data = reinterpret_cast<uint8_t *>(address);
|
||||
auto data_size = static_cast<uint32_t>(size);
|
||||
jlong result = data_writer->WriteMessageToBufferRing(qid, data, data_size,
|
||||
StreamingMessageType::Message);
|
||||
|
||||
if (result == 0) {
|
||||
STREAMING_LOG(INFO) << "writer interrupted, return 0.";
|
||||
throwChannelInterruptException(env, "writer interrupted.");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_DataWriter_stopWriterNative(JNIEnv *env,
|
||||
jobject thisObj,
|
||||
jlong ptr) {
|
||||
STREAMING_LOG(INFO) << "jni: stop writer.";
|
||||
auto *data_writer = reinterpret_cast<DataWriter *>(ptr);
|
||||
data_writer->Stop();
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *env,
|
||||
jobject thisObj,
|
||||
jlong ptr) {
|
||||
auto *data_writer = reinterpret_cast<DataWriter *>(ptr);
|
||||
delete data_writer;
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||
#include <jni.h>
|
||||
/* Header for class org_ray_streaming_runtime_transfer_DataWriter */
|
||||
|
||||
#ifndef _Included_org_ray_streaming_runtime_transfer_DataWriter
|
||||
#define _Included_org_ray_streaming_runtime_transfer_DataWriter
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_DataWriter
|
||||
* Method: createWriterNative
|
||||
* Signature: ([[B[[B[JJ[BZ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_createWriterNative
|
||||
(JNIEnv *, jclass, jobjectArray, jobjectArray, jlongArray, jlong, jbyteArray, jboolean);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_DataWriter
|
||||
* Method: writeMessageNative
|
||||
* Signature: (JJJI)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_writeMessageNative
|
||||
(JNIEnv *, jobject, jlong, jlong, jlong, jint);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_DataWriter
|
||||
* Method: stopWriterNative
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_stopWriterNative
|
||||
(JNIEnv *, jobject, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_DataWriter
|
||||
* Method: closeWriterNative
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_closeWriterNative
|
||||
(JNIEnv *, jobject, jlong);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
@@ -0,0 +1,75 @@
|
||||
#include "org_ray_streaming_runtime_transfer_TransferHandler.h"
|
||||
#include "queue/queue_client.h"
|
||||
#include "streaming_jni_common.h"
|
||||
|
||||
using namespace ray::streaming;
|
||||
|
||||
static std::shared_ptr<ray::LocalMemoryBuffer> JByteArrayToBuffer(JNIEnv *env,
|
||||
jbyteArray bytes) {
|
||||
RawDataFromJByteArray buf(env, bytes);
|
||||
STREAMING_CHECK(buf.data != nullptr);
|
||||
|
||||
return std::make_shared<ray::LocalMemoryBuffer>(buf.data, buf.data_size, true);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(
|
||||
JNIEnv *env, jobject this_obj, jlong core_worker_ptr, jobject async_func,
|
||||
jobject sync_func) {
|
||||
auto ray_async_func = FunctionDescriptorToRayFunction(env, async_func);
|
||||
auto ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func);
|
||||
auto *writer_client =
|
||||
new WriterClient(reinterpret_cast<ray::CoreWorker *>(core_worker_ptr),
|
||||
ray_async_func, ray_sync_func);
|
||||
return reinterpret_cast<jlong>(writer_client);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(
|
||||
JNIEnv *env, jobject this_obj, jlong core_worker_ptr, jobject async_func,
|
||||
jobject sync_func) {
|
||||
ray::RayFunction ray_async_func = FunctionDescriptorToRayFunction(env, async_func);
|
||||
ray::RayFunction ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func);
|
||||
auto *reader_client =
|
||||
new ReaderClient(reinterpret_cast<ray::CoreWorker *>(core_worker_ptr),
|
||||
ray_async_func, ray_sync_func);
|
||||
return reinterpret_cast<jlong>(reader_client);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative(
|
||||
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
|
||||
auto *writer_client = reinterpret_cast<WriterClient *>(ptr);
|
||||
writer_client->OnWriterMessage(JByteArrayToBuffer(env, bytes));
|
||||
}
|
||||
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
|
||||
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
|
||||
auto *writer_client = reinterpret_cast<WriterClient *>(ptr);
|
||||
std::shared_ptr<ray::LocalMemoryBuffer> result_buffer =
|
||||
writer_client->OnWriterMessageSync(JByteArrayToBuffer(env, bytes));
|
||||
jbyteArray arr = env->NewByteArray(result_buffer->Size());
|
||||
env->SetByteArrayRegion(arr, 0, result_buffer->Size(),
|
||||
reinterpret_cast<jbyte *>(result_buffer->Data()));
|
||||
return arr;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative(
|
||||
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
|
||||
auto *reader_client = reinterpret_cast<ReaderClient *>(ptr);
|
||||
reader_client->OnReaderMessage(JByteArrayToBuffer(env, bytes));
|
||||
}
|
||||
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative(
|
||||
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
|
||||
auto *reader_client = reinterpret_cast<ReaderClient *>(ptr);
|
||||
auto result_buffer = reader_client->OnReaderMessageSync(JByteArrayToBuffer(env, bytes));
|
||||
|
||||
jbyteArray arr = env->NewByteArray(result_buffer->Size());
|
||||
env->SetByteArrayRegion(arr, 0, result_buffer->Size(),
|
||||
reinterpret_cast<jbyte *>(result_buffer->Data()));
|
||||
return arr;
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||
#include <jni.h>
|
||||
/* Header for class org_ray_streaming_runtime_transfer_TransferHandler */
|
||||
|
||||
#ifndef _Included_org_ray_streaming_runtime_transfer_TransferHandler
|
||||
#define _Included_org_ray_streaming_runtime_transfer_TransferHandler
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_TransferHandler
|
||||
* Method: createWriterClientNative
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative
|
||||
(JNIEnv *, jobject, jlong, jobject, jobject);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_TransferHandler
|
||||
* Method: createReaderClientNative
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative
|
||||
(JNIEnv *, jobject, jlong, jobject, jobject);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_TransferHandler
|
||||
* Method: handleWriterMessageNative
|
||||
* Signature: (J[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative
|
||||
(JNIEnv *, jobject, jlong, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_TransferHandler
|
||||
* Method: handleWriterMessageSyncNative
|
||||
* Signature: (J[B)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative
|
||||
(JNIEnv *, jobject, jlong, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_TransferHandler
|
||||
* Method: handleReaderMessageNative
|
||||
* Signature: (J[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative
|
||||
(JNIEnv *, jobject, jlong, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_TransferHandler
|
||||
* Method: handleReaderMessageSyncNative
|
||||
* Signature: (J[B)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative
|
||||
(JNIEnv *, jobject, jlong, jbyteArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
@@ -0,0 +1,123 @@
|
||||
#include "streaming_jni_common.h"
|
||||
|
||||
std::vector<ray::ObjectID>
|
||||
jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr) {
|
||||
int stringCount = env->GetArrayLength(jarr);
|
||||
std::vector<ray::ObjectID> object_id_vec;
|
||||
for (int i = 0; i < stringCount; i++) {
|
||||
auto jstr = (jbyteArray) (env->GetObjectArrayElement(jarr, i));
|
||||
UniqueIdFromJByteArray idFromJByteArray(env, jstr);
|
||||
object_id_vec.push_back(idFromJByteArray.PID);
|
||||
}
|
||||
return object_id_vec;
|
||||
}
|
||||
|
||||
std::vector<ray::ActorID>
|
||||
jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr) {
|
||||
int count = env->GetArrayLength(jarr);
|
||||
std::vector<ray::ActorID> actor_id_vec;
|
||||
for (int i = 0; i < count; i++) {
|
||||
auto bytes = (jbyteArray)(env->GetObjectArrayElement(jarr, i));
|
||||
std::string id_str(ray::ActorID::Size(), 0);
|
||||
env->GetByteArrayRegion(bytes, 0, ray::ActorID::Size(),
|
||||
reinterpret_cast<jbyte *>(&id_str.front()));
|
||||
actor_id_vec.push_back(ActorID::FromBinary(id_str));
|
||||
}
|
||||
|
||||
return actor_id_vec;
|
||||
}
|
||||
|
||||
jint throwRuntimeException(JNIEnv *env, const char *message) {
|
||||
jclass exClass;
|
||||
char className[] = "java/lang/RuntimeException";
|
||||
exClass = env->FindClass(className);
|
||||
return env->ThrowNew(exClass, message);
|
||||
}
|
||||
|
||||
jint throwChannelInitException(JNIEnv *env, const char *message,
|
||||
const std::vector<ray::ObjectID> &abnormal_queues) {
|
||||
jclass array_list_class = env->FindClass("java/util/ArrayList");
|
||||
jmethodID array_list_constructor = env->GetMethodID(array_list_class, "<init>", "()V");
|
||||
jmethodID array_list_add = env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z");
|
||||
jobject array_list = env->NewObject(array_list_class, array_list_constructor);
|
||||
|
||||
for (auto &q_id : abnormal_queues) {
|
||||
jbyteArray jbyte_array = env->NewByteArray(kUniqueIDSize);
|
||||
env->SetByteArrayRegion(jbyte_array, 0, kUniqueIDSize, const_cast<jbyte*>(reinterpret_cast<const jbyte *>(q_id.Data())));
|
||||
env->CallBooleanMethod(array_list, array_list_add, jbyte_array);
|
||||
}
|
||||
|
||||
jclass ex_class = env->FindClass("org/ray/streaming/runtime/transfer/ChannelInitException");
|
||||
jmethodID ex_constructor = env->GetMethodID(ex_class, "<init>", "(Ljava/lang/String;Ljava/util/List;)V");
|
||||
jstring message_jstr = env->NewStringUTF(message);
|
||||
jobject ex_obj = env->NewObject(ex_class, ex_constructor, message_jstr, array_list);
|
||||
env->DeleteLocalRef(message_jstr);
|
||||
return env->Throw((jthrowable)ex_obj);
|
||||
}
|
||||
|
||||
jint throwChannelInterruptException(JNIEnv *env, const char *message) {
|
||||
jclass ex_class = env->FindClass("org/ray/streaming/runtime/transfer/ChannelInterruptException");
|
||||
return env->ThrowNew(ex_class, message);
|
||||
}
|
||||
|
||||
jclass LoadClass(JNIEnv *env, const char *class_name) {
|
||||
jclass tempLocalClassRef = env->FindClass(class_name);
|
||||
jclass ret = (jclass)env->NewGlobalRef(tempLocalClassRef);
|
||||
STREAMING_CHECK(ret) << "Can't load Java class " << class_name;
|
||||
env->DeleteLocalRef(tempLocalClassRef);
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void JavaListToNativeVector(
|
||||
JNIEnv *env, jobject java_list, std::vector<NativeT> *native_vector,
|
||||
std::function<NativeT(JNIEnv *, jobject)> element_converter) {
|
||||
jclass java_list_class = LoadClass(env, "java/util/List");
|
||||
jmethodID java_list_size = env->GetMethodID(java_list_class, "size", "()I");
|
||||
jmethodID java_list_get = env->GetMethodID(java_list_class, "get", "(I)Ljava/lang/Object;");
|
||||
int size = env->CallIntMethod(java_list, java_list_size);
|
||||
native_vector->clear();
|
||||
for (int i = 0; i < size; i++) {
|
||||
native_vector->emplace_back(
|
||||
element_converter(env, env->CallObjectMethod(java_list, java_list_get, (jint)i)));
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a Java String to C++ std::string.
|
||||
std::string JavaStringToNativeString(JNIEnv *env, jstring jstr) {
|
||||
const char *c_str = env->GetStringUTFChars(jstr, nullptr);
|
||||
std::string result(c_str);
|
||||
env->ReleaseStringUTFChars(static_cast<jstring>(jstr), c_str);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Convert a Java List<String> to C++ std::vector<std::string>.
|
||||
void JavaStringListToNativeStringVector(JNIEnv *env, jobject java_list,
|
||||
std::vector<std::string> *native_vector) {
|
||||
JavaListToNativeVector<std::string>(
|
||||
env, java_list, native_vector, [](JNIEnv *env, jobject jstr) {
|
||||
return JavaStringToNativeString(env, static_cast<jstring>(jstr));
|
||||
});
|
||||
}
|
||||
|
||||
ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env, jobject functionDescriptor) {
|
||||
jclass java_language_class = LoadClass(env, "org/ray/runtime/generated/Common$Language");
|
||||
jclass java_function_descriptor_class =
|
||||
LoadClass(env, "org/ray/runtime/functionmanager/FunctionDescriptor");
|
||||
jmethodID java_language_get_number = env->GetMethodID(java_language_class, "getNumber", "()I");
|
||||
jmethodID java_function_descriptor_get_language =
|
||||
env->GetMethodID(java_function_descriptor_class, "getLanguage",
|
||||
"()Lorg/ray/runtime/generated/Common$Language;");
|
||||
jobject java_language =
|
||||
env->CallObjectMethod(functionDescriptor, java_function_descriptor_get_language);
|
||||
int language = env->CallIntMethod(java_language, java_language_get_number);
|
||||
std::vector<std::string> function_descriptor;
|
||||
jmethodID java_function_descriptor_to_list =
|
||||
env->GetMethodID(java_function_descriptor_class, "toList", "()Ljava/util/List;");
|
||||
JavaStringListToNativeStringVector(
|
||||
env, env->CallObjectMethod(functionDescriptor, java_function_descriptor_to_list),
|
||||
&function_descriptor);
|
||||
ray::RayFunction ray_function{static_cast<::Language>(language), function_descriptor};
|
||||
return ray_function;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
#ifndef RAY_STREAMING_JNI_COMMON_H
|
||||
#define RAY_STREAMING_JNI_COMMON_H
|
||||
|
||||
#include <jni.h>
|
||||
#include <string>
|
||||
#include "ray/core_worker/common.h"
|
||||
#include "util/streaming_logging.h"
|
||||
|
||||
class UniqueIdFromJByteArray {
|
||||
private:
|
||||
JNIEnv *_env;
|
||||
jbyteArray _bytes;
|
||||
jbyte *b;
|
||||
|
||||
public:
|
||||
ray::ObjectID PID;
|
||||
|
||||
UniqueIdFromJByteArray(JNIEnv *env, jbyteArray wid) {
|
||||
_env = env;
|
||||
_bytes = wid;
|
||||
|
||||
b = reinterpret_cast<jbyte *>(_env->GetByteArrayElements(_bytes, nullptr));
|
||||
PID = ray::ObjectID::FromBinary(
|
||||
std::string(reinterpret_cast<const char*>(b), ray::ObjectID::Size()));
|
||||
}
|
||||
|
||||
~UniqueIdFromJByteArray() {
|
||||
_env->ReleaseByteArrayElements(_bytes, b, 0);
|
||||
}
|
||||
};
|
||||
|
||||
class RawDataFromJByteArray {
|
||||
private:
|
||||
JNIEnv *_env;
|
||||
jbyteArray _bytes;
|
||||
|
||||
public:
|
||||
uint8_t *data;
|
||||
uint32_t data_size;
|
||||
|
||||
RawDataFromJByteArray(JNIEnv *env, jbyteArray bytes) {
|
||||
_env = env;
|
||||
_bytes = bytes;
|
||||
data_size = _env->GetArrayLength(_bytes);
|
||||
jbyte *b =
|
||||
reinterpret_cast<jbyte *>(_env->GetByteArrayElements(_bytes, nullptr));
|
||||
data = reinterpret_cast<uint8_t *>(b);
|
||||
}
|
||||
|
||||
~RawDataFromJByteArray() {
|
||||
_env->ReleaseByteArrayElements(_bytes, reinterpret_cast<jbyte *>(data), 0);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class StringFromJString {
|
||||
private:
|
||||
JNIEnv *_env;
|
||||
const char *j_str;
|
||||
jstring jni_str;
|
||||
|
||||
public:
|
||||
std::string str;
|
||||
|
||||
StringFromJString(JNIEnv *env, jstring jni_str_) {
|
||||
jni_str = jni_str_;
|
||||
_env = env;
|
||||
j_str = env->GetStringUTFChars(jni_str, nullptr);
|
||||
str = std::string(j_str);
|
||||
}
|
||||
|
||||
~StringFromJString() {
|
||||
_env->ReleaseStringUTFChars(jni_str, j_str);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class LongVectorFromJLongArray {
|
||||
private:
|
||||
JNIEnv *_env;
|
||||
jlongArray long_array;
|
||||
jlong *long_array_ptr = nullptr;
|
||||
|
||||
public:
|
||||
std::vector<uint64_t> data;
|
||||
|
||||
LongVectorFromJLongArray(JNIEnv *env, jlongArray long_array_) {
|
||||
_env = env;
|
||||
long_array = long_array_;
|
||||
|
||||
long_array_ptr = env->GetLongArrayElements(long_array, nullptr);
|
||||
jsize seq_id_size = env->GetArrayLength(long_array);
|
||||
data = std::vector<uint64_t>(long_array_ptr, long_array_ptr + seq_id_size);
|
||||
}
|
||||
|
||||
~LongVectorFromJLongArray() {
|
||||
_env->ReleaseLongArrayElements(long_array, long_array_ptr, 0);
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<ray::ObjectID>
|
||||
jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr);
|
||||
std::vector<ray::ActorID>
|
||||
jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr);
|
||||
|
||||
jint throwRuntimeException(JNIEnv *env, const char *message);
|
||||
jint throwChannelInitException(JNIEnv *env, const char *message,
|
||||
const std::vector<ray::ObjectID> &abnormal_queues);
|
||||
jint throwChannelInterruptException(JNIEnv *env, const char *message);
|
||||
ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env, jobject functionDescriptor);
|
||||
#endif //RAY_STREAMING_JNI_COMMON_H
|
||||
Reference in New Issue
Block a user