diff --git a/BUILD.bazel b/BUILD.bazel
index 95fd33a52..cc3da5139 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -647,13 +647,13 @@ pyx_library(
)
cc_binary(
- name = "libraylet_library_java.so",
- srcs = [
- "src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h",
- "src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc",
- "src/ray/common/id.h",
- "src/ray/raylet/raylet_client.h",
- "src/ray/util/logging.h",
+ name = "libcore_worker_library_java.so",
+ srcs = glob([
+ "src/ray/core_worker/lib/java/*.h",
+ "src/ray/core_worker/lib/java/*.cc",
+ "src/ray/raylet/lib/java/*.h",
+ "src/ray/raylet/lib/java/*.cc",
+ ]) + [
"@bazel_tools//tools/jdk:jni_header",
] + select({
"@bazel_tools//src/conditions:windows": ["@bazel_tools//tools/jdk:jni_md_header-windows"],
@@ -671,24 +671,23 @@ cc_binary(
linkshared = 1,
linkstatic = 1,
deps = [
- "//:raylet_lib",
- "@plasma//:plasma_client",
+ "//:core_worker_lib",
],
)
genrule(
- name = "raylet-jni-darwin-compat",
- srcs = [":libraylet_library_java.so"],
- outs = ["libraylet_library_java.dylib"],
+ name = "core_worker-jni-darwin-compat",
+ srcs = [":libcore_worker_library_java.so"],
+ outs = ["libcore_worker_library_java.dylib"],
cmd = "cp $< $@",
output_to_bindir = 1,
)
filegroup(
- name = "raylet_library_java",
+ name = "core_worker_library_java",
srcs = select({
- "@bazel_tools//src/conditions:darwin": [":libraylet_library_java.dylib"],
- "//conditions:default": [":libraylet_library_java.so"],
+ "@bazel_tools//src/conditions:darwin": [":libcore_worker_library_java.dylib"],
+ "//conditions:default": [":libcore_worker_library_java.so"],
}),
visibility = ["//java:__subpackages__"],
)
diff --git a/bazel/BUILD.plasma b/bazel/BUILD.plasma
index ff0fe3e14..5a8b57afc 100644
--- a/bazel/BUILD.plasma
+++ b/bazel/BUILD.plasma
@@ -2,27 +2,6 @@ load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
COPTS = ["-DARROW_USE_GLOG"]
-java_library(
- name = "org_apache_arrow_arrow_plasma",
- srcs = glob(["java/plasma/src/main/java/**/*.java"]),
- data = [":plasma_client_java"],
- visibility = ["//visibility:public"],
- deps = [
- "@maven//:org_slf4j_slf4j_api",
- ],
-)
-
-java_binary(
- name = "org_apache_arrow_arrow_plasma_test",
- srcs = ["java/plasma/src/test/java/org/apache/arrow/plasma/PlasmaClientTest.java"],
- main_class = "org.apache.arrow.plasma.PlasmaClientTest",
- visibility = ["//visibility:public"],
- deps = [
- ":org_apache_arrow_arrow_plasma",
- "@maven//:junit_junit",
- ],
-)
-
cc_library(
name = "arrow",
srcs = [
@@ -145,15 +124,6 @@ genrule(
output_to_bindir = 1,
)
-filegroup(
- name = "plasma_client_java",
- srcs = select({
- "@bazel_tools//src/conditions:darwin": [":libplasma_java.dylib"],
- "//conditions:default": [":libplasma_java.so"],
- }),
- visibility = ["//visibility:public"],
-)
-
cc_library(
name = "plasma_lib",
srcs = [
diff --git a/java/BUILD.bazel b/java/BUILD.bazel
index b9c43424f..37ef5b93b 100644
--- a/java/BUILD.bazel
+++ b/java/BUILD.bazel
@@ -69,7 +69,6 @@ define_java_module(
],
deps = [
":org_ray_ray_api",
- "@plasma//:org_apache_arrow_arrow_plasma",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:com_typesafe_config",
@@ -97,7 +96,6 @@ define_java_module(
deps = [
":org_ray_ray_api",
":org_ray_ray_runtime",
- "@plasma//:org_apache_arrow_arrow_plasma",
"@maven//:com_google_guava_guava",
"@maven//:com_sun_xml_bind_jaxb_core",
"@maven//:com_sun_xml_bind_jaxb_impl",
@@ -176,9 +174,8 @@ filegroup(
"//:redis-server",
"//:libray_redis_module.so",
"//:raylet",
- "//:raylet_library_java",
+ "//:core_worker_library_java",
"@plasma//:plasma_store_server",
- "@plasma//:plasma_client_java",
],
)
@@ -189,7 +186,6 @@ genrule(
":all_java_proto",
":java_native_deps",
":copy_pom_file",
- "@plasma//:org_apache_arrow_arrow_plasma",
],
outs = ["gen_maven_deps.out"],
cmd = """
@@ -208,9 +204,6 @@ genrule(
chmod +w $$f
cp $$f $$NATIVE_DEPS_DIR
done
- # Install plasma jar to local maven repo.
- mvn install:install-file -Dfile=$(locations @plasma//:org_apache_arrow_arrow_plasma) -Dpackaging=jar \
- -DgroupId=org.apache.arrow -DartifactId=arrow-plasma -Dversion=0.13.0-SNAPSHOT
echo $$(date) > $@
""",
local = 1,
diff --git a/java/pom.xml b/java/pom.xml
index bf7a41229..912b803de 100644
--- a/java/pom.xml
+++ b/java/pom.xml
@@ -24,11 +24,6 @@
-
- org.apache.arrow
- arrow-plasma
- 0.13.0-SNAPSHOT
-
diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml
index aba612b36..3c40f7ffc 100644
--- a/java/runtime/pom.xml
+++ b/java/runtime/pom.xml
@@ -22,10 +22,6 @@
ray-api
${project.version}
-
- org.apache.arrow
- arrow-plasma
-
com.beust
jcommander
diff --git a/java/runtime/pom_template.xml b/java/runtime/pom_template.xml
index 9200bd6c6..10a36bfce 100644
--- a/java/runtime/pom_template.xml
+++ b/java/runtime/pom_template.xml
@@ -22,10 +22,6 @@
ray-api
${project.version}
-
- org.apache.arrow
- arrow-plasma
-
{generated_bzl_deps}
diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java
index 2d51f113a..831de1acf 100644
--- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java
+++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java
@@ -1,7 +1,15 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
+import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.lang.reflect.Field;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -72,6 +80,27 @@ public abstract class AbstractRayRuntime implements RayRuntime {
protected RuntimeContext runtimeContext;
protected GcsClient gcsClient;
+ static {
+ try {
+ LOGGER.debug("Loading native libraries.");
+ // Load native libraries.
+ String[] libraries = new String[]{"core_worker_library_java"};
+ for (String library : libraries) {
+ String fileName = System.mapLibraryName(library);
+ // Copy the file from resources to a temp dir, and load the native library.
+ File file = File.createTempFile(fileName, "");
+ file.deleteOnExit();
+ InputStream in = AbstractRayRuntime.class.getResourceAsStream("/" + fileName);
+ Preconditions.checkNotNull(in, "{} doesn't exist.", fileName);
+ Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
+ System.load(file.getAbsolutePath());
+ }
+ LOGGER.debug("Native libraries loaded.");
+ } catch (IOException e) {
+ throw new RuntimeException("Couldn't load native libraries.", e);
+ }
+ }
+
public AbstractRayRuntime(RayConfig rayConfig) {
this.rayConfig = rayConfig;
functionManager = new FunctionManager(rayConfig.jobResourcePath);
@@ -79,6 +108,33 @@ public abstract class AbstractRayRuntime implements RayRuntime {
runtimeContext = new RuntimeContextImpl(this);
}
+ protected void resetLibraryPath() {
+ if (rayConfig.libraryPath.isEmpty()) {
+ return;
+ }
+
+ String path = System.getProperty("java.library.path");
+ if (Strings.isNullOrEmpty(path)) {
+ path = "";
+ } else {
+ path += ":";
+ }
+ path += String.join(":", rayConfig.libraryPath);
+
+ // This is a hack to reset library path at runtime,
+ // see https://stackoverflow.com/questions/15409223/.
+ System.setProperty("java.library.path", path);
+ // Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed.
+ final Field sysPathsField;
+ try {
+ sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
+ sysPathsField.setAccessible(true);
+ sysPathsField.set(null, null);
+ } catch (NoSuchFieldException | IllegalAccessException e) {
+ LOGGER.error("Failed to set library path.", e);
+ }
+ }
+
/**
* Start runtime.
*/
@@ -330,8 +386,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
* Create the task specification.
*
* @param func The target remote function.
- * @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a
- * Python task.
+ * @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a Python
+ * task.
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
* @param args The arguments for the remote function.
* @param isActorCreationTask Whether this task is an actor creation task.
diff --git a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java
index a53d59bc8..a491d89e5 100644
--- a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java
+++ b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java
@@ -3,7 +3,7 @@ package org.ray.runtime;
import java.util.concurrent.atomic.AtomicInteger;
import org.ray.api.id.JobId;
import org.ray.runtime.config.RayConfig;
-import org.ray.runtime.objectstore.MockObjectStore;
+import org.ray.runtime.objectstore.MockObjectInterface;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.ray.runtime.raylet.MockRayletClient;
@@ -13,19 +13,22 @@ public class RayDevRuntime extends AbstractRayRuntime {
super(rayConfig);
}
- private MockObjectStore store;
+ private MockObjectInterface objectInterface;
private AtomicInteger jobCounter = new AtomicInteger(0);
@Override
public void start() {
- store = new MockObjectStore(this);
+ // Reset library path at runtime.
+ resetLibraryPath();
+
+ objectInterface = new MockObjectInterface(workerContext);
if (rayConfig.getJobId().isNil()) {
rayConfig.setJobId(nextJobId());
}
workerContext = new WorkerContext(rayConfig.workerMode,
rayConfig.getJobId(), rayConfig.runMode);
- objectStoreProxy = new ObjectStoreProxy(this, null);
+ objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterface);
rayletClient = new MockRayletClient(this, rayConfig.numberExecThreadsForDevRuntime);
}
@@ -34,8 +37,8 @@ public class RayDevRuntime extends AbstractRayRuntime {
rayletClient.destroy();
}
- public MockObjectStore getObjectStore() {
- return store;
+ public MockObjectInterface getObjectInterface() {
+ return objectInterface;
}
@Override
diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java
index 8d98b18f4..cf804ee02 100644
--- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java
+++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java
@@ -1,21 +1,13 @@
package org.ray.runtime;
-import com.google.common.base.Preconditions;
-import com.google.common.base.Strings;
-import java.io.File;
-import java.io.IOException;
-import java.io.InputStream;
-import java.lang.reflect.Field;
-import java.nio.file.Files;
-import java.nio.file.Paths;
-import java.nio.file.StandardCopyOption;
import java.util.HashMap;
import java.util.Map;
import org.ray.api.id.JobId;
import org.ray.runtime.config.RayConfig;
-import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.gcs.GcsClient;
import org.ray.runtime.gcs.RedisClient;
+import org.ray.runtime.generated.Common.WorkerType;
+import org.ray.runtime.objectstore.ObjectInterfaceImpl;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.ray.runtime.raylet.RayletClientImpl;
import org.ray.runtime.runner.RunManager;
@@ -31,58 +23,12 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
private RunManager manager = null;
- static {
- try {
- LOGGER.debug("Loading native libraries.");
- // Load native libraries.
- String[] libraries = new String[]{"raylet_library_java", "plasma_java"};
- for (String library : libraries) {
- String fileName = System.mapLibraryName(library);
- // Copy the file from resources to a temp dir, and load the native library.
- File file = File.createTempFile(fileName, "");
- file.deleteOnExit();
- InputStream in = RayNativeRuntime.class.getResourceAsStream("/" + fileName);
- Preconditions.checkNotNull(in, "{} doesn't exist.", fileName);
- Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
- System.load(file.getAbsolutePath());
- }
- LOGGER.debug("Native libraries loaded.");
- } catch (IOException e) {
- throw new RuntimeException("Couldn't load native libraries.", e);
- }
- }
+ private ObjectInterfaceImpl objectInterfaceImpl = null;
public RayNativeRuntime(RayConfig rayConfig) {
super(rayConfig);
}
- private void resetLibraryPath() {
- if (rayConfig.libraryPath.isEmpty()) {
- return;
- }
-
- String path = System.getProperty("java.library.path");
- if (Strings.isNullOrEmpty(path)) {
- path = "";
- } else {
- path += ":";
- }
- path += String.join(":", rayConfig.libraryPath);
-
- // This is a hack to reset library path at runtime,
- // see https://stackoverflow.com/questions/15409223/.
- System.setProperty("java.library.path", path);
- // Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed.
- final Field sysPathsField;
- try {
- sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
- sysPathsField.setAccessible(true);
- sysPathsField.set(null, null);
- } catch (NoSuchFieldException | IllegalAccessException e) {
- LOGGER.error("Failed to set library path.", e);
- }
- }
-
@Override
public void start() {
// Reset library path at runtime.
@@ -101,16 +47,18 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
workerContext = new WorkerContext(rayConfig.workerMode,
rayConfig.getJobId(), rayConfig.runMode);
- // TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
- objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
-
rayletClient = new RayletClientImpl(
rayConfig.rayletSocketName,
workerContext.getCurrentWorkerId(),
- rayConfig.workerMode == WorkerMode.WORKER,
+ rayConfig.workerMode == WorkerType.WORKER,
workerContext.getCurrentJobId()
);
+ // TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
+ objectInterfaceImpl = new ObjectInterfaceImpl(workerContext, rayletClient,
+ rayConfig.objectStoreSocketName);
+ objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterfaceImpl);
+
// register
registerWorker();
@@ -123,6 +71,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
if (null != manager) {
manager.cleanup();
}
+ objectInterfaceImpl.destroy();
+ workerContext.destroy();
}
/**
@@ -132,7 +82,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
RedisClient redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
Map workerInfo = new HashMap<>();
String workerId = new String(workerContext.getCurrentWorkerId().getBytes());
- if (rayConfig.workerMode == WorkerMode.DRIVER) {
+ if (rayConfig.workerMode == WorkerType.DRIVER) {
workerInfo.put("node_ip_address", rayConfig.nodeIp);
workerInfo.put("driver_id", workerId);
workerInfo.put("start_time", String.valueOf(System.currentTimeMillis()));
diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java
index 828d39cb5..4153e732a 100644
--- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java
+++ b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java
@@ -1,37 +1,24 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
+import java.nio.ByteBuffer;
import org.ray.api.id.JobId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RunMode;
-import org.ray.runtime.config.WorkerMode;
+import org.ray.runtime.generated.Common.WorkerType;
+import org.ray.runtime.raylet.RayletClientImpl;
import org.ray.runtime.task.TaskSpec;
-import org.ray.runtime.util.IdUtil;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+/**
+ * This is a wrapper class for worker context of core worker.
+ */
public class WorkerContext {
- private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContext.class);
-
- private UniqueId workerId;
-
- private ThreadLocal currentTaskId;
-
/**
- * Number of objects that have been put from current task.
+ * The native pointer of worker context of core worker.
*/
- private ThreadLocal putIndex;
-
- /**
- * Number of tasks that have been submitted from current task.
- */
- private ThreadLocal taskIndex;
-
- private ThreadLocal currentTask;
-
- private JobId currentJobId;
+ private final long nativeWorkerContextPointer;
private ClassLoader currentClassLoader;
@@ -45,31 +32,23 @@ public class WorkerContext {
*/
private RunMode runMode;
- public WorkerContext(WorkerMode workerMode, JobId jobId, RunMode runMode) {
+ public WorkerContext(WorkerType workerType, JobId jobId, RunMode runMode) {
+ this.nativeWorkerContextPointer = nativeCreateWorkerContext(workerType.getNumber(), jobId.getBytes());
mainThreadId = Thread.currentThread().getId();
- taskIndex = ThreadLocal.withInitial(() -> 0);
- putIndex = ThreadLocal.withInitial(() -> 0);
- currentTaskId = ThreadLocal.withInitial(TaskId::randomId);
this.runMode = runMode;
- currentTask = ThreadLocal.withInitial(() -> null);
currentClassLoader = null;
- if (workerMode == WorkerMode.DRIVER) {
- workerId = IdUtil.computeDriverId(jobId);
- currentTaskId.set(TaskId.randomId());
- currentJobId = jobId;
- } else {
- workerId = UniqueId.randomId();
- this.currentTaskId.set(TaskId.NIL);
- this.currentJobId = JobId.NIL;
- }
+ }
+
+ public long getNativeWorkerContext() {
+ return nativeWorkerContextPointer;
}
/**
* @return For the main thread, this method returns the ID of this worker's current running task;
- * for other threads, this method returns a random ID.
+ * for other threads, this method returns a random ID.
*/
public TaskId getCurrentTaskId() {
- return currentTaskId.get();
+ return new TaskId(nativeGetCurrentTaskId(nativeWorkerContextPointer));
}
/**
@@ -79,17 +58,14 @@ public class WorkerContext {
public void setCurrentTask(TaskSpec task, ClassLoader classLoader) {
if (runMode == RunMode.CLUSTER) {
Preconditions.checkState(
- Thread.currentThread().getId() == mainThreadId,
- "This method should only be called from the main thread."
+ Thread.currentThread().getId() == mainThreadId,
+ "This method should only be called from the main thread."
);
}
Preconditions.checkNotNull(task);
- this.currentTaskId.set(task.taskId);
- this.currentJobId = task.jobId;
- taskIndex.set(0);
- putIndex.set(0);
- this.currentTask.set(task);
+ byte[] taskSpec = RayletClientImpl.convertTaskSpecToProtobuf(task);
+ nativeSetCurrentTask(nativeWorkerContextPointer, taskSpec);
currentClassLoader = classLoader;
}
@@ -97,30 +73,28 @@ public class WorkerContext {
* Increment the put index and return the new value.
*/
public int nextPutIndex() {
- putIndex.set(putIndex.get() + 1);
- return putIndex.get();
+ return nativeGetNextPutIndex(nativeWorkerContextPointer);
}
/**
* Increment the task index and return the new value.
*/
public int nextTaskIndex() {
- taskIndex.set(taskIndex.get() + 1);
- return taskIndex.get();
+ return nativeGetNextTaskIndex(nativeWorkerContextPointer);
}
/**
* @return The ID of the current worker.
*/
public UniqueId getCurrentWorkerId() {
- return workerId;
+ return new UniqueId(nativeGetCurrentWorkerId(nativeWorkerContextPointer));
}
/**
* The ID of the current job.
*/
public JobId getCurrentJobId() {
- return currentJobId;
+ return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeWorkerContextPointer));
}
/**
@@ -134,6 +108,32 @@ public class WorkerContext {
* Get the current task.
*/
public TaskSpec getCurrentTask() {
- return this.currentTask.get();
+ byte[] bytes = nativeGetCurrentTask(nativeWorkerContextPointer);
+ if (bytes == null) {
+ return null;
+ }
+ return RayletClientImpl.parseTaskSpecFromProtobuf(bytes);
}
+
+ public void destroy() {
+ nativeDestroy(nativeWorkerContextPointer);
+ }
+
+ private static native long nativeCreateWorkerContext(int workerType, byte[] jobId);
+
+ private static native byte[] nativeGetCurrentTaskId(long nativeWorkerContextPointer);
+
+ private static native void nativeSetCurrentTask(long nativeWorkerContextPointer, byte[] taskSpec);
+
+ private static native byte[] nativeGetCurrentTask(long nativeWorkerContextPointer);
+
+ private static native ByteBuffer nativeGetCurrentJobId(long nativeWorkerContextPointer);
+
+ private static native byte[] nativeGetCurrentWorkerId(long nativeWorkerContextPointer);
+
+ private static native int nativeGetNextTaskIndex(long nativeWorkerContextPointer);
+
+ private static native int nativeGetNextPutIndex(long nativeWorkerContextPointer);
+
+ private static native void nativeDestroy(long nativeWorkerContextPointer);
}
diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java
index e67c88d59..1e90d68f4 100644
--- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java
+++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java
@@ -11,6 +11,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.ray.api.id.JobId;
+import org.ray.runtime.generated.Common.WorkerType;
import org.ray.runtime.util.NetworkUtil;
import org.ray.runtime.util.ResourceUtil;
import org.ray.runtime.util.StringUtil;
@@ -29,7 +30,7 @@ public class RayConfig {
public static final String CUSTOM_CONFIG_FILE = "ray.conf";
public final String nodeIp;
- public final WorkerMode workerMode;
+ public final WorkerType workerMode;
public final RunMode runMode;
public final Map resources;
private JobId jobId;
@@ -62,7 +63,7 @@ public class RayConfig {
public final int numberExecThreadsForDevRuntime;
private void validate() {
- if (workerMode == WorkerMode.WORKER) {
+ if (workerMode == WorkerType.WORKER) {
Preconditions.checkArgument(redisAddress != null,
"Redis address must be set in worker mode.");
}
@@ -78,14 +79,14 @@ public class RayConfig {
public RayConfig(Config config) {
// Worker mode.
- WorkerMode localWorkerMode;
+ WorkerType localWorkerMode;
try {
- localWorkerMode = config.getEnum(WorkerMode.class, "ray.worker.mode");
+ localWorkerMode = config.getEnum(WorkerType.class, "ray.worker.mode");
} catch (ConfigException.Missing e) {
- localWorkerMode = WorkerMode.DRIVER;
+ localWorkerMode = WorkerType.DRIVER;
}
workerMode = localWorkerMode;
- boolean isDriver = workerMode == WorkerMode.DRIVER;
+ boolean isDriver = workerMode == WorkerType.DRIVER;
// Run mode.
runMode = config.getEnum(RunMode.class, "ray.run-mode");
// Node ip.
diff --git a/java/runtime/src/main/java/org/ray/runtime/config/WorkerMode.java b/java/runtime/src/main/java/org/ray/runtime/config/WorkerMode.java
deleted file mode 100644
index 947159c3b..000000000
--- a/java/runtime/src/main/java/org/ray/runtime/config/WorkerMode.java
+++ /dev/null
@@ -1,6 +0,0 @@
-package org.ray.runtime.config;
-
-public enum WorkerMode {
- DRIVER,
- WORKER
-}
diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java
new file mode 100644
index 000000000..8ec855bca
--- /dev/null
+++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java
@@ -0,0 +1,98 @@
+package org.ray.runtime.objectstore;
+
+import com.google.common.base.Preconditions;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+import org.ray.api.id.ObjectId;
+import org.ray.runtime.WorkerContext;
+import org.ray.runtime.util.IdUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class MockObjectInterface implements ObjectInterface {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectInterface.class);
+
+ private static final int GET_CHECK_INTERVAL_MS = 100;
+
+ private final Map pool = new ConcurrentHashMap<>();
+ private final List> objectPutCallbacks = new ArrayList<>();
+ private final WorkerContext workerContext;
+
+ public MockObjectInterface(WorkerContext workerContext) {
+ this.workerContext = workerContext;
+ }
+
+ public void addObjectPutCallback(Consumer callback) {
+ this.objectPutCallbacks.add(callback);
+ }
+
+ public boolean isObjectReady(ObjectId id) {
+ return pool.containsKey(id);
+ }
+
+ @Override
+ public ObjectId put(NativeRayObject obj) {
+ ObjectId objectId = IdUtil.computePutId(workerContext.getCurrentTaskId(),
+ workerContext.nextPutIndex());
+ put(obj, objectId);
+ return objectId;
+ }
+
+ @Override
+ public void put(NativeRayObject obj, ObjectId objectId) {
+ Preconditions.checkNotNull(obj);
+ Preconditions.checkNotNull(objectId);
+ pool.putIfAbsent(objectId, obj);
+ for (Consumer callback : objectPutCallbacks) {
+ callback.accept(objectId);
+ }
+ }
+
+ @Override
+ public List get(List objectIds, long timeoutMs) {
+ waitInternal(objectIds, objectIds.size(), timeoutMs);
+ return objectIds.stream().map(pool::get).collect(Collectors.toList());
+ }
+
+ @Override
+ public List wait(List objectIds, int numObjects, long timeoutMs) {
+ waitInternal(objectIds, numObjects, timeoutMs);
+ return objectIds.stream().map(pool::containsKey).collect(Collectors.toList());
+ }
+
+ private void waitInternal(List objectIds, int numObjects, long timeoutMs) {
+ int ready = 0;
+ long remainingTime = timeoutMs;
+ boolean firstCheck = true;
+ while (ready < numObjects && (timeoutMs < 0 || remainingTime > 0)) {
+ if (!firstCheck) {
+ long sleepTime = Math.min(remainingTime, GET_CHECK_INTERVAL_MS);
+ try {
+ Thread.sleep(sleepTime);
+ } catch (InterruptedException e) {
+ LOGGER.warn("Got InterruptedException while sleeping.");
+ }
+ remainingTime -= sleepTime;
+ }
+ ready = 0;
+ for (ObjectId objectId : objectIds) {
+ if (pool.containsKey(objectId)) {
+ ready += 1;
+ }
+ }
+ firstCheck = false;
+ }
+ }
+
+ @Override
+ public void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks) {
+ for (ObjectId objectId : objectIds) {
+ pool.remove(objectId);
+ }
+ }
+}
diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java
deleted file mode 100644
index f3d64c834..000000000
--- a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java
+++ /dev/null
@@ -1,148 +0,0 @@
-package org.ray.runtime.objectstore;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.function.Consumer;
-import java.util.stream.Collectors;
-
-import org.apache.arrow.plasma.ObjectStoreLink;
-import org.ray.api.id.ObjectId;
-import org.ray.runtime.RayDevRuntime;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/**
- * A mock implementation of {@code org.ray.spi.ObjectStoreLink}, which use Map to store data.
- */
-public class MockObjectStore implements ObjectStoreLink {
-
- private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectStore.class);
-
- private static final int GET_CHECK_INTERVAL_MS = 100;
-
- private final RayDevRuntime runtime;
- private final Map data = new ConcurrentHashMap<>();
- private final Map metadata = new ConcurrentHashMap<>();
- private final List> objectPutCallbacks;
-
- public MockObjectStore(RayDevRuntime runtime) {
- this.runtime = runtime;
- this.objectPutCallbacks = new ArrayList<>();
- }
-
- public void addObjectPutCallback(Consumer callback) {
- this.objectPutCallbacks.add(callback);
- }
-
- @Override
- public void put(byte[] objectId, byte[] value, byte[] metadataValue) {
- if (objectId == null || objectId.length == 0 || value == null) {
- LOGGER
- .error("{} cannot put null: {}, {}", logPrefix(), objectId, Arrays.toString(value));
- System.exit(-1);
- }
- ObjectId id = new ObjectId(objectId);
- data.put(id, value);
- if (metadataValue != null) {
- metadata.put(id, metadataValue);
- }
- for (Consumer callback : objectPutCallbacks) {
- callback.accept(id);
- }
- }
-
- @Override
- public byte[] get(byte[] objectId, int timeoutMs, boolean isMetadata) {
- return get(new byte[][] {objectId}, timeoutMs, isMetadata).get(0);
- }
-
- @Override
- public List get(byte[][] objectIds, int timeoutMs, boolean isMetadata) {
- return get(objectIds, timeoutMs)
- .stream()
- .map(data -> isMetadata ? data.metadata : data.data)
- .collect(Collectors.toList());
- }
-
- @Override
- public List get(byte[][] objectIds, int timeoutMs) {
- int ready = 0;
- int remainingTime = timeoutMs;
- boolean firstCheck = true;
- while (ready < objectIds.length && remainingTime > 0) {
- if (!firstCheck) {
- int sleepTime = Math.min(remainingTime, GET_CHECK_INTERVAL_MS);
- try {
- Thread.sleep(sleepTime);
- } catch (InterruptedException e) {
- LOGGER.warn("Got InterruptedException while sleeping.");
- }
- remainingTime -= sleepTime;
- }
- ready = 0;
- for (byte[] id : objectIds) {
- if (data.containsKey(new ObjectId(id))) {
- ready += 1;
- }
- }
- firstCheck = false;
- }
- ArrayList rets = new ArrayList<>();
- for (byte[] objId : objectIds) {
- ObjectId objectId = new ObjectId(objId);
- rets.add(new ObjectStoreData(metadata.get(objectId), data.get(objectId)));
- }
- return rets;
- }
-
- @Override
- public byte[] hash(byte[] objectId) {
- return null;
- }
-
- @Override
- public long evict(long numBytes) {
- return 0;
- }
-
- @Override
- public void release(byte[] objectId) {
- return;
- }
-
- @Override
- public void delete(byte[] objectId) {
- return;
- }
-
- @Override
- public boolean contains(byte[] objectId) {
- return data.containsKey(new ObjectId(objectId));
- }
-
- private String logPrefix() {
- return runtime.getWorkerContext().getCurrentTaskId() + "-" + getUserTrace() + " -> ";
- }
-
- private String getUserTrace() {
- StackTraceElement[] stes = Thread.currentThread().getStackTrace();
- int k = 1;
- while (stes[k].getClassName().startsWith("org.ray")
- && !stes[k].getClassName().contains("test")) {
- k++;
- }
- return stes[k].getFileName() + ":" + stes[k].getLineNumber();
- }
-
- public boolean isObjectReady(ObjectId id) {
- return data.containsKey(id);
- }
-
- public void free(ObjectId id) {
- data.remove(id);
- metadata.remove(id);
- }
-}
diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/NativeRayObject.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/NativeRayObject.java
new file mode 100644
index 000000000..7146765c2
--- /dev/null
+++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/NativeRayObject.java
@@ -0,0 +1,13 @@
+package org.ray.runtime.objectstore;
+
+public class NativeRayObject {
+
+ public byte[] data;
+ public byte[] metadata;
+
+ public NativeRayObject(byte[] data, byte[] metadata) {
+ this.data = data;
+ this.metadata = metadata;
+ }
+}
+
diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterface.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterface.java
new file mode 100644
index 000000000..5780dbd6c
--- /dev/null
+++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterface.java
@@ -0,0 +1,54 @@
+package org.ray.runtime.objectstore;
+
+import java.util.List;
+import org.ray.api.id.ObjectId;
+
+/**
+ * The interface that contains all worker methods that are related to object store.
+ */
+public interface ObjectInterface {
+
+ /**
+ * Put an object into object store.
+ *
+ * @param obj The ray object.
+ * @return Generated ID of the object.
+ */
+ ObjectId put(NativeRayObject obj);
+
+ /**
+ * Put an object with specified ID into object store.
+ *
+ * @param obj The ray object.
+ * @param objectId Object ID specified by user.
+ */
+ void put(NativeRayObject obj, ObjectId objectId);
+
+ /**
+ * Get a list of objects from the object store.
+ *
+ * @param objectIds IDs of the objects to get.
+ * @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
+ * @return Result list of objects data.
+ */
+ List get(List objectIds, long timeoutMs);
+
+ /**
+ * Wait for a list of objects to appear in the object store.
+ *
+ * @param objectIds IDs of the objects to wait for.
+ * @param numObjects Number of objects that should appear.
+ * @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
+ * @return A bitset that indicates each object has appeared or not.
+ */
+ List wait(List objectIds, int numObjects, long timeoutMs);
+
+ /**
+ * Delete a list of objects from the object store.
+ *
+ * @param objectIds IDs of the objects to delete.
+ * @param localOnly Whether only delete the objects in local node, or all nodes in the cluster.
+ * @param deleteCreatingTasks Whether also delete the tasks that created these objects.
+ */
+ void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks);
+}
diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterfaceImpl.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterfaceImpl.java
new file mode 100644
index 000000000..5e1774808
--- /dev/null
+++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterfaceImpl.java
@@ -0,0 +1,91 @@
+package org.ray.runtime.objectstore;
+
+import java.util.List;
+import java.util.stream.Collectors;
+import org.ray.api.exception.RayException;
+import org.ray.api.id.BaseId;
+import org.ray.api.id.ObjectId;
+import org.ray.runtime.AbstractRayRuntime;
+import org.ray.runtime.WorkerContext;
+import org.ray.runtime.raylet.RayletClient;
+import org.ray.runtime.raylet.RayletClientImpl;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This is a wrapper class for core worker object interface.
+ */
+public class ObjectInterfaceImpl implements ObjectInterface {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
+
+ /**
+ * The native pointer of core worker object interface.
+ */
+ private final long nativeObjectInterfacePointer;
+
+ public ObjectInterfaceImpl(WorkerContext workerContext, RayletClient rayletClient,
+ String storeSocketName) {
+ this.nativeObjectInterfacePointer =
+ nativeCreateObjectInterface(workerContext.getNativeWorkerContext(),
+ ((RayletClientImpl) rayletClient).getClient(), storeSocketName);
+ }
+
+ @Override
+ public ObjectId put(NativeRayObject obj) {
+ return new ObjectId(nativePut(nativeObjectInterfacePointer, obj));
+ }
+
+ @Override
+ public void put(NativeRayObject obj, ObjectId objectId) {
+ try {
+ nativePut(nativeObjectInterfacePointer, objectId.getBytes(), obj);
+ } catch (RayException e) {
+ LOGGER.warn(e.getMessage());
+ }
+ }
+
+ @Override
+ public List get(List objectIds, long timeoutMs) {
+ return nativeGet(nativeObjectInterfacePointer, toBinaryList(objectIds), timeoutMs);
+ }
+
+ @Override
+ public List wait(List objectIds, int numObjects, long timeoutMs) {
+ return nativeWait(nativeObjectInterfacePointer, toBinaryList(objectIds), numObjects, timeoutMs);
+ }
+
+ @Override
+ public void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks) {
+ nativeDelete(nativeObjectInterfacePointer, toBinaryList(objectIds), localOnly, deleteCreatingTasks);
+ }
+
+ public void destroy() {
+ nativeDestroy(nativeObjectInterfacePointer);
+ }
+
+ private static List toBinaryList(List ids) {
+ return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
+ }
+
+ private static native long nativeCreateObjectInterface(long nativeObjectInterface,
+ long nativeRayletClient,
+ String storeSocketName);
+
+ private static native byte[] nativePut(long nativeObjectInterface, NativeRayObject obj);
+
+ private static native void nativePut(long nativeObjectInterface, byte[] objectId,
+ NativeRayObject obj);
+
+ private static native List nativeGet(long nativeObjectInterface,
+ List ids,
+ long timeoutMs);
+
+ private static native List nativeWait(long nativeObjectInterface, List objectIds,
+ int numObjects, long timeoutMs);
+
+ private static native void nativeDelete(long nativeObjectInterface, List objectIds,
+ boolean localOnly, boolean deleteCreatingTasks);
+
+ private static native void nativeDestroy(long nativeObjectInterface);
+}
diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java
index 1a7e4701c..5470d719b 100644
--- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java
+++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java
@@ -4,20 +4,14 @@ import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-import org.apache.arrow.plasma.ObjectStoreLink;
-import org.apache.arrow.plasma.ObjectStoreLink.ObjectStoreData;
-import org.apache.arrow.plasma.PlasmaClient;
-import org.apache.arrow.plasma.exceptions.DuplicateObjectException;
import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayException;
+import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
import org.ray.api.exception.UnreconstructableException;
import org.ray.api.id.ObjectId;
-import org.ray.runtime.AbstractRayRuntime;
-import org.ray.runtime.RayDevRuntime;
-import org.ray.runtime.config.RunMode;
+import org.ray.runtime.WorkerContext;
import org.ray.runtime.generated.Gcs.ErrorType;
-import org.ray.runtime.util.IdUtil;
import org.ray.runtime.util.Serializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -36,21 +30,18 @@ public class ObjectStoreProxy {
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
+ private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
+ .valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
+
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
- private final AbstractRayRuntime runtime;
+ private final WorkerContext workerContext;
- private static ThreadLocal objectStore;
+ private final ObjectInterface objectInterface;
- public ObjectStoreProxy(AbstractRayRuntime runtime, String storeSocketName) {
- this.runtime = runtime;
- objectStore = ThreadLocal.withInitial(() -> {
- if (runtime.getRayConfig().runMode == RunMode.CLUSTER) {
- return new PlasmaClient(storeSocketName, "", 0);
- } else {
- return ((RayDevRuntime) runtime).getObjectStore();
- }
- });
+ public ObjectStoreProxy(WorkerContext workerContext, ObjectInterface objectInterface) {
+ this.workerContext = workerContext;
+ this.objectInterface = objectInterface;
}
/**
@@ -75,46 +66,44 @@ public class ObjectStoreProxy {
* @return A list of GetResult objects.
*/
public List> get(List ids, int timeoutMs) {
- byte[][] binaryIds = IdUtil.getIdBytes(ids);
- List dataAndMetaList = objectStore.get().get(binaryIds, timeoutMs);
+ List dataAndMetaList = objectInterface.get(ids, timeoutMs);
List> results = new ArrayList<>();
for (int i = 0; i < dataAndMetaList.size(); i++) {
- byte[] meta = dataAndMetaList.get(i).metadata;
- byte[] data = dataAndMetaList.get(i).data;
-
+ NativeRayObject dataAndMeta = dataAndMetaList.get(i);
GetResult result;
- if (meta != null) {
- // If meta is not null, deserialize the object from meta.
- result = deserializeFromMeta(meta, data, ids.get(i));
- } else if (data != null) {
- // If data is not null, deserialize the Java object.
- Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader());
- if (object instanceof RayException) {
- // If the object is a `RayException`, it means that an error occurred during task
- // execution.
- result = new GetResult<>(true, null, (RayException) object);
+ if (dataAndMeta != null) {
+ byte[] meta = dataAndMeta.metadata;
+ byte[] data = dataAndMeta.data;
+ if (meta != null && meta.length > 0) {
+ // If meta is not null, deserialize the object from meta.
+ result = deserializeFromMeta(meta, data,
+ workerContext.getCurrentClassLoader(), ids.get(i));
} else {
- // Otherwise, the object is valid.
- result = new GetResult<>(true, (T) object, null);
+ // If data is not null, deserialize the Java object.
+ Object object = Serializer.decode(data, workerContext.getCurrentClassLoader());
+ if (object instanceof RayException) {
+ // If the object is a `RayException`, it means that an error occurred during task
+ // execution.
+ result = new GetResult<>(true, null, (RayException) object);
+ } else {
+ // Otherwise, the object is valid.
+ result = new GetResult<>(true, (T) object, null);
+ }
}
} else {
// If both meta and data are null, the object doesn't exist in object store.
result = new GetResult<>(false, null, null);
}
- if (meta != null || data != null) {
- // Release the object from object store..
- objectStore.get().release(binaryIds[i]);
- }
-
results.add(result);
}
return results;
}
@SuppressWarnings("unchecked")
- private GetResult deserializeFromMeta(byte[] meta, byte[] data, ObjectId objectId) {
+ private GetResult deserializeFromMeta(byte[] meta, byte[] data,
+ ClassLoader classLoader, ObjectId objectId) {
if (Arrays.equals(meta, RAW_TYPE_META)) {
return (GetResult) new GetResult<>(true, data, null);
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
@@ -123,6 +112,8 @@ public class ObjectStoreProxy {
return new GetResult<>(true, null, RayActorException.INSTANCE);
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
return new GetResult<>(true, null, new UnreconstructableException(objectId));
+ } else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
+ return new GetResult<>(true, null, Serializer.decode(data, classLoader));
}
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
}
@@ -134,16 +125,14 @@ public class ObjectStoreProxy {
* @param object The object to put.
*/
public void put(ObjectId id, Object object) {
- try {
- if (object instanceof byte[]) {
- // If the object is a byte array, skip serializing it and use a special metadata to
- // indicate it's raw binary. So that this object can also be read by Python.
- objectStore.get().put(id.getBytes(), (byte[]) object, RAW_TYPE_META);
- } else {
- objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
- }
- } catch (DuplicateObjectException e) {
- LOGGER.warn(e.getMessage());
+ if (object instanceof byte[]) {
+ // If the object is a byte array, skip serializing it and use a special metadata to
+ // indicate it's raw binary. So that this object can also be read by Python.
+ objectInterface.put(new NativeRayObject((byte[]) object, RAW_TYPE_META), id);
+ } else if (object instanceof RayTaskException) {
+ objectInterface.put(new NativeRayObject(Serializer.encode(object), TASK_EXECUTION_EXCEPTION_META), id);
+ } else {
+ objectInterface.put(new NativeRayObject(Serializer.encode(object), null), id);
}
}
@@ -154,11 +143,7 @@ public class ObjectStoreProxy {
* @param serializedObject The serialized object to put.
*/
public void putSerialized(ObjectId id, byte[] serializedObject) {
- try {
- objectStore.get().put(id.getBytes(), serializedObject, null);
- } catch (DuplicateObjectException e) {
- LOGGER.warn(e.getMessage());
- }
+ objectInterface.put(new NativeRayObject(serializedObject, null), id);
}
/**
diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java
index 0dc8f4c9e..38995bf9b 100644
--- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java
+++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java
@@ -14,6 +14,7 @@ import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.stream.Collectors;
import org.apache.commons.lang3.NotImplementedException;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
@@ -23,7 +24,8 @@ import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.Worker;
-import org.ray.runtime.objectstore.MockObjectStore;
+import org.ray.runtime.objectstore.MockObjectInterface;
+import org.ray.runtime.objectstore.NativeRayObject;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.TaskSpec;
import org.slf4j.Logger;
@@ -37,7 +39,7 @@ public class MockRayletClient implements RayletClient {
private static final Logger LOGGER = LoggerFactory.getLogger(MockRayletClient.class);
private final Map> waitingTasks = new ConcurrentHashMap<>();
- private final MockObjectStore store;
+ private final MockObjectInterface objectInterface;
private final RayDevRuntime runtime;
private final ExecutorService exec;
private final Deque idleWorkers;
@@ -46,8 +48,8 @@ public class MockRayletClient implements RayletClient {
public MockRayletClient(RayDevRuntime runtime, int numberThreads) {
this.runtime = runtime;
- this.store = runtime.getObjectStore();
- store.addObjectPutCallback(this::onObjectPut);
+ this.objectInterface = runtime.getObjectInterface();
+ objectInterface.addObjectPutCallback(this::onObjectPut);
// The thread pool that executes tasks in parallel.
exec = Executors.newFixedThreadPool(numberThreads);
idleWorkers = new ConcurrentLinkedDeque<>();
@@ -113,8 +115,8 @@ public class MockRayletClient implements RayletClient {
// can be executed.
if (task.isActorCreationTask() || task.isActorTask()) {
ObjectId[] returnIds = task.returnIds;
- store.put(returnIds[returnIds.length - 1].getBytes(),
- new byte[]{}, new byte[]{});
+ objectInterface.put(new NativeRayObject(new byte[] {}, new byte[] {}),
+ returnIds[returnIds.length - 1]);
}
} finally {
returnWorker(worker);
@@ -133,13 +135,13 @@ public class MockRayletClient implements RayletClient {
// Check whether task arguments are ready.
for (FunctionArg arg : spec.args) {
if (arg.id != null) {
- if (!store.isObjectReady(arg.id)) {
+ if (!objectInterface.isObjectReady(arg.id)) {
unreadyObjects.add(arg.id);
}
}
}
if (spec.isActorTask()) {
- if (!store.isObjectReady(spec.previousActorTaskDummyObjectId)) {
+ if (!objectInterface.isObjectReady(spec.previousActorTaskDummyObjectId)) {
unreadyObjects.add(spec.previousActorTaskDummyObjectId);
}
}
@@ -154,7 +156,7 @@ public class MockRayletClient implements RayletClient {
@Override
public void fetchOrReconstruct(List objectIds, boolean fetchOnly,
- TaskId currentTaskId) {
+ TaskId currentTaskId) {
}
@@ -170,20 +172,17 @@ public class MockRayletClient implements RayletClient {
@Override
public WaitResult wait(List> waitFor, int numReturns, int
- timeoutMs, TaskId currentTaskId) {
+ timeoutMs, TaskId currentTaskId) {
if (waitFor == null || waitFor.isEmpty()) {
return new WaitResult<>(ImmutableList.of(), ImmutableList.of());
}
- byte[][] ids = new byte[waitFor.size()][];
- for (int i = 0; i < waitFor.size(); i++) {
- ids[i] = waitFor.get(i).getId().getBytes();
- }
+ List ids = waitFor.stream().map(RayObject::getId).collect(Collectors.toList());
List> readyList = new ArrayList<>();
List> unreadyList = new ArrayList<>();
- List result = store.get(ids, timeoutMs, false);
+ List result = objectInterface.wait(ids, ids.size(), timeoutMs);
for (int i = 0; i < waitFor.size(); i++) {
- if (result.get(i) != null) {
+ if (result.get(i)) {
readyList.add(waitFor.get(i));
} else {
unreadyList.add(waitFor.get(i));
@@ -195,9 +194,7 @@ public class MockRayletClient implements RayletClient {
@Override
public void freePlasmaObjects(List objectIds, boolean localOnly,
boolean deleteCreatingTasks) {
- for (ObjectId id : objectIds) {
- store.free(id);
- }
+ objectInterface.delete(objectIds, localOnly, deleteCreatingTasks);
}
diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java
index 059edbe67..a1e11141e 100644
--- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java
+++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java
@@ -4,8 +4,6 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
-import java.nio.ByteBuffer;
-import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
@@ -40,11 +38,15 @@ public class RayletClientImpl implements RayletClient {
// TODO(qwang): JobId parameter can be removed once we embed jobId in driverId.
public RayletClientImpl(String schedulerSockName, UniqueId clientId,
- boolean isWorker, JobId jobId) {
+ boolean isWorker, JobId jobId) {
client = nativeInit(schedulerSockName, clientId.getBytes(),
isWorker, jobId.getBytes());
}
+ public long getClient() {
+ return client;
+ }
+
@Override
public WaitResult wait(List> waitFor, int numReturns, int
timeoutMs, TaskId currentTaskId) {
@@ -133,7 +135,7 @@ public class RayletClientImpl implements RayletClient {
/**
* Parse `TaskSpec` protobuf bytes.
*/
- private static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
+ public static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
Common.TaskSpec taskSpec;
try {
taskSpec = Common.TaskSpec.parseFrom(bytes);
@@ -214,7 +216,7 @@ public class RayletClientImpl implements RayletClient {
/**
* Convert a `TaskSpec` to protobuf-serialized bytes.
*/
- private static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
+ public static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
// Set common fields.
Common.TaskSpec.Builder builder = Common.TaskSpec.newBuilder()
.setJobId(ByteString.copyFrom(task.jobId.getBytes()))
diff --git a/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java
index 8a96bc57a..6f9c95ea4 100644
--- a/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java
+++ b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java
@@ -154,18 +154,6 @@ public class IdUtil {
}
- /**
- * Compute the driver id from the given job.
- */
- public static UniqueId computeDriverId(JobId jobId) {
- byte[] bytes = new byte[UniqueId.LENGTH];
- System.arraycopy(jobId.getBytes(), 0, bytes, 0, jobId.size());
- Arrays.fill(bytes, jobId.size(), UniqueId.LENGTH, (byte)0xFF);
- ByteBuffer wbb = ByteBuffer.wrap(bytes);
- wbb.order(ByteOrder.LITTLE_ENDIAN);
- return new UniqueId(bytes);
- }
-
/**
* Compute the murmur hash code of this ID.
*/
diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java
index 6d47a2fc9..b47b010ae 100644
--- a/java/test/src/main/java/org/ray/api/test/FailureTest.java
+++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java
@@ -1,12 +1,18 @@
package org.ray.api.test;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.List;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.TestUtils;
import org.ray.api.exception.RayActorException;
+import org.ray.api.exception.RayException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
+import org.ray.api.function.RayFunc0;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -23,6 +29,15 @@ public class FailureTest extends BaseTest {
return 0;
}
+ public static int slowFunc() {
+ try {
+ Thread.sleep(10000);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ return 0;
+ }
+
public static class BadActor {
public BadActor(boolean failOnCreation) {
@@ -106,5 +121,26 @@ public class FailureTest extends BaseTest {
// RayActorException.
}
}
+
+ @Test
+ public void testGetThrowsQuicklyWhenFoundException() {
+ TestUtils.skipTestUnderSingleProcess();
+ List> badFunctions = Arrays.asList(FailureTest::badFunc,
+ FailureTest::badFunc2);
+ for (RayFunc0 badFunc : badFunctions) {
+ RayObject obj1 = Ray.call(badFunc);
+ RayObject obj2 = Ray.call(FailureTest::slowFunc);
+ Instant start = Instant.now();
+ try {
+ Ray.get(Arrays.asList(obj1.getId(), obj2.getId()));
+ Assert.fail("Should throw RayException.");
+ } catch (RayException e) {
+ Instant end = Instant.now();
+ long duration = Duration.between(start, end).toMillis();
+ Assert.assertTrue(duration < 5000, "Should fail quickly. " +
+ "Actual execution time: " + duration + " ms.");
+ }
+ }
+ }
}
diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java
index 7abc3f421..84adba6d7 100644
--- a/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java
+++ b/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java
@@ -1,12 +1,10 @@
package org.ray.api.test;
-import org.apache.arrow.plasma.PlasmaClient;
-import org.apache.arrow.plasma.exceptions.DuplicateObjectException;
-
import org.ray.api.Ray;
import org.ray.api.TestUtils;
-import org.ray.api.id.UniqueId;
+import org.ray.api.id.ObjectId;
import org.ray.runtime.AbstractRayRuntime;
+import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -15,15 +13,13 @@ public class PlasmaStoreTest extends BaseTest {
@Test
public void testPutWithDuplicateId() {
TestUtils.skipTestUnderSingleProcess();
- UniqueId objectId = UniqueId.randomId();
+ ObjectId objectId = ObjectId.randomId();
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
- PlasmaClient store = new PlasmaClient(runtime.getRayConfig().objectStoreSocketName, "", 0);
- store.put(objectId.getBytes(), new byte[]{}, new byte[]{});
- try {
- store.put(objectId.getBytes(), new byte[]{}, new byte[]{});
- Assert.fail("This line shouldn't be reached.");
- } catch (DuplicateObjectException e) {
- // Putting 2 objects with duplicate ID should throw DuplicateObjectException.
- }
+ ObjectStoreProxy objectInterface = runtime.getObjectStoreProxy();
+ objectInterface.put(objectId, 1);
+ Assert.assertEquals(objectInterface.get(objectId, -1).object, (Integer) 1);
+ objectInterface.put(objectId, 2);
+ // Putting 2 objects with duplicate ID should fail but ignored.
+ Assert.assertEquals(objectInterface.get(objectId, -1).object, (Integer) 1);
}
}
diff --git a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java
index 5b6834e5e..ebc342722 100644
--- a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java
+++ b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java
@@ -1,7 +1,7 @@
package org.ray.api.test;
import org.ray.runtime.config.RayConfig;
-import org.ray.runtime.config.WorkerMode;
+import org.ray.runtime.generated.Common.WorkerType;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -12,7 +12,7 @@ public class RayConfigTest {
try {
System.setProperty("ray.job.resource-path", "path/to/ray/job/resource/path");
RayConfig rayConfig = RayConfig.create();
- Assert.assertEquals(WorkerMode.DRIVER, rayConfig.workerMode);
+ Assert.assertEquals(WorkerType.DRIVER, rayConfig.workerMode);
Assert.assertEquals("path/to/ray/job/resource/path", rayConfig.jobResourcePath);
} finally {
// Unset system properties.
diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h
index 3888e3ba4..462699002 100644
--- a/src/ray/common/ray_config_def.h
+++ b/src/ray/common/ray_config_def.h
@@ -151,3 +151,10 @@ RAY_CONFIG(uint32_t, num_actor_checkpoints_to_keep, 20)
/// Maximum number of ids in one batch to send to GCS to delete keys.
RAY_CONFIG(uint32_t, maximum_gcs_deletion_batch_size, 1000)
+
+/// When getting objects from object store, print a warning every this number of attempts.
+RAY_CONFIG(uint32_t, object_store_get_warn_per_num_attempts, 50)
+
+/// When getting objects from object store, max number of ids to print in the warning
+/// message.
+RAY_CONFIG(uint32_t, object_store_get_max_ids_to_print_in_warning, 20)
diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h
index aabb3fa83..d265aa536 100644
--- a/src/ray/core_worker/common.h
+++ b/src/ray/core_worker/common.h
@@ -9,9 +9,7 @@
#include "ray/raylet/raylet_client.h"
namespace ray {
-
-/// Type of this worker.
-enum class WorkerType { WORKER, DRIVER };
+using WorkerType = rpc::WorkerType;
/// Information about a remote function.
struct RayFunction {
diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc
index c5d7e7857..b655e4588 100644
--- a/src/ray/core_worker/context.cc
+++ b/src/ray/core_worker/context.cc
@@ -6,69 +6,81 @@ namespace ray {
/// per-thread context for core worker.
struct WorkerThreadContext {
WorkerThreadContext()
- : current_task_id(TaskID::FromRandom()), task_index(0), put_index(0) {}
+ : current_task_id_(TaskID::FromRandom()), task_index_(0), put_index_(0) {}
- int GetNextTaskIndex() { return ++task_index; }
+ int GetNextTaskIndex() { return ++task_index_; }
- int GetNextPutIndex() { return ++put_index; }
+ int GetNextPutIndex() { return ++put_index_; }
- const TaskID &GetCurrentTaskID() const { return current_task_id; }
+ const TaskID &GetCurrentTaskID() const { return current_task_id_; }
- void SetCurrentTask(const TaskID &task_id) {
- current_task_id = task_id;
- task_index = 0;
- put_index = 0;
+ std::shared_ptr GetCurrentTask() const {
+ return current_task_;
+ }
+
+ void SetCurrentTaskId(const TaskID &task_id) {
+ current_task_id_ = task_id;
+ task_index_ = 0;
+ put_index_ = 0;
}
void SetCurrentTask(const TaskSpecification &task_spec) {
- SetCurrentTask(task_spec.TaskId());
+ SetCurrentTaskId(task_spec.TaskId());
+ current_task_ = std::make_shared(task_spec);
}
private:
/// The task ID for current task.
- TaskID current_task_id;
+ TaskID current_task_id_;
+
+ /// The current task.
+ std::shared_ptr current_task_;
/// Number of tasks that have been submitted from current task.
- int task_index;
+ int task_index_;
/// Number of objects that have been put from current task.
- int put_index;
+ int put_index_;
};
thread_local std::unique_ptr WorkerContext::thread_context_ =
nullptr;
WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id)
- : worker_type(worker_type),
- worker_id(worker_type == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id)
- : WorkerID::FromRandom()),
- current_job_id(worker_type == WorkerType::DRIVER ? job_id : JobID::Nil()) {
+ : worker_type_(worker_type),
+ worker_id_(worker_type_ == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id)
+ : WorkerID::FromRandom()),
+ current_job_id_(worker_type_ == WorkerType::DRIVER ? job_id : JobID::Nil()) {
// For worker main thread which initializes the WorkerContext,
// set task_id according to whether current worker is a driver.
// (For other threads it's set to random ID via GetThreadContext).
- GetThreadContext().SetCurrentTask(
- (worker_type == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil());
+ GetThreadContext().SetCurrentTaskId(
+ (worker_type_ == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil());
}
-const WorkerType WorkerContext::GetWorkerType() const { return worker_type; }
+const WorkerType WorkerContext::GetWorkerType() const { return worker_type_; }
-const WorkerID &WorkerContext::GetWorkerID() const { return worker_id; }
+const WorkerID &WorkerContext::GetWorkerID() const { return worker_id_; }
int WorkerContext::GetNextTaskIndex() { return GetThreadContext().GetNextTaskIndex(); }
int WorkerContext::GetNextPutIndex() { return GetThreadContext().GetNextPutIndex(); }
-const JobID &WorkerContext::GetCurrentJobID() const { return current_job_id; }
+const JobID &WorkerContext::GetCurrentJobID() const { return current_job_id_; }
const TaskID &WorkerContext::GetCurrentTaskID() const {
return GetThreadContext().GetCurrentTaskID();
}
void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
- current_job_id = task_spec.JobId();
+ current_job_id_ = task_spec.JobId();
GetThreadContext().SetCurrentTask(task_spec);
}
+std::shared_ptr WorkerContext::GetCurrentTask() const {
+ return GetThreadContext().GetCurrentTask();
+}
+
WorkerThreadContext &WorkerContext::GetThreadContext() {
if (thread_context_ == nullptr) {
thread_context_ = std::unique_ptr(new WorkerThreadContext());
diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h
index 629249103..8405501d3 100644
--- a/src/ray/core_worker/context.h
+++ b/src/ray/core_worker/context.h
@@ -22,19 +22,21 @@ class WorkerContext {
void SetCurrentTask(const TaskSpecification &task_spec);
+ std::shared_ptr GetCurrentTask() const;
+
int GetNextTaskIndex();
int GetNextPutIndex();
private:
/// Type of the worker.
- const WorkerType worker_type;
+ const WorkerType worker_type_;
/// ID for this worker.
- const WorkerID worker_id;
+ const WorkerID worker_id_;
/// Job ID for this worker.
- JobID current_job_id;
+ JobID current_job_id_;
private:
static WorkerThreadContext &GetThreadContext();
diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc
index 6fa560f27..e49ca9972 100644
--- a/src/ray/core_worker/core_worker.cc
+++ b/src/ray/core_worker/core_worker.cc
@@ -15,7 +15,7 @@ CoreWorker::CoreWorker(
task_interface_(worker_context_, raylet_client_),
object_interface_(worker_context_, raylet_client_, store_socket) {
int rpc_server_port = 0;
- if (worker_type_ == ray::WorkerType::WORKER) {
+ if (worker_type_ == WorkerType::WORKER) {
RAY_CHECK(execution_callback != nullptr);
task_execution_interface_ = std::unique_ptr(
new CoreWorkerTaskExecutionInterface(worker_context_, raylet_client_,
@@ -28,8 +28,8 @@ CoreWorker::CoreWorker(
// instead of crashing.
raylet_client_ = std::unique_ptr(new RayletClient(
raylet_socket_, ClientID::FromBinary(worker_context_.GetWorkerID().Binary()),
- (worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(),
- language_, rpc_server_port));
+ (worker_type_ == WorkerType::WORKER), worker_context_.GetCurrentJobID(), language_,
+ rpc_server_port));
}
} // namespace ray
diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc
new file mode 100644
index 000000000..6c66f8f2f
--- /dev/null
+++ b/src/ray/core_worker/lib/java/jni_init.cc
@@ -0,0 +1,75 @@
+#include "ray/core_worker/lib/java/jni_utils.h"
+
+jclass java_boolean_class;
+jmethodID java_boolean_init;
+
+jclass java_list_class;
+jmethodID java_list_size;
+jmethodID java_list_get;
+jmethodID java_list_add;
+
+jclass java_array_list_class;
+jmethodID java_array_list_init;
+jmethodID java_array_list_init_with_capacity;
+
+jclass java_ray_exception_class;
+
+jclass java_native_ray_object_class;
+jmethodID java_native_ray_object_init;
+jfieldID java_native_ray_object_data;
+jfieldID java_native_ray_object_metadata;
+
+jint JNI_VERSION = JNI_VERSION_1_8;
+
+inline jclass LoadClass(JNIEnv *env, const char *class_name) {
+ jclass tempLocalClassRef = env->FindClass(class_name);
+ jclass ret = (jclass)env->NewGlobalRef(tempLocalClassRef);
+ env->DeleteLocalRef(tempLocalClassRef);
+ return ret;
+}
+
+/// Load and cache frequently-used Java classes and methods
+jint JNI_OnLoad(JavaVM *vm, void *reserved) {
+ JNIEnv *env;
+ if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) {
+ return JNI_ERR;
+ }
+
+ java_boolean_class = LoadClass(env, "java/lang/Boolean");
+ java_boolean_init = env->GetMethodID(java_boolean_class, "", "(Z)V");
+
+ java_list_class = LoadClass(env, "java/util/List");
+ java_list_size = env->GetMethodID(java_list_class, "size", "()I");
+ java_list_get = env->GetMethodID(java_list_class, "get", "(I)Ljava/lang/Object;");
+ java_list_add = env->GetMethodID(java_list_class, "add", "(Ljava/lang/Object;)Z");
+
+ java_array_list_class = LoadClass(env, "java/util/ArrayList");
+ java_array_list_init = env->GetMethodID(java_array_list_class, "", "()V");
+ java_array_list_init_with_capacity =
+ env->GetMethodID(java_array_list_class, "", "(I)V");
+
+ java_ray_exception_class = LoadClass(env, "org/ray/api/exception/RayException");
+
+ java_native_ray_object_class =
+ LoadClass(env, "org/ray/runtime/objectstore/NativeRayObject");
+ java_native_ray_object_init =
+ env->GetMethodID(java_native_ray_object_class, "", "([B[B)V");
+ java_native_ray_object_data =
+ env->GetFieldID(java_native_ray_object_class, "data", "[B");
+ java_native_ray_object_metadata =
+ env->GetFieldID(java_native_ray_object_class, "metadata", "[B");
+
+ return JNI_VERSION;
+}
+
+/// Unload java classes
+void JNI_OnUnload(JavaVM *vm, void *reserved) {
+ JNIEnv *env;
+ vm->GetEnv(reinterpret_cast(&env), JNI_VERSION);
+
+ env->DeleteGlobalRef(java_boolean_class);
+ env->DeleteGlobalRef(java_list_class);
+ env->DeleteGlobalRef(java_array_list_class);
+ env->DeleteGlobalRef(java_ray_exception_class);
+ env->DeleteGlobalRef(java_native_ray_object_class);
+}
diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h
new file mode 100644
index 000000000..d0f4ca8a5
--- /dev/null
+++ b/src/ray/core_worker/lib/java/jni_utils.h
@@ -0,0 +1,180 @@
+#ifndef RAY_COMMON_JAVA_JNI_HELPER_H
+#define RAY_COMMON_JAVA_JNI_HELPER_H
+
+#include
+#include "ray/common/buffer.h"
+#include "ray/common/id.h"
+#include "ray/common/status.h"
+#include "ray/core_worker/store_provider/store_provider.h"
+
+/// Boolean class
+extern jclass java_boolean_class;
+/// Constructor of Boolean class
+extern jmethodID java_boolean_init;
+
+/// List class
+extern jclass java_list_class;
+/// size method of List class
+extern jmethodID java_list_size;
+/// get method of List class
+extern jmethodID java_list_get;
+/// add method of List class
+extern jmethodID java_list_add;
+
+/// ArrayList class
+extern jclass java_array_list_class;
+/// Constructor of ArrayList class
+extern jmethodID java_array_list_init;
+/// Constructor of ArrayList class with single parameter capacity
+extern jmethodID java_array_list_init_with_capacity;
+
+/// RayException class
+extern jclass java_ray_exception_class;
+
+/// NativeRayObject class
+extern jclass java_native_ray_object_class;
+/// Constructor of NativeRayObject class
+extern jmethodID java_native_ray_object_init;
+/// data field of NativeRayObject class
+extern jfieldID java_native_ray_object_data;
+/// metadata field of NativeRayObject class
+extern jfieldID java_native_ray_object_metadata;
+
+/// Throws a Java RayException if the status is not OK.
+#define THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, ret) \
+ { \
+ if (!(status).ok()) { \
+ (env)->ThrowNew(java_ray_exception_class, (status).message().c_str()); \
+ return (ret); \
+ } \
+ }
+
+/// Convert a Java byte array to a C++ UniqueID.
+template
+inline ID JavaByteArrayToId(JNIEnv *env, const jbyteArray &bytes) {
+ std::string id_str(ID::Size(), 0);
+ env->GetByteArrayRegion(bytes, 0, ID::Size(),
+ reinterpret_cast(&id_str.front()));
+ return ID::FromBinary(id_str);
+}
+
+/// Convert C++ UniqueID to a Java byte array.
+template
+inline jbyteArray IdToJavaByteArray(JNIEnv *env, const ID &id) {
+ jbyteArray array = env->NewByteArray(ID::Size());
+ env->SetByteArrayRegion(array, 0, ID::Size(),
+ reinterpret_cast(id.Data()));
+ return array;
+}
+
+/// Convert C++ UniqueID to a Java ByteBuffer.
+template
+inline jobject IdToJavaByteBuffer(JNIEnv *env, const ID &id) {
+ return env->NewDirectByteBuffer(
+ reinterpret_cast(const_cast(id.Data())), id.Size());
+}
+
+/// Convert a Java String to C++ std::string.
+inline std::string JavaStringToNativeString(JNIEnv *env, jstring jstr) {
+ const char *c_str = env->GetStringUTFChars(jstr, nullptr);
+ std::string result(c_str);
+ env->ReleaseStringUTFChars(static_cast(jstr), c_str);
+ return result;
+}
+
+/// Convert a Java List to C++ std::vector.
+template
+inline void JavaListToNativeVector(
+ JNIEnv *env, jobject java_list, std::vector *native_vector,
+ std::function element_converter) {
+ int size = env->CallIntMethod(java_list, java_list_size);
+ native_vector->clear();
+ for (int i = 0; i < size; i++) {
+ native_vector->emplace_back(
+ element_converter(env, env->CallObjectMethod(java_list, java_list_get, (jint)i)));
+ }
+}
+
+/// Convert a C++ std::vector to a Java List.
+template
+inline jobject NativeVectorToJavaList(
+ JNIEnv *env, const std::vector &native_vector,
+ std::function element_converter) {
+ jobject java_list =
+ env->NewObject(java_array_list_class, java_array_list_init_with_capacity,
+ (jint)native_vector.size());
+ for (const auto &item : native_vector) {
+ env->CallVoidMethod(java_list, java_list_add, element_converter(env, item));
+ }
+ return java_list;
+}
+
+/// Convert a C++ ray::Buffer to a Java byte array.
+inline jbyteArray NativeBufferToJavaByteArray(JNIEnv *env,
+ const std::shared_ptr buffer) {
+ if (!buffer) {
+ return nullptr;
+ }
+ jbyteArray java_byte_array = env->NewByteArray(buffer->Size());
+ if (buffer->Size() > 0) {
+ env->SetByteArrayRegion(java_byte_array, 0, buffer->Size(),
+ reinterpret_cast(buffer->Data()));
+ }
+ return java_byte_array;
+}
+
+/// A helper method to help access a Java NativeRayObject instance and ensure memory
+/// safety.
+///
+/// \param[in] java_obj The Java NativeRayObject object.
+/// \param[in] reader The callback function to access a C++ ray::RayObject instance.
+/// \return The return value of callback function.
+template
+inline ReturnT ReadJavaNativeRayObject(
+ JNIEnv *env, const jobject &java_obj,
+ std::function &)> reader) {
+ if (!java_obj) {
+ return reader(nullptr);
+ }
+ auto java_data = (jbyteArray)env->GetObjectField(java_obj, java_native_ray_object_data);
+ auto java_metadata =
+ (jbyteArray)env->GetObjectField(java_obj, java_native_ray_object_metadata);
+ auto data_size = env->GetArrayLength(java_data);
+ jbyte *data = data_size > 0 ? env->GetByteArrayElements(java_data, nullptr) : nullptr;
+ auto metadata_size = java_metadata ? env->GetArrayLength(java_metadata) : 0;
+ jbyte *metadata =
+ metadata_size > 0 ? env->GetByteArrayElements(java_metadata, nullptr) : nullptr;
+ auto data_buffer = std::make_shared(
+ reinterpret_cast(data), data_size);
+ auto metadata_buffer = java_metadata
+ ? std::make_shared(
+ reinterpret_cast(metadata), metadata_size)
+ : nullptr;
+
+ auto native_obj = std::make_shared(data_buffer, metadata_buffer);
+ auto result = reader(native_obj);
+
+ if (data) {
+ env->ReleaseByteArrayElements(java_data, data, JNI_ABORT);
+ }
+ if (metadata) {
+ env->ReleaseByteArrayElements(java_metadata, metadata, JNI_ABORT);
+ }
+
+ return result;
+}
+
+/// Convert a C++ ray::RayObject to a Java NativeRayObject.
+inline jobject ToJavaNativeRayObject(JNIEnv *env,
+ const std::shared_ptr &rayObject) {
+ if (!rayObject) {
+ return nullptr;
+ }
+ auto java_data = NativeBufferToJavaByteArray(env, rayObject->GetData());
+ auto java_metadata = NativeBufferToJavaByteArray(env, rayObject->GetMetadata());
+ auto java_obj = env->NewObject(java_native_ray_object_class,
+ java_native_ray_object_init, java_data, java_metadata);
+ return java_obj;
+}
+
+#endif // RAY_COMMON_JAVA_JNI_HELPER_H
diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.cc b/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.cc
new file mode 100644
index 000000000..2c91dcdaa
--- /dev/null
+++ b/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.cc
@@ -0,0 +1,134 @@
+#include "ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h"
+#include
+#include "ray/common/id.h"
+#include "ray/core_worker/context.h"
+#include "ray/core_worker/lib/java/jni_utils.h"
+
+inline ray::WorkerContext *GetWorkerContextFromPointer(
+ jlong nativeWorkerContextFromPointer) {
+ return reinterpret_cast(nativeWorkerContextFromPointer);
+}
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeCreateWorkerContext
+ * Signature: (I[B)J
+ */
+JNIEXPORT jlong JNICALL Java_org_ray_runtime_WorkerContext_nativeCreateWorkerContext(
+ JNIEnv *env, jclass, jint workerType, jbyteArray jobId) {
+ return reinterpret_cast(
+ new ray::WorkerContext(static_cast(workerType),
+ JavaByteArrayToId(env, jobId)));
+}
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeGetCurrentTaskId
+ * Signature: (J)[B
+ */
+JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentTaskId(
+ JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
+ auto task_id =
+ GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentTaskID();
+ return IdToJavaByteArray(env, task_id);
+}
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeSetCurrentTask
+ * Signature: (J[B)V
+ */
+JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeSetCurrentTask(
+ JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer, jbyteArray taskSpec) {
+ jbyte *data = env->GetByteArrayElements(taskSpec, NULL);
+ jsize size = env->GetArrayLength(taskSpec);
+ ray::rpc::TaskSpec task_spec_message;
+ task_spec_message.ParseFromArray(data, size);
+ env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT);
+
+ ray::TaskSpecification spec(task_spec_message);
+ GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->SetCurrentTask(spec);
+}
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeGetCurrentTask
+ * Signature: (J)[B
+ */
+JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentTask(
+ JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
+ auto spec =
+ GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentTask();
+ if (!spec) {
+ return nullptr;
+ }
+
+ auto task_message = spec->Serialize();
+ jbyteArray result = env->NewByteArray(task_message.size());
+ env->SetByteArrayRegion(
+ result, 0, task_message.size(),
+ reinterpret_cast(const_cast(task_message.data())));
+ return result;
+}
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeGetCurrentJobId
+ * Signature: (J)Ljava/nio/ByteBuffer;
+ */
+JNIEXPORT jobject JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentJobId(
+ JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
+ const auto &job_id =
+ GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentJobID();
+ return IdToJavaByteBuffer(env, job_id);
+}
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeGetCurrentWorkerId
+ * Signature: (J)[B
+ */
+JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentWorkerId(
+ JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
+ auto worker_id =
+ GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetWorkerID();
+ return IdToJavaByteArray(env, worker_id);
+}
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeGetNextTaskIndex
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextTaskIndex(
+ JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
+ return GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetNextTaskIndex();
+}
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeGetNextPutIndex
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextPutIndex(
+ JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
+ return GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetNextPutIndex();
+}
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeDestroy
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeDestroy(
+ JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
+ delete GetWorkerContextFromPointer(nativeWorkerContextFromPointer);
+}
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h b/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h
new file mode 100644
index 000000000..df9c60a56
--- /dev/null
+++ b/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h
@@ -0,0 +1,87 @@
+/* DO NOT EDIT THIS FILE - it is machine generated */
+#include
+/* Header for class org_ray_runtime_WorkerContext */
+
+#ifndef _Included_org_ray_runtime_WorkerContext
+#define _Included_org_ray_runtime_WorkerContext
+#ifdef __cplusplus
+extern "C" {
+#endif
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeCreateWorkerContext
+ * Signature: (I[B)J
+ */
+JNIEXPORT jlong JNICALL Java_org_ray_runtime_WorkerContext_nativeCreateWorkerContext(
+ JNIEnv *, jclass, jint, jbyteArray);
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeGetCurrentTaskId
+ * Signature: (J)[B
+ */
+JNIEXPORT jbyteArray JNICALL
+Java_org_ray_runtime_WorkerContext_nativeGetCurrentTaskId(JNIEnv *, jclass, jlong);
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeSetCurrentTask
+ * Signature: (J[B)V
+ */
+JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeSetCurrentTask(
+ JNIEnv *, jclass, jlong, jbyteArray);
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeGetCurrentTask
+ * Signature: (J)[B
+ */
+JNIEXPORT jbyteArray JNICALL
+Java_org_ray_runtime_WorkerContext_nativeGetCurrentTask(JNIEnv *, jclass, jlong);
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeGetCurrentJobId
+ * Signature: (J)Ljava/nio/ByteBuffer;
+ */
+JNIEXPORT jobject JNICALL
+Java_org_ray_runtime_WorkerContext_nativeGetCurrentJobId(JNIEnv *, jclass, jlong);
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeGetCurrentWorkerId
+ * Signature: (J)[B
+ */
+JNIEXPORT jbyteArray JNICALL
+Java_org_ray_runtime_WorkerContext_nativeGetCurrentWorkerId(JNIEnv *, jclass, jlong);
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeGetNextTaskIndex
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextTaskIndex(JNIEnv *,
+ jclass,
+ jlong);
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeGetNextPutIndex
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextPutIndex(JNIEnv *,
+ jclass,
+ jlong);
+
+/*
+ * Class: org_ray_runtime_WorkerContext
+ * Method: nativeDestroy
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeDestroy(JNIEnv *, jclass,
+ jlong);
+
+#ifdef __cplusplus
+}
+#endif
+#endif
diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.cc b/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.cc
new file mode 100644
index 000000000..3c7bb43a0
--- /dev/null
+++ b/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.cc
@@ -0,0 +1,149 @@
+#include "ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h"
+#include
+#include "ray/common/id.h"
+#include "ray/core_worker/common.h"
+#include "ray/core_worker/lib/java/jni_utils.h"
+#include "ray/core_worker/object_interface.h"
+
+inline ray::CoreWorkerObjectInterface *GetObjectInterfaceFromPointer(
+ jlong nativeObjectInterfacePointer) {
+ return reinterpret_cast(nativeObjectInterfacePointer);
+}
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/*
+ * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
+ * Method: nativeCreateObjectInterface
+ * Signature: (JJLjava/lang/String;)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeCreateObjectInterface(
+ JNIEnv *env, jclass, jlong nativeWorkerContext, jlong nativeRayletClient,
+ jstring storeSocketName) {
+ return reinterpret_cast(new ray::CoreWorkerObjectInterface(
+ *reinterpret_cast(nativeWorkerContext),
+ *reinterpret_cast *>(nativeRayletClient),
+ JavaStringToNativeString(env, storeSocketName)));
+}
+
+/*
+ * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
+ * Method: nativePut
+ * Signature: (JLorg/ray/runtime/objectstore/NativeRayObject;)[B
+ */
+JNIEXPORT jbyteArray JNICALL
+Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__JLorg_ray_runtime_objectstore_NativeRayObject_2(
+ JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject obj) {
+ ray::Status status;
+ ray::ObjectID object_id = ReadJavaNativeRayObject(
+ env, obj,
+ [nativeObjectInterfacePointer,
+ &status](const std::shared_ptr &rayObject) {
+ RAY_CHECK(rayObject != nullptr);
+ ray::ObjectID object_id;
+ status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
+ ->Put(*rayObject, &object_id);
+ return object_id;
+ });
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
+ return IdToJavaByteArray(env, object_id);
+}
+
+/*
+ * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
+ * Method: nativePut
+ * Signature: (J[BLorg/ray/runtime/objectstore/NativeRayObject;)V
+ */
+JNIEXPORT void JNICALL
+Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__J_3BLorg_ray_runtime_objectstore_NativeRayObject_2(
+ JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jbyteArray objectId,
+ jobject obj) {
+ auto object_id = JavaByteArrayToId(env, objectId);
+ auto status = ReadJavaNativeRayObject(
+ env, obj,
+ [nativeObjectInterfacePointer,
+ &object_id](const std::shared_ptr &rayObject) {
+ RAY_CHECK(rayObject != nullptr);
+ return GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
+ ->Put(*rayObject, object_id);
+ });
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
+}
+
+/*
+ * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
+ * Method: nativeGet
+ * Signature: (JLjava/util/List;J)Ljava/util/List;
+ */
+JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeGet(
+ JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject ids,
+ jlong timeoutMs) {
+ std::vector object_ids;
+ JavaListToNativeVector(
+ env, ids, &object_ids, [](JNIEnv *env, jobject id) {
+ return JavaByteArrayToId(env, static_cast(id));
+ });
+ std::vector> results;
+ auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
+ ->Get(object_ids, (int64_t)timeoutMs, &results);
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
+ return NativeVectorToJavaList>(env, results,
+ ToJavaNativeRayObject);
+}
+
+/*
+ * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
+ * Method: nativeWait
+ * Signature: (JLjava/util/List;IJ)Ljava/util/List;
+ */
+JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeWait(
+ JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject objectIds,
+ jint numObjects, jlong timeoutMs) {
+ std::vector object_ids;
+ JavaListToNativeVector(
+ env, objectIds, &object_ids, [](JNIEnv *env, jobject id) {
+ return JavaByteArrayToId(env, static_cast(id));
+ });
+ std::vector results;
+ auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
+ ->Wait(object_ids, (int)numObjects, (int64_t)timeoutMs, &results);
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
+ return NativeVectorToJavaList(env, results, [](JNIEnv *env, const bool &item) {
+ return env->NewObject(java_boolean_class, java_boolean_init, (jboolean)item);
+ });
+}
+
+/*
+ * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
+ * Method: nativeDelete
+ * Signature: (JLjava/util/List;ZZ)V
+ */
+JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDelete(
+ JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject objectIds,
+ jboolean localOnly, jboolean deleteCreatingTasks) {
+ std::vector object_ids;
+ JavaListToNativeVector(
+ env, objectIds, &object_ids, [](JNIEnv *env, jobject id) {
+ return JavaByteArrayToId(env, static_cast(id));
+ });
+ auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
+ ->Delete(object_ids, (bool)localOnly, (bool)deleteCreatingTasks);
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
+}
+
+/*
+ * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
+ * Method: nativeDestroy
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDestroy(
+ JNIEnv *env, jclass, jlong nativeObjectInterfacePointer) {
+ delete GetObjectInterfaceFromPointer(nativeObjectInterfacePointer);
+}
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h b/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h
new file mode 100644
index 000000000..0ea41535e
--- /dev/null
+++ b/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h
@@ -0,0 +1,72 @@
+/* DO NOT EDIT THIS FILE - it is machine generated */
+#include
+/* Header for class org_ray_runtime_objectstore_ObjectInterfaceImpl */
+
+#ifndef _Included_org_ray_runtime_objectstore_ObjectInterfaceImpl
+#define _Included_org_ray_runtime_objectstore_ObjectInterfaceImpl
+#ifdef __cplusplus
+extern "C" {
+#endif
+/*
+ * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
+ * Method: nativeCreateObjectInterface
+ * Signature: (JJLjava/lang/String;)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeCreateObjectInterface(
+ JNIEnv *, jclass, jlong, jlong, jstring);
+
+/*
+ * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
+ * Method: nativePut
+ * Signature: (JLorg/ray/runtime/objectstore/NativeRayObject;)[B
+ */
+JNIEXPORT jbyteArray JNICALL
+Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__JLorg_ray_runtime_objectstore_NativeRayObject_2(
+ JNIEnv *, jclass, jlong, jobject);
+
+/*
+ * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
+ * Method: nativePut
+ * Signature: (J[BLorg/ray/runtime/objectstore/NativeRayObject;)V
+ */
+JNIEXPORT void JNICALL
+Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__J_3BLorg_ray_runtime_objectstore_NativeRayObject_2(
+ JNIEnv *, jclass, jlong, jbyteArray, jobject);
+
+/*
+ * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
+ * Method: nativeGet
+ * Signature: (JLjava/util/List;J)Ljava/util/List;
+ */
+JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeGet(
+ JNIEnv *, jclass, jlong, jobject, jlong);
+
+/*
+ * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
+ * Method: nativeWait
+ * Signature: (JLjava/util/List;IJ)Ljava/util/List;
+ */
+JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeWait(
+ JNIEnv *, jclass, jlong, jobject, jint, jlong);
+
+/*
+ * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
+ * Method: nativeDelete
+ * Signature: (JLjava/util/List;ZZ)V
+ */
+JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDelete(
+ JNIEnv *, jclass, jlong, jobject, jboolean, jboolean);
+
+/*
+ * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
+ * Method: nativeDestroy
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDestroy(
+ JNIEnv *, jclass, jlong);
+
+#ifdef __cplusplus
+}
+#endif
+#endif
diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc
index 53c330dc0..5a59420a5 100644
--- a/src/ray/core_worker/store_provider/plasma_store_provider.cc
+++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc
@@ -3,6 +3,7 @@
#include "ray/core_worker/context.h"
#include "ray/core_worker/core_worker.h"
#include "ray/core_worker/object_interface.h"
+#include "ray/protobuf/gcs.pb.h"
namespace ray {
@@ -101,11 +102,14 @@ Status CoreWorkerPlasmaStoreProvider::Get(
std::make_shared(object_buffers[i].data),
std::make_shared(object_buffers[i].metadata));
unready.erase(object_id);
+ if (IsException(object_buffers[i])) {
+ should_break = true;
+ }
}
}
num_attempts += 1;
- // TODO(zhijunfu): log a message if attempted too many times.
+ WarnIfAttemptedTooManyTimes(num_attempts, unready);
}
if (was_blocked) {
@@ -144,4 +148,45 @@ Status CoreWorkerPlasmaStoreProvider::Delete(const std::vector &object
return raylet_client_->FreeObjects(object_ids, local_only, delete_creating_tasks);
}
+bool CoreWorkerPlasmaStoreProvider::IsException(const plasma::ObjectBuffer &buffer) {
+ // TODO (kfstorm): metadata should be structured.
+ const std::string metadata = buffer.metadata->ToString();
+ const auto error_type_descriptor = ray::rpc::ErrorType_descriptor();
+ for (int i = 0; i < error_type_descriptor->value_count(); i++) {
+ const auto error_type_number = error_type_descriptor->value(i)->number();
+ if (metadata == std::to_string(error_type_number)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void CoreWorkerPlasmaStoreProvider::WarnIfAttemptedTooManyTimes(
+ int num_attempts, const std::unordered_map &unready) {
+ if (num_attempts % RayConfig::instance().object_store_get_warn_per_num_attempts() ==
+ 0) {
+ std::ostringstream oss;
+ size_t printed = 0;
+ for (auto &entry : unready) {
+ if (printed >=
+ RayConfig::instance().object_store_get_max_ids_to_print_in_warning()) {
+ break;
+ }
+ if (printed > 0) {
+ oss << ", ";
+ }
+ oss << entry.first.Hex();
+ }
+ if (printed < unready.size()) {
+ oss << ", etc";
+ }
+ RAY_LOG(WARNING)
+ << "Attempted " << num_attempts << " times to reconstruct objects, but "
+ << "some objects are still unavailable. If this message continues to print,"
+ << " it may indicate that object's creating task is hanging, or something wrong"
+ << " happened in raylet backend. " << unready.size()
+ << " object(s) pending: " << oss.str() << ".";
+ }
+}
+
} // namespace ray
diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h
index 9aa2f914a..89ecb2ea2 100644
--- a/src/ray/core_worker/store_provider/plasma_store_provider.h
+++ b/src/ray/core_worker/store_provider/plasma_store_provider.h
@@ -60,6 +60,20 @@ class CoreWorkerPlasmaStoreProvider : public CoreWorkerStoreProvider {
bool delete_creating_tasks) override;
private:
+ /// Whether the buffer represents an exception object.
+ ///
+ /// \param[in] buffer the object buffer.
+ /// \return Whether it represents an exception object.
+ static bool IsException(const plasma::ObjectBuffer &buffer);
+
+ /// Print a warning if we've attempted too many times, but some objects are still
+ /// unavailable.
+ ///
+ /// \param[in] num_attemps The number of attempted times.
+ /// \param[in] unready The unready objects.
+ static void WarnIfAttemptedTooManyTimes(
+ int num_attempts, const std::unordered_map &unready);
+
/// Plasma store client.
plasma::PlasmaClient store_client_;
diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto
index 8d14c004a..48a3f1cf7 100644
--- a/src/ray/protobuf/common.proto
+++ b/src/ray/protobuf/common.proto
@@ -11,6 +11,12 @@ enum Language {
CPP = 2;
}
+// Type of a worker.
+enum WorkerType {
+ WORKER = 0;
+ DRIVER = 1;
+}
+
// Type of a task.
enum TaskType {
// Normal task.
diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto
index 5dcc36662..05d97750f 100644
--- a/src/ray/protobuf/gcs.proto
+++ b/src/ray/protobuf/gcs.proto
@@ -267,4 +267,6 @@ enum ErrorType {
// 2) The object's creating task is already cleaned up from GCS (this currently
// crashes raylet).
OBJECT_UNRECONSTRUCTABLE = 2;
+ // Indicates that a task failed due to user code failure.
+ TASK_EXECUTION_EXCEPTION = 3;
}
diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc
index ac6d33b9d..fb4390fb5 100644
--- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc
+++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc
@@ -3,39 +3,14 @@
#include
#include "ray/common/id.h"
+#include "ray/core_worker/lib/java/jni_utils.h"
#include "ray/raylet/raylet_client.h"
#include "ray/util/logging.h"
-template
-class UniqueIdFromJByteArray {
- public:
- const ID &GetId() const { return id; }
-
- UniqueIdFromJByteArray(JNIEnv *env, const jbyteArray &bytes) {
- std::string id_str(ID::Size(), 0);
- env->GetByteArrayRegion(bytes, 0, ID::Size(),
- reinterpret_cast(&id_str.front()));
- id = ID::FromBinary(id_str);
- }
-
- private:
- ID id;
-};
-
#ifdef __cplusplus
extern "C" {
#endif
-inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) {
- if (!status.ok()) {
- jclass exception_class = env->FindClass("org/ray/api/exception/RayException");
- env->ThrowNew(exception_class, status.message().c_str());
- return true;
- } else {
- return false;
- }
-}
-
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeInit
@@ -44,11 +19,11 @@ inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) {
JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker,
jbyteArray jobId) {
- UniqueIdFromJByteArray worker_id(env, workerId);
- UniqueIdFromJByteArray job_id(env, jobId);
+ const auto worker_id = JavaByteArrayToId(env, workerId);
+ const auto job_id = JavaByteArrayToId(env, jobId);
const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE);
- auto raylet_client = new RayletClient(nativeString, worker_id.GetId(), isWorker,
- job_id.GetId(), Language::JAVA);
+ auto raylet_client = new std::unique_ptr(
+ new RayletClient(nativeString, worker_id, isWorker, job_id, Language::JAVA));
env->ReleaseStringUTFChars(sockName, nativeString);
return reinterpret_cast(raylet_client);
}
@@ -60,7 +35,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask(
JNIEnv *env, jclass, jlong client, jbyteArray taskSpec) {
- auto raylet_client = reinterpret_cast(client);
+ auto &raylet_client = *reinterpret_cast *>(client);
jbyte *data = env->GetByteArrayElements(taskSpec, NULL);
jsize size = env->GetArrayLength(taskSpec);
@@ -70,7 +45,7 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit
ray::TaskSpecification task_spec(task_spec_message);
auto status = raylet_client->SubmitTask(task_spec);
- ThrowRayExceptionIfNotOK(env, status);
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
@@ -80,13 +55,11 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit
*/
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask(
JNIEnv *env, jclass, jlong client) {
- auto raylet_client = reinterpret_cast(client);
+ auto &raylet_client = *reinterpret_cast *>(client);
std::unique_ptr spec;
auto status = raylet_client->GetTask(&spec);
- if (ThrowRayExceptionIfNotOK(env, status)) {
- return nullptr;
- }
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
// Serialize the task spec and copy to Java byte array.
auto task_data = spec->Serialize();
@@ -109,8 +82,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_native
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy(
JNIEnv *env, jclass, jlong client) {
- auto raylet_client = reinterpret_cast(client);
- ThrowRayExceptionIfNotOK(env, raylet_client->Disconnect());
+ auto raylet_client = reinterpret_cast *>(client);
+ auto status = (*raylet_client)->Disconnect();
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
delete raylet_client;
}
@@ -128,15 +102,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(
for (int i = 0; i < len; i++) {
jbyteArray object_id_bytes =
static_cast(env->GetObjectArrayElement(objectIds, i));
- UniqueIdFromJByteArray object_id(env, object_id_bytes);
- object_ids.push_back(object_id.GetId());
+ const auto object_id = JavaByteArrayToId(env, object_id_bytes);
+ object_ids.push_back(object_id);
env->DeleteLocalRef(object_id_bytes);
}
- UniqueIdFromJByteArray current_task_id(env, currentTaskId);
- auto raylet_client = reinterpret_cast(client);
- auto status =
- raylet_client->FetchOrReconstruct(object_ids, fetchOnly, current_task_id.GetId());
- ThrowRayExceptionIfNotOK(env, status);
+ const auto current_task_id = JavaByteArrayToId(env, currentTaskId);
+ auto &raylet_client = *reinterpret_cast *>(client);
+ auto status = raylet_client->FetchOrReconstruct(object_ids, fetchOnly, current_task_id);
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
@@ -146,10 +119,10 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked(
JNIEnv *env, jclass, jlong client, jbyteArray currentTaskId) {
- UniqueIdFromJByteArray current_task_id(env, currentTaskId);
- auto raylet_client = reinterpret_cast(client);
- auto status = raylet_client->NotifyUnblocked(current_task_id.GetId());
- ThrowRayExceptionIfNotOK(env, status);
+ const auto current_task_id = JavaByteArrayToId(env, currentTaskId);
+ auto &raylet_client = *reinterpret_cast *>(client);
+ auto status = raylet_client->NotifyUnblocked(current_task_id);
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
@@ -166,22 +139,20 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject(
for (int i = 0; i < len; i++) {
jbyteArray object_id_bytes =
static_cast(env->GetObjectArrayElement(objectIds, i));
- UniqueIdFromJByteArray object_id(env, object_id_bytes);
- object_ids.push_back(object_id.GetId());
+ const auto object_id = JavaByteArrayToId(env, object_id_bytes);
+ object_ids.push_back(object_id);
env->DeleteLocalRef(object_id_bytes);
}
- UniqueIdFromJByteArray current_task_id(env, currentTaskId);
+ const auto current_task_id = JavaByteArrayToId(env, currentTaskId);
- auto raylet_client = reinterpret_cast(client);
+ auto &raylet_client = *reinterpret_cast *>(client);
// Invoke wait.
WaitResultPair result;
- auto status = raylet_client->Wait(object_ids, numReturns, timeoutMillis,
- static_cast(isWaitLocal),
- current_task_id.GetId(), &result);
- if (ThrowRayExceptionIfNotOK(env, status)) {
- return nullptr;
- }
+ auto status =
+ raylet_client->Wait(object_ids, numReturns, timeoutMillis,
+ static_cast(isWaitLocal), current_task_id, &result);
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
// Convert result to java object.
jboolean put_value = true;
@@ -216,11 +187,10 @@ JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId(
JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId,
jint parent_task_counter) {
- UniqueIdFromJByteArray job_id(env, jobId);
- UniqueIdFromJByteArray parent_task_id(env, parentTaskId);
+ const auto job_id = JavaByteArrayToId(env, jobId);
+ const auto parent_task_id = JavaByteArrayToId(env, parentTaskId);
- TaskID task_id =
- ray::GenerateTaskId(job_id.GetId(), parent_task_id.GetId(), parent_task_counter);
+ TaskID task_id = ray::GenerateTaskId(job_id, parent_task_id, parent_task_counter);
jbyteArray result = env->NewByteArray(task_id.Size());
if (nullptr == result) {
return nullptr;
@@ -245,13 +215,13 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects(
for (int i = 0; i < len; i++) {
jbyteArray object_id_bytes =
static_cast(env->GetObjectArrayElement(objectIds, i));
- UniqueIdFromJByteArray object_id(env, object_id_bytes);
- object_ids.push_back(object_id.GetId());
+ const auto object_id = JavaByteArrayToId(env, object_id_bytes);
+ object_ids.push_back(object_id);
env->DeleteLocalRef(object_id_bytes);
}
- auto raylet_client = reinterpret_cast(client);
+ auto &raylet_client = *reinterpret_cast *>(client);
auto status = raylet_client->FreeObjects(object_ids, localOnly, deleteCreatingTasks);
- ThrowRayExceptionIfNotOK(env, status);
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
@@ -263,13 +233,11 @@ JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env, jclass,
jlong client,
jbyteArray actorId) {
- auto raylet_client = reinterpret_cast(client);
- UniqueIdFromJByteArray actor_id(env, actorId);
+ auto &raylet_client = *reinterpret_cast *>(client);
+ const auto actor_id = JavaByteArrayToId(env, actorId);
ActorCheckpointID checkpoint_id;
- auto status = raylet_client->PrepareActorCheckpoint(actor_id.GetId(), checkpoint_id);
- if (ThrowRayExceptionIfNotOK(env, status)) {
- return nullptr;
- }
+ auto status = raylet_client->PrepareActorCheckpoint(actor_id, checkpoint_id);
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
jbyteArray result = env->NewByteArray(checkpoint_id.Size());
env->SetByteArrayRegion(result, 0, checkpoint_id.Size(),
reinterpret_cast(checkpoint_id.Data()));
@@ -284,12 +252,11 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env
JNIEXPORT void JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint(
JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) {
- auto raylet_client = reinterpret_cast(client);
- UniqueIdFromJByteArray actor_id(env, actorId);
- UniqueIdFromJByteArray checkpoint_id(env, checkpointId);
- auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id.GetId(),
- checkpoint_id.GetId());
- ThrowRayExceptionIfNotOK(env, status);
+ auto &raylet_client = *reinterpret_cast *>(client);
+ const auto actor_id = JavaByteArrayToId(env, actorId);
+ const auto checkpoint_id = JavaByteArrayToId(env, checkpointId);
+ auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id);
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
@@ -300,14 +267,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpo
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource(
JNIEnv *env, jclass, jlong client, jstring resourceName, jdouble capacity,
jbyteArray nodeId) {
- auto raylet_client = reinterpret_cast(client);
- UniqueIdFromJByteArray node_id(env, nodeId);
+ auto &raylet_client = *reinterpret_cast *>(client);
+ const auto node_id = JavaByteArrayToId(env, nodeId);
const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE);
- auto status = raylet_client->SetResource(
- native_resource_name, static_cast(capacity), node_id.GetId());
+ auto status = raylet_client->SetResource(native_resource_name,
+ static_cast(capacity), node_id);
env->ReleaseStringUTFChars(resourceName, native_resource_name);
- ThrowRayExceptionIfNotOK(env, status);
+ THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
#ifdef __cplusplus