[Streaming] Streaming data transfer supports cross language. (#7961)

* add init parameters for java

* fix bug

* cython

* fix compile

* fix test_direct_tranfer

* comment

* ChannelCreationParameter

* fix comment

* builder

* lint and fix tests

* fix single process test

* fix checkstyle and lint

* checkstyle

* lint python

Co-authored-by: wanxing <wanxing@B-458DMD6M-1753.local>
This commit is contained in:
wanxing
2020-04-16 15:16:48 +08:00
committed by GitHub
parent 5a7882bb44
commit 9345d03ffb
36 changed files with 618 additions and 333 deletions
@@ -9,13 +9,15 @@ using namespace ray::streaming;
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative(
JNIEnv *env, jclass, jobjectArray input_channels, jobjectArray input_actor_ids,
jlongArray seq_id_array, jlongArray msg_id_array, jlong timer_interval,
jboolean isRecreate, jbyteArray config_bytes, jboolean is_mock) {
JNIEnv *env, jclass, jobject streaming_queue_initial_parameters,
jobjectArray input_channels, jlongArray seq_id_array, jlongArray msg_id_array,
jlong timer_interval, jboolean isRecreate, jbyteArray config_bytes,
jboolean is_mock) {
STREAMING_LOG(INFO) << "[JNI]: create DataReader.";
std::vector<ray::streaming::ChannelCreationParameter> parameter_vec;
ParseChannelInitParameters(env, streaming_queue_initial_parameters, parameter_vec);
std::vector<ray::ObjectID> input_channels_ids =
jarray_to_object_id_vec(env, input_channels);
std::vector<ray::ActorID> actor_ids = jarray_to_actor_id_vec(env, input_actor_ids);
std::vector<uint64_t> seq_ids = LongVectorFromJLongArray(env, seq_id_array).data;
std::vector<uint64_t> msg_ids = LongVectorFromJLongArray(env, msg_id_array).data;
@@ -29,7 +31,7 @@ Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative(
ctx->MarkMockTest();
}
auto reader = new DataReader(ctx);
reader->Init(input_channels_ids, actor_ids, seq_ids, msg_ids, timer_interval);
reader->Init(input_channels_ids, parameter_vec, seq_ids, msg_ids, timer_interval);
STREAMING_LOG(INFO) << "create native DataReader succeed";
return reinterpret_cast<jlong>(reader);
}
@@ -72,17 +74,15 @@ JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBund
std::memcpy(meta + kMessageBundleHeaderSize, bundle->from.Data(), kUniqueIDSize);
}
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_DataReader_stopReaderNative(JNIEnv *env,
jobject thisObj,
jlong ptr) {
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_stopReaderNative(
JNIEnv *env, jobject thisObj, jlong ptr) {
auto reader = reinterpret_cast<DataReader *>(ptr);
reader->Stop();
}
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_DataReader_closeReaderNative(JNIEnv *env,
jobject thisObj,
jlong ptr) {
jobject thisObj,
jlong ptr) {
delete reinterpret_cast<DataReader *>(ptr);
}
@@ -10,34 +10,37 @@ extern "C" {
/*
* Class: io_ray_streaming_runtime_transfer_DataReader
* Method: createDataReaderNative
* Signature: ([[B[[B[J[JJZ[BZ)J
* Signature: (Lio/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder;[[B[J[JJZ[BZ)J
*/
JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative
(JNIEnv *, jclass, jobjectArray, jobjectArray, jlongArray, jlongArray, jlong, jboolean, jbyteArray, jboolean);
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative(
JNIEnv *, jclass, jobject, jobjectArray, jlongArray, jlongArray, jlong, jboolean,
jbyteArray, jboolean);
/*
* Class: io_ray_streaming_runtime_transfer_DataReader
* Method: getBundleNative
* Signature: (JJJJ)V
*/
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBundleNative
(JNIEnv *, jobject, jlong, jlong, jlong, jlong);
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBundleNative(
JNIEnv *, jobject, jlong, jlong, jlong, jlong);
/*
* Class: io_ray_streaming_runtime_transfer_DataReader
* Method: stopReaderNative
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_stopReaderNative
(JNIEnv *, jobject, jlong);
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_stopReaderNative(
JNIEnv *, jobject, jlong);
/*
* Class: io_ray_streaming_runtime_transfer_DataReader
* Method: closeReaderNative
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_closeReaderNative
(JNIEnv *, jobject, jlong);
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_DataReader_closeReaderNative(JNIEnv *, jobject,
jlong);
#ifdef __cplusplus
}
@@ -7,10 +7,13 @@ using namespace ray::streaming;
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative(
JNIEnv *env, jclass, jobjectArray output_queue_ids, jobjectArray output_actor_ids,
JNIEnv *env, jclass, jobject initial_parameters, jobjectArray output_queue_ids,
jlongArray msg_ids, jlong channel_size, jbyteArray conf_bytes_array,
jboolean is_mock) {
STREAMING_LOG(INFO) << "[JNI]: createDataWriterNative.";
std::vector<ray::streaming::ChannelCreationParameter> parameter_vec;
ParseChannelInitParameters(env, initial_parameters, parameter_vec);
std::vector<ray::ObjectID> queue_id_vec =
jarray_to_object_id_vec(env, output_queue_ids);
for (auto id : queue_id_vec) {
@@ -22,9 +25,6 @@ Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative(
std::vector<uint64_t> msg_ids_vec = LongVectorFromJLongArray(env, msg_ids).data;
std::vector<uint64_t> queue_size_vec(long_array_obj.data.size(), channel_size);
std::vector<ray::ObjectID> remain_id_vec;
std::vector<ray::ActorID> actor_ids = jarray_to_actor_id_vec(env, output_actor_ids);
STREAMING_LOG(INFO) << "actor_ids: " << actor_ids[0];
RawDataFromJByteArray conf(env, conf_bytes_array);
STREAMING_CHECK(conf.data != nullptr);
@@ -36,7 +36,8 @@ Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative(
runtime_context->MarkMockTest();
}
auto *data_writer = new DataWriter(runtime_context);
auto status = data_writer->Init(queue_id_vec, actor_ids, msg_ids_vec, queue_size_vec);
auto status =
data_writer->Init(queue_id_vec, parameter_vec, msg_ids_vec, queue_size_vec);
if (status != StreamingStatus::OK) {
STREAMING_LOG(WARNING) << "DataWriter init failed.";
} else {
@@ -64,10 +65,8 @@ Java_io_ray_streaming_runtime_transfer_DataWriter_writeMessageNative(
return result;
}
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative(JNIEnv *env,
jobject thisObj,
jlong ptr) {
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative(
JNIEnv *env, jobject thisObj, jlong ptr) {
STREAMING_LOG(INFO) << "jni: stop writer.";
auto *data_writer = reinterpret_cast<DataWriter *>(ptr);
data_writer->Stop();
@@ -75,8 +74,8 @@ Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative(JNIEnv *env,
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *env,
jobject thisObj,
jlong ptr) {
jobject thisObj,
jlong ptr) {
auto *data_writer = reinterpret_cast<DataWriter *>(ptr);
delete data_writer;
}
@@ -10,34 +10,38 @@ extern "C" {
/*
* Class: io_ray_streaming_runtime_transfer_DataWriter
* Method: createWriterNative
* Signature: ([[B[[B[JJ[BZ)J
* Signature: (Lio/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder;[[B[JJ[BZ)J
*/
JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative
(JNIEnv *, jclass, jobjectArray, jobjectArray, jlongArray, jlong, jbyteArray, jboolean);
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative(
JNIEnv *, jclass, jobject, jobjectArray, jlongArray, jlong, jbyteArray, jboolean);
/*
* Class: io_ray_streaming_runtime_transfer_DataWriter
* Method: writeMessageNative
* Signature: (JJJI)J
*/
JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_writeMessageNative
(JNIEnv *, jobject, jlong, jlong, jlong, jint);
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_DataWriter_writeMessageNative(JNIEnv *, jobject,
jlong, jlong, jlong,
jint);
/*
* Class: io_ray_streaming_runtime_transfer_DataWriter
* Method: stopWriterNative
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative
(JNIEnv *, jobject, jlong);
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative(
JNIEnv *, jobject, jlong);
/*
* Class: io_ray_streaming_runtime_transfer_DataWriter
* Method: closeWriterNative
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative
(JNIEnv *, jobject, jlong);
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *, jobject,
jlong);
#ifdef __cplusplus
}
@@ -14,29 +14,18 @@ static std::shared_ptr<ray::LocalMemoryBuffer> JByteArrayToBuffer(JNIEnv *env,
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(
JNIEnv *env, jobject this_obj, jobject async_func, jobject sync_func) {
auto ray_async_func = FunctionDescriptorToRayFunction(env, async_func);
auto ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func);
auto *writer_client = new WriterClient(ray_async_func, ray_sync_func);
JNIEnv *env, jobject this_obj) {
auto *writer_client = new WriterClient();
return reinterpret_cast<jlong>(writer_client);
}
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(
JNIEnv *env, jobject this_obj, jobject async_func, jobject sync_func) {
ray::RayFunction ray_async_func = FunctionDescriptorToRayFunction(env, async_func);
ray::RayFunction ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func);
auto *reader_client = new ReaderClient(ray_async_func, ray_sync_func);
JNIEnv *env, jobject this_obj) {
auto *reader_client = new ReaderClient();
return reinterpret_cast<jlong>(reader_client);
}
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative(
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
auto *writer_client = reinterpret_cast<WriterClient *>(ptr);
writer_client->OnWriterMessage(JByteArrayToBuffer(env, bytes));
}
JNIEXPORT jbyteArray JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
@@ -66,4 +55,4 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNa
env->SetByteArrayRegion(arr, 0, result_buffer->Size(),
reinterpret_cast<jbyte *>(result_buffer->Data()));
return arr;
}
}
@@ -12,22 +12,16 @@ extern "C" {
* Method: createWriterClientNative
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(JNIEnv *,
jobject,
jobject,
jobject);
JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative
(JNIEnv *, jobject);
/*
* Class: io_ray_streaming_runtime_transfer_TransferHandler
* Method: createReaderClientNative
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(JNIEnv *,
jobject,
jobject,
jobject);
JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative
(JNIEnv *, jobject);
/*
* Class: io_ray_streaming_runtime_transfer_TransferHandler
@@ -44,7 +38,7 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative
* Signature: (J[B)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
JNIEnv *, jobject, jlong, jbyteArray);
/*
@@ -62,10 +56,10 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative
* Signature: (J[B)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative(
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative(
JNIEnv *, jobject, jlong, jbyteArray);
#ifdef __cplusplus
}
#endif
#endif
#endif
+64 -5
View File
@@ -88,6 +88,15 @@ void JavaListToNativeVector(JNIEnv *env, jobject java_list,
}
}
/// Convert a Java byte array to a C++ UniqueID.
template <typename ID>
inline ID JavaByteArrayToId(JNIEnv *env, const jbyteArray &bytes) {
std::string id_str(ID::Size(), 0);
env->GetByteArrayRegion(bytes, 0, ID::Size(),
reinterpret_cast<jbyte *>(&id_str.front()));
return ID::FromBinary(id_str);
}
/// Convert a Java String to C++ std::string.
std::string JavaStringToNativeString(JNIEnv *env, jstring jstr) {
const char *c_str = env->GetStringUTFChars(jstr, nullptr);
@@ -105,10 +114,9 @@ void JavaStringListToNativeStringVector(JNIEnv *env, jobject java_list,
});
}
ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env,
jobject functionDescriptor) {
jclass java_language_class =
LoadClass(env, "io/ray/runtime/generated/Common$Language");
std::shared_ptr<ray::RayFunction> FunctionDescriptorToRayFunction(
JNIEnv *env, jobject functionDescriptor) {
jclass java_language_class = LoadClass(env, "io/ray/runtime/generated/Common$Language");
jclass java_function_descriptor_class =
LoadClass(env, "io/ray/runtime/functionmanager/FunctionDescriptor");
jmethodID java_language_get_number =
@@ -129,5 +137,56 @@ ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env,
ray::FunctionDescriptor function_descriptor =
ray::FunctionDescriptorBuilder::FromVector(language, function_descriptor_list);
ray::RayFunction ray_function{language, function_descriptor};
return ray_function;
return std::make_shared<ray::RayFunction>(ray_function);
}
void ParseChannelInitParameters(
JNIEnv *env, jobject param_obj,
std::vector<ray::streaming::ChannelCreationParameter> &parameter_vec) {
jclass java_streaming_queue_initial_parameters_class =
LoadClass(env,
"io/ray/streaming/runtime/transfer/"
"ChannelCreationParametersBuilder");
jmethodID java_streaming_queue_initial_parameters_getParameters_method =
env->GetMethodID(java_streaming_queue_initial_parameters_class, "getParameters",
"()Ljava/util/List;");
STREAMING_CHECK(java_streaming_queue_initial_parameters_getParameters_method !=
nullptr);
jclass java_streaming_queue_initial_parameters_parameter_class =
LoadClass(env,
"io/ray/streaming/runtime/transfer/"
"ChannelCreationParametersBuilder$Parameter");
jmethodID java_getActorIdBytes_method = env->GetMethodID(
java_streaming_queue_initial_parameters_parameter_class, "getActorIdBytes", "()[B");
jmethodID java_getAsyncFunctionDescriptor_method =
env->GetMethodID(java_streaming_queue_initial_parameters_parameter_class,
"getAsyncFunctionDescriptor",
"()Lio/ray/runtime/functionmanager/FunctionDescriptor;");
jmethodID java_getSyncFunctionDescriptor_method =
env->GetMethodID(java_streaming_queue_initial_parameters_parameter_class,
"getSyncFunctionDescriptor",
"()Lio/ray/runtime/functionmanager/FunctionDescriptor;");
// Call getParameters method
jobject parameter_list = env->CallObjectMethod(
param_obj, java_streaming_queue_initial_parameters_getParameters_method);
JavaListToNativeVector<ray::streaming::ChannelCreationParameter>(
env, parameter_list, &parameter_vec,
[java_getActorIdBytes_method, java_getAsyncFunctionDescriptor_method,
java_getSyncFunctionDescriptor_method](JNIEnv *env, jobject jobject_parameter) {
ray::streaming::ChannelCreationParameter native_parameter;
jbyteArray jobject_actor_id_bytes = (jbyteArray)env->CallObjectMethod(
jobject_parameter, java_getActorIdBytes_method);
native_parameter.actor_id =
JavaByteArrayToId<ray::ActorID>(env, jobject_actor_id_bytes);
jobject jobject_async_func = env->CallObjectMethod(
jobject_parameter, java_getAsyncFunctionDescriptor_method);
native_parameter.async_function =
FunctionDescriptorToRayFunction(env, jobject_async_func);
jobject jobject_sync_func = env->CallObjectMethod(
jobject_parameter, java_getSyncFunctionDescriptor_method);
native_parameter.sync_function =
FunctionDescriptorToRayFunction(env, jobject_sync_func);
return native_parameter;
});
}
+13 -17
View File
@@ -3,6 +3,7 @@
#include <jni.h>
#include <string>
#include "channel.h"
#include "ray/core_worker/common.h"
#include "util/streaming_logging.h"
@@ -21,12 +22,10 @@ class UniqueIdFromJByteArray {
b = reinterpret_cast<jbyte *>(_env->GetByteArrayElements(_bytes, nullptr));
PID = ray::ObjectID::FromBinary(
std::string(reinterpret_cast<const char*>(b), ray::ObjectID::Size()));
std::string(reinterpret_cast<const char *>(b), ray::ObjectID::Size()));
}
~UniqueIdFromJByteArray() {
_env->ReleaseByteArrayElements(_bytes, b, 0);
}
~UniqueIdFromJByteArray() { _env->ReleaseByteArrayElements(_bytes, b, 0); }
};
class RawDataFromJByteArray {
@@ -42,15 +41,13 @@ class RawDataFromJByteArray {
_env = env;
_bytes = bytes;
data_size = _env->GetArrayLength(_bytes);
jbyte *b =
reinterpret_cast<jbyte *>(_env->GetByteArrayElements(_bytes, nullptr));
jbyte *b = reinterpret_cast<jbyte *>(_env->GetByteArrayElements(_bytes, nullptr));
data = reinterpret_cast<uint8_t *>(b);
}
~RawDataFromJByteArray() {
_env->ReleaseByteArrayElements(_bytes, reinterpret_cast<jbyte *>(data), 0);
}
};
class StringFromJString {
@@ -69,10 +66,7 @@ class StringFromJString {
str = std::string(j_str);
}
~StringFromJString() {
_env->ReleaseStringUTFChars(jni_str, j_str);
}
~StringFromJString() { _env->ReleaseStringUTFChars(jni_str, j_str); }
};
class LongVectorFromJLongArray {
@@ -98,14 +92,16 @@ class LongVectorFromJLongArray {
}
};
std::vector<ray::ObjectID>
jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr);
std::vector<ray::ActorID>
jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr);
std::vector<ray::ObjectID> jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr);
std::vector<ray::ActorID> jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr);
jint throwRuntimeException(JNIEnv *env, const char *message);
jint throwChannelInitException(JNIEnv *env, const char *message,
const std::vector<ray::ObjectID> &abnormal_queues);
jint throwChannelInterruptException(JNIEnv *env, const char *message);
ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env, jobject functionDescriptor);
#endif //RAY_STREAMING_JNI_COMMON_H
std::shared_ptr<ray::RayFunction> FunctionDescriptorToRayFunction(
JNIEnv *env, jobject functionDescriptor);
void ParseChannelInitParameters(
JNIEnv *env, jobject param_obj,
std::vector<ray::streaming::ChannelCreationParameter> &parameter_vec);
#endif // RAY_STREAMING_JNI_COMMON_H