[Streaming] Streaming data transfer java (#6474)

This commit is contained in:
Chaokun Yang
2019-12-22 10:56:05 +08:00
committed by Hao Chen
parent 1b14fbe179
commit 7bbfa85c66
146 changed files with 3923 additions and 786 deletions
@@ -0,0 +1,17 @@
#include "org_ray_streaming_runtime_transfer_ChannelID.h"
#include "streaming_jni_common.h"
using namespace ray::streaming;
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_createNativeID(
JNIEnv *env, jclass cls, jlong qid_address) {
auto id = ray::ObjectID::FromBinary(
std::string(reinterpret_cast<const char *>(qid_address), ray::ObjectID::Size()));
return reinterpret_cast<jlong>(new ray::ObjectID(id));
}
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_destroyNativeID(
JNIEnv *env, jclass cls, jlong native_id_ptr) {
auto id = reinterpret_cast<ray::ObjectID *>(native_id_ptr);
STREAMING_CHECK(id != nullptr);
delete id;
}
@@ -0,0 +1,31 @@
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class org_ray_streaming_runtime_transfer_ChannelID */
#ifndef _Included_org_ray_streaming_runtime_transfer_ChannelID
#define _Included_org_ray_streaming_runtime_transfer_ChannelID
#ifdef __cplusplus
extern "C" {
#endif
#undef org_ray_streaming_runtime_transfer_ChannelID_ID_LENGTH
#define org_ray_streaming_runtime_transfer_ChannelID_ID_LENGTH 20L
/*
* Class: org_ray_streaming_runtime_transfer_ChannelID
* Method: createNativeID
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_createNativeID
(JNIEnv *, jclass, jlong);
/*
* Class: org_ray_streaming_runtime_transfer_ChannelID
* Method: destroyNativeID
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_destroyNativeID
(JNIEnv *, jclass, jlong);
#ifdef __cplusplus
}
#endif
#endif
@@ -0,0 +1,88 @@
#include "org_ray_streaming_runtime_transfer_DataReader.h"
#include <cstdlib>
#include "data_reader.h"
#include "runtime_context.h"
#include "streaming_jni_common.h"
using namespace ray;
using namespace ray::streaming;
JNIEXPORT jlong JNICALL
Java_org_ray_streaming_runtime_transfer_DataReader_createDataReaderNative(
JNIEnv *env, jclass, jobjectArray input_channels, jobjectArray input_actor_ids,
jlongArray seq_id_array, jlongArray msg_id_array, jlong timer_interval,
jboolean isRecreate, jbyteArray config_bytes, jboolean is_mock) {
STREAMING_LOG(INFO) << "[JNI]: create DataReader.";
std::vector<ray::ObjectID> input_channels_ids =
jarray_to_object_id_vec(env, input_channels);
std::vector<ray::ActorID> actor_ids = jarray_to_actor_id_vec(env, input_actor_ids);
std::vector<uint64_t> seq_ids = LongVectorFromJLongArray(env, seq_id_array).data;
std::vector<uint64_t> msg_ids = LongVectorFromJLongArray(env, msg_id_array).data;
auto ctx = std::make_shared<RuntimeContext>();
RawDataFromJByteArray conf(env, config_bytes);
if (conf.data_size > 0) {
STREAMING_LOG(INFO) << "load config, config bytes size: " << conf.data_size;
ctx->SetConfig(conf.data, conf.data_size);
}
if (is_mock) {
ctx->MarkMockTest();
}
auto reader = new DataReader(ctx);
reader->Init(input_channels_ids, actor_ids, seq_ids, msg_ids, timer_interval);
STREAMING_LOG(INFO) << "create native DataReader succeed";
return reinterpret_cast<jlong>(reader);
}
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_getBundleNative(
JNIEnv *env, jobject, jlong reader_ptr, jlong timeout_millis, jlong out,
jlong meta_addr) {
std::shared_ptr<ray::streaming::DataBundle> bundle;
auto reader = reinterpret_cast<ray::streaming::DataReader *>(reader_ptr);
auto status = reader->GetBundle((uint32_t)timeout_millis, bundle);
// over timeout, return empty array.
if (StreamingStatus::Interrupted == status) {
throwChannelInterruptException(env, "reader interrupted.");
} else if (StreamingStatus::GetBundleTimeOut == status) {
} else if (StreamingStatus::InitQueueFailed == status) {
throwRuntimeException(env, "init channel failed");
} else if (StreamingStatus::WaitQueueTimeOut == status) {
throwRuntimeException(env, "wait channel object timeout");
}
if (StreamingStatus::OK != status) {
*reinterpret_cast<uint64_t *>(out) = 0;
*reinterpret_cast<uint32_t *>(out + 8) = 0;
return;
}
// bundle data
// In streaming queue, bundle data and metadata will be different args of direct call,
// so we separate it here for future extensibility.
*reinterpret_cast<uint64_t *>(out) =
reinterpret_cast<uint64_t>(bundle->data + kMessageBundleHeaderSize);
*reinterpret_cast<uint32_t *>(out + 8) = bundle->data_size - kMessageBundleHeaderSize;
// bundle metadata
auto meta = reinterpret_cast<uint8_t *>(meta_addr);
// bundle header written by writer
std::memcpy(meta, bundle->data, kMessageBundleHeaderSize);
// append qid
std::memcpy(meta + kMessageBundleHeaderSize, bundle->from.Data(), kUniqueIDSize);
}
JNIEXPORT void JNICALL
Java_org_ray_streaming_runtime_transfer_DataReader_stopReaderNative(JNIEnv *env,
jobject thisObj,
jlong ptr) {
auto reader = reinterpret_cast<DataReader *>(ptr);
reader->Stop();
}
JNIEXPORT void JNICALL
Java_org_ray_streaming_runtime_transfer_DataReader_closeReaderNative(JNIEnv *env,
jobject thisObj,
jlong ptr) {
delete reinterpret_cast<DataReader *>(ptr);
}
@@ -0,0 +1,45 @@
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class org_ray_streaming_runtime_transfer_DataReader */
#ifndef _Included_org_ray_streaming_runtime_transfer_DataReader
#define _Included_org_ray_streaming_runtime_transfer_DataReader
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_ray_streaming_runtime_transfer_DataReader
* Method: createDataReaderNative
* Signature: ([[B[[B[J[JJZ[BZ)J
*/
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_createDataReaderNative
(JNIEnv *, jclass, jobjectArray, jobjectArray, jlongArray, jlongArray, jlong, jboolean, jbyteArray, jboolean);
/*
* Class: org_ray_streaming_runtime_transfer_DataReader
* Method: getBundleNative
* Signature: (JJJJ)V
*/
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_getBundleNative
(JNIEnv *, jobject, jlong, jlong, jlong, jlong);
/*
* Class: org_ray_streaming_runtime_transfer_DataReader
* Method: stopReaderNative
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_stopReaderNative
(JNIEnv *, jobject, jlong);
/*
* Class: org_ray_streaming_runtime_transfer_DataReader
* Method: closeReaderNative
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_closeReaderNative
(JNIEnv *, jobject, jlong);
#ifdef __cplusplus
}
#endif
#endif
@@ -0,0 +1,82 @@
#include "org_ray_streaming_runtime_transfer_DataWriter.h"
#include "config/streaming_config.h"
#include "data_writer.h"
#include "streaming_jni_common.h"
using namespace ray::streaming;
JNIEXPORT jlong JNICALL
Java_org_ray_streaming_runtime_transfer_DataWriter_createWriterNative(
JNIEnv *env, jclass, jobjectArray output_queue_ids, jobjectArray output_actor_ids,
jlongArray msg_ids, jlong channel_size, jbyteArray conf_bytes_array,
jboolean is_mock) {
STREAMING_LOG(INFO) << "[JNI]: createDataWriterNative.";
std::vector<ray::ObjectID> queue_id_vec =
jarray_to_object_id_vec(env, output_queue_ids);
for (auto id : queue_id_vec) {
STREAMING_LOG(INFO) << "output channel id: " << id.Hex();
}
STREAMING_LOG(INFO) << "total channel size: " << channel_size << "*"
<< queue_id_vec.size() << "=" << queue_id_vec.size() * channel_size;
LongVectorFromJLongArray long_array_obj(env, msg_ids);
std::vector<uint64_t> msg_ids_vec = LongVectorFromJLongArray(env, msg_ids).data;
std::vector<uint64_t> queue_size_vec(long_array_obj.data.size(), channel_size);
std::vector<ray::ObjectID> remain_id_vec;
std::vector<ray::ActorID> actor_ids = jarray_to_actor_id_vec(env, output_actor_ids);
STREAMING_LOG(INFO) << "actor_ids: " << actor_ids[0];
RawDataFromJByteArray conf(env, conf_bytes_array);
STREAMING_CHECK(conf.data != nullptr);
auto runtime_context = std::make_shared<RuntimeContext>();
if (conf.data_size > 0) {
runtime_context->SetConfig(conf.data, conf.data_size);
}
if (is_mock) {
runtime_context->MarkMockTest();
}
auto *data_writer = new DataWriter(runtime_context);
auto status = data_writer->Init(queue_id_vec, actor_ids, msg_ids_vec, queue_size_vec);
if (status != StreamingStatus::OK) {
STREAMING_LOG(WARNING) << "DataWriter init failed.";
} else {
STREAMING_LOG(INFO) << "DataWriter init success";
}
data_writer->Run();
return reinterpret_cast<jlong>(data_writer);
}
JNIEXPORT jlong JNICALL
Java_org_ray_streaming_runtime_transfer_DataWriter_writeMessageNative(
JNIEnv *env, jobject, jlong writer_ptr, jlong qid_ptr, jlong address, jint size) {
auto *data_writer = reinterpret_cast<DataWriter *>(writer_ptr);
auto qid = *reinterpret_cast<ray::ObjectID *>(qid_ptr);
auto data = reinterpret_cast<uint8_t *>(address);
auto data_size = static_cast<uint32_t>(size);
jlong result = data_writer->WriteMessageToBufferRing(qid, data, data_size,
StreamingMessageType::Message);
if (result == 0) {
STREAMING_LOG(INFO) << "writer interrupted, return 0.";
throwChannelInterruptException(env, "writer interrupted.");
}
return result;
}
JNIEXPORT void JNICALL
Java_org_ray_streaming_runtime_transfer_DataWriter_stopWriterNative(JNIEnv *env,
jobject thisObj,
jlong ptr) {
STREAMING_LOG(INFO) << "jni: stop writer.";
auto *data_writer = reinterpret_cast<DataWriter *>(ptr);
data_writer->Stop();
}
JNIEXPORT void JNICALL
Java_org_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *env,
jobject thisObj,
jlong ptr) {
auto *data_writer = reinterpret_cast<DataWriter *>(ptr);
delete data_writer;
}
@@ -0,0 +1,45 @@
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class org_ray_streaming_runtime_transfer_DataWriter */
#ifndef _Included_org_ray_streaming_runtime_transfer_DataWriter
#define _Included_org_ray_streaming_runtime_transfer_DataWriter
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_ray_streaming_runtime_transfer_DataWriter
* Method: createWriterNative
* Signature: ([[B[[B[JJ[BZ)J
*/
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_createWriterNative
(JNIEnv *, jclass, jobjectArray, jobjectArray, jlongArray, jlong, jbyteArray, jboolean);
/*
* Class: org_ray_streaming_runtime_transfer_DataWriter
* Method: writeMessageNative
* Signature: (JJJI)J
*/
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_writeMessageNative
(JNIEnv *, jobject, jlong, jlong, jlong, jint);
/*
* Class: org_ray_streaming_runtime_transfer_DataWriter
* Method: stopWriterNative
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_stopWriterNative
(JNIEnv *, jobject, jlong);
/*
* Class: org_ray_streaming_runtime_transfer_DataWriter
* Method: closeWriterNative
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_closeWriterNative
(JNIEnv *, jobject, jlong);
#ifdef __cplusplus
}
#endif
#endif
@@ -0,0 +1,75 @@
#include "org_ray_streaming_runtime_transfer_TransferHandler.h"
#include "queue/queue_client.h"
#include "streaming_jni_common.h"
using namespace ray::streaming;
static std::shared_ptr<ray::LocalMemoryBuffer> JByteArrayToBuffer(JNIEnv *env,
jbyteArray bytes) {
RawDataFromJByteArray buf(env, bytes);
STREAMING_CHECK(buf.data != nullptr);
return std::make_shared<ray::LocalMemoryBuffer>(buf.data, buf.data_size, true);
}
JNIEXPORT jlong JNICALL
Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(
JNIEnv *env, jobject this_obj, jlong core_worker_ptr, jobject async_func,
jobject sync_func) {
auto ray_async_func = FunctionDescriptorToRayFunction(env, async_func);
auto ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func);
auto *writer_client =
new WriterClient(reinterpret_cast<ray::CoreWorker *>(core_worker_ptr),
ray_async_func, ray_sync_func);
return reinterpret_cast<jlong>(writer_client);
}
JNIEXPORT jlong JNICALL
Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(
JNIEnv *env, jobject this_obj, jlong core_worker_ptr, jobject async_func,
jobject sync_func) {
ray::RayFunction ray_async_func = FunctionDescriptorToRayFunction(env, async_func);
ray::RayFunction ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func);
auto *reader_client =
new ReaderClient(reinterpret_cast<ray::CoreWorker *>(core_worker_ptr),
ray_async_func, ray_sync_func);
return reinterpret_cast<jlong>(reader_client);
}
JNIEXPORT void JNICALL
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative(
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
auto *writer_client = reinterpret_cast<WriterClient *>(ptr);
writer_client->OnWriterMessage(JByteArrayToBuffer(env, bytes));
}
JNIEXPORT jbyteArray JNICALL
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
auto *writer_client = reinterpret_cast<WriterClient *>(ptr);
std::shared_ptr<ray::LocalMemoryBuffer> result_buffer =
writer_client->OnWriterMessageSync(JByteArrayToBuffer(env, bytes));
jbyteArray arr = env->NewByteArray(result_buffer->Size());
env->SetByteArrayRegion(arr, 0, result_buffer->Size(),
reinterpret_cast<jbyte *>(result_buffer->Data()));
return arr;
}
JNIEXPORT void JNICALL
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative(
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
auto *reader_client = reinterpret_cast<ReaderClient *>(ptr);
reader_client->OnReaderMessage(JByteArrayToBuffer(env, bytes));
}
JNIEXPORT jbyteArray JNICALL
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative(
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
auto *reader_client = reinterpret_cast<ReaderClient *>(ptr);
auto result_buffer = reader_client->OnReaderMessageSync(JByteArrayToBuffer(env, bytes));
jbyteArray arr = env->NewByteArray(result_buffer->Size());
env->SetByteArrayRegion(arr, 0, result_buffer->Size(),
reinterpret_cast<jbyte *>(result_buffer->Data()));
return arr;
}
@@ -0,0 +1,61 @@
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class org_ray_streaming_runtime_transfer_TransferHandler */
#ifndef _Included_org_ray_streaming_runtime_transfer_TransferHandler
#define _Included_org_ray_streaming_runtime_transfer_TransferHandler
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_ray_streaming_runtime_transfer_TransferHandler
* Method: createWriterClientNative
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative
(JNIEnv *, jobject, jlong, jobject, jobject);
/*
* Class: org_ray_streaming_runtime_transfer_TransferHandler
* Method: createReaderClientNative
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative
(JNIEnv *, jobject, jlong, jobject, jobject);
/*
* Class: org_ray_streaming_runtime_transfer_TransferHandler
* Method: handleWriterMessageNative
* Signature: (J[B)V
*/
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative
(JNIEnv *, jobject, jlong, jbyteArray);
/*
* Class: org_ray_streaming_runtime_transfer_TransferHandler
* Method: handleWriterMessageSyncNative
* Signature: (J[B)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative
(JNIEnv *, jobject, jlong, jbyteArray);
/*
* Class: org_ray_streaming_runtime_transfer_TransferHandler
* Method: handleReaderMessageNative
* Signature: (J[B)V
*/
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative
(JNIEnv *, jobject, jlong, jbyteArray);
/*
* Class: org_ray_streaming_runtime_transfer_TransferHandler
* Method: handleReaderMessageSyncNative
* Signature: (J[B)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative
(JNIEnv *, jobject, jlong, jbyteArray);
#ifdef __cplusplus
}
#endif
#endif
@@ -0,0 +1,123 @@
#include "streaming_jni_common.h"
std::vector<ray::ObjectID>
jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr) {
int stringCount = env->GetArrayLength(jarr);
std::vector<ray::ObjectID> object_id_vec;
for (int i = 0; i < stringCount; i++) {
auto jstr = (jbyteArray) (env->GetObjectArrayElement(jarr, i));
UniqueIdFromJByteArray idFromJByteArray(env, jstr);
object_id_vec.push_back(idFromJByteArray.PID);
}
return object_id_vec;
}
std::vector<ray::ActorID>
jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr) {
int count = env->GetArrayLength(jarr);
std::vector<ray::ActorID> actor_id_vec;
for (int i = 0; i < count; i++) {
auto bytes = (jbyteArray)(env->GetObjectArrayElement(jarr, i));
std::string id_str(ray::ActorID::Size(), 0);
env->GetByteArrayRegion(bytes, 0, ray::ActorID::Size(),
reinterpret_cast<jbyte *>(&id_str.front()));
actor_id_vec.push_back(ActorID::FromBinary(id_str));
}
return actor_id_vec;
}
jint throwRuntimeException(JNIEnv *env, const char *message) {
jclass exClass;
char className[] = "java/lang/RuntimeException";
exClass = env->FindClass(className);
return env->ThrowNew(exClass, message);
}
jint throwChannelInitException(JNIEnv *env, const char *message,
const std::vector<ray::ObjectID> &abnormal_queues) {
jclass array_list_class = env->FindClass("java/util/ArrayList");
jmethodID array_list_constructor = env->GetMethodID(array_list_class, "<init>", "()V");
jmethodID array_list_add = env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z");
jobject array_list = env->NewObject(array_list_class, array_list_constructor);
for (auto &q_id : abnormal_queues) {
jbyteArray jbyte_array = env->NewByteArray(kUniqueIDSize);
env->SetByteArrayRegion(jbyte_array, 0, kUniqueIDSize, const_cast<jbyte*>(reinterpret_cast<const jbyte *>(q_id.Data())));
env->CallBooleanMethod(array_list, array_list_add, jbyte_array);
}
jclass ex_class = env->FindClass("org/ray/streaming/runtime/transfer/ChannelInitException");
jmethodID ex_constructor = env->GetMethodID(ex_class, "<init>", "(Ljava/lang/String;Ljava/util/List;)V");
jstring message_jstr = env->NewStringUTF(message);
jobject ex_obj = env->NewObject(ex_class, ex_constructor, message_jstr, array_list);
env->DeleteLocalRef(message_jstr);
return env->Throw((jthrowable)ex_obj);
}
jint throwChannelInterruptException(JNIEnv *env, const char *message) {
jclass ex_class = env->FindClass("org/ray/streaming/runtime/transfer/ChannelInterruptException");
return env->ThrowNew(ex_class, message);
}
jclass LoadClass(JNIEnv *env, const char *class_name) {
jclass tempLocalClassRef = env->FindClass(class_name);
jclass ret = (jclass)env->NewGlobalRef(tempLocalClassRef);
STREAMING_CHECK(ret) << "Can't load Java class " << class_name;
env->DeleteLocalRef(tempLocalClassRef);
return ret;
}
template <typename NativeT>
void JavaListToNativeVector(
JNIEnv *env, jobject java_list, std::vector<NativeT> *native_vector,
std::function<NativeT(JNIEnv *, jobject)> element_converter) {
jclass java_list_class = LoadClass(env, "java/util/List");
jmethodID java_list_size = env->GetMethodID(java_list_class, "size", "()I");
jmethodID java_list_get = env->GetMethodID(java_list_class, "get", "(I)Ljava/lang/Object;");
int size = env->CallIntMethod(java_list, java_list_size);
native_vector->clear();
for (int i = 0; i < size; i++) {
native_vector->emplace_back(
element_converter(env, env->CallObjectMethod(java_list, java_list_get, (jint)i)));
}
}
/// Convert a Java String to C++ std::string.
std::string JavaStringToNativeString(JNIEnv *env, jstring jstr) {
const char *c_str = env->GetStringUTFChars(jstr, nullptr);
std::string result(c_str);
env->ReleaseStringUTFChars(static_cast<jstring>(jstr), c_str);
return result;
}
/// Convert a Java List<String> to C++ std::vector<std::string>.
void JavaStringListToNativeStringVector(JNIEnv *env, jobject java_list,
std::vector<std::string> *native_vector) {
JavaListToNativeVector<std::string>(
env, java_list, native_vector, [](JNIEnv *env, jobject jstr) {
return JavaStringToNativeString(env, static_cast<jstring>(jstr));
});
}
ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env, jobject functionDescriptor) {
jclass java_language_class = LoadClass(env, "org/ray/runtime/generated/Common$Language");
jclass java_function_descriptor_class =
LoadClass(env, "org/ray/runtime/functionmanager/FunctionDescriptor");
jmethodID java_language_get_number = env->GetMethodID(java_language_class, "getNumber", "()I");
jmethodID java_function_descriptor_get_language =
env->GetMethodID(java_function_descriptor_class, "getLanguage",
"()Lorg/ray/runtime/generated/Common$Language;");
jobject java_language =
env->CallObjectMethod(functionDescriptor, java_function_descriptor_get_language);
int language = env->CallIntMethod(java_language, java_language_get_number);
std::vector<std::string> function_descriptor;
jmethodID java_function_descriptor_to_list =
env->GetMethodID(java_function_descriptor_class, "toList", "()Ljava/util/List;");
JavaStringListToNativeStringVector(
env, env->CallObjectMethod(functionDescriptor, java_function_descriptor_to_list),
&function_descriptor);
ray::RayFunction ray_function{static_cast<::Language>(language), function_descriptor};
return ray_function;
}
@@ -0,0 +1,111 @@
#ifndef RAY_STREAMING_JNI_COMMON_H
#define RAY_STREAMING_JNI_COMMON_H
#include <jni.h>
#include <string>
#include "ray/core_worker/common.h"
#include "util/streaming_logging.h"
class UniqueIdFromJByteArray {
private:
JNIEnv *_env;
jbyteArray _bytes;
jbyte *b;
public:
ray::ObjectID PID;
UniqueIdFromJByteArray(JNIEnv *env, jbyteArray wid) {
_env = env;
_bytes = wid;
b = reinterpret_cast<jbyte *>(_env->GetByteArrayElements(_bytes, nullptr));
PID = ray::ObjectID::FromBinary(
std::string(reinterpret_cast<const char*>(b), ray::ObjectID::Size()));
}
~UniqueIdFromJByteArray() {
_env->ReleaseByteArrayElements(_bytes, b, 0);
}
};
class RawDataFromJByteArray {
private:
JNIEnv *_env;
jbyteArray _bytes;
public:
uint8_t *data;
uint32_t data_size;
RawDataFromJByteArray(JNIEnv *env, jbyteArray bytes) {
_env = env;
_bytes = bytes;
data_size = _env->GetArrayLength(_bytes);
jbyte *b =
reinterpret_cast<jbyte *>(_env->GetByteArrayElements(_bytes, nullptr));
data = reinterpret_cast<uint8_t *>(b);
}
~RawDataFromJByteArray() {
_env->ReleaseByteArrayElements(_bytes, reinterpret_cast<jbyte *>(data), 0);
}
};
class StringFromJString {
private:
JNIEnv *_env;
const char *j_str;
jstring jni_str;
public:
std::string str;
StringFromJString(JNIEnv *env, jstring jni_str_) {
jni_str = jni_str_;
_env = env;
j_str = env->GetStringUTFChars(jni_str, nullptr);
str = std::string(j_str);
}
~StringFromJString() {
_env->ReleaseStringUTFChars(jni_str, j_str);
}
};
class LongVectorFromJLongArray {
private:
JNIEnv *_env;
jlongArray long_array;
jlong *long_array_ptr = nullptr;
public:
std::vector<uint64_t> data;
LongVectorFromJLongArray(JNIEnv *env, jlongArray long_array_) {
_env = env;
long_array = long_array_;
long_array_ptr = env->GetLongArrayElements(long_array, nullptr);
jsize seq_id_size = env->GetArrayLength(long_array);
data = std::vector<uint64_t>(long_array_ptr, long_array_ptr + seq_id_size);
}
~LongVectorFromJLongArray() {
_env->ReleaseLongArrayElements(long_array, long_array_ptr, 0);
}
};
std::vector<ray::ObjectID>
jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr);
std::vector<ray::ActorID>
jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr);
jint throwRuntimeException(JNIEnv *env, const char *message);
jint throwChannelInitException(JNIEnv *env, const char *message,
const std::vector<ray::ObjectID> &abnormal_queues);
jint throwChannelInterruptException(JNIEnv *env, const char *message);
ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env, jobject functionDescriptor);
#endif //RAY_STREAMING_JNI_COMMON_H