mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[Java worker] Refactor object store and worker context on top of core worker (#5079)
This commit is contained in:
+1
-8
@@ -69,7 +69,6 @@ define_java_module(
|
||||
],
|
||||
deps = [
|
||||
":org_ray_ray_api",
|
||||
"@plasma//:org_apache_arrow_arrow_plasma",
|
||||
"@maven//:com_google_guava_guava",
|
||||
"@maven//:com_google_protobuf_protobuf_java",
|
||||
"@maven//:com_typesafe_config",
|
||||
@@ -97,7 +96,6 @@ define_java_module(
|
||||
deps = [
|
||||
":org_ray_ray_api",
|
||||
":org_ray_ray_runtime",
|
||||
"@plasma//:org_apache_arrow_arrow_plasma",
|
||||
"@maven//:com_google_guava_guava",
|
||||
"@maven//:com_sun_xml_bind_jaxb_core",
|
||||
"@maven//:com_sun_xml_bind_jaxb_impl",
|
||||
@@ -176,9 +174,8 @@ filegroup(
|
||||
"//:redis-server",
|
||||
"//:libray_redis_module.so",
|
||||
"//:raylet",
|
||||
"//:raylet_library_java",
|
||||
"//:core_worker_library_java",
|
||||
"@plasma//:plasma_store_server",
|
||||
"@plasma//:plasma_client_java",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -189,7 +186,6 @@ genrule(
|
||||
":all_java_proto",
|
||||
":java_native_deps",
|
||||
":copy_pom_file",
|
||||
"@plasma//:org_apache_arrow_arrow_plasma",
|
||||
],
|
||||
outs = ["gen_maven_deps.out"],
|
||||
cmd = """
|
||||
@@ -208,9 +204,6 @@ genrule(
|
||||
chmod +w $$f
|
||||
cp $$f $$NATIVE_DEPS_DIR
|
||||
done
|
||||
# Install plasma jar to local maven repo.
|
||||
mvn install:install-file -Dfile=$(locations @plasma//:org_apache_arrow_arrow_plasma) -Dpackaging=jar \
|
||||
-DgroupId=org.apache.arrow -DartifactId=arrow-plasma -Dversion=0.13.0-SNAPSHOT
|
||||
echo $$(date) > $@
|
||||
""",
|
||||
local = 1,
|
||||
|
||||
@@ -24,11 +24,6 @@
|
||||
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.apache.arrow</groupId>
|
||||
<artifactId>arrow-plasma</artifactId>
|
||||
<version>0.13.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
|
||||
@@ -22,10 +22,6 @@
|
||||
<artifactId>ray-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.arrow</groupId>
|
||||
<artifactId>arrow-plasma</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.beust</groupId>
|
||||
<artifactId>jcommander</artifactId>
|
||||
|
||||
@@ -22,10 +22,6 @@
|
||||
<artifactId>ray-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.arrow</groupId>
|
||||
<artifactId>arrow-plasma</artifactId>
|
||||
</dependency>
|
||||
{generated_bzl_deps}
|
||||
</dependencies>
|
||||
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
@@ -72,6 +80,27 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
protected RuntimeContext runtimeContext;
|
||||
protected GcsClient gcsClient;
|
||||
|
||||
static {
|
||||
try {
|
||||
LOGGER.debug("Loading native libraries.");
|
||||
// Load native libraries.
|
||||
String[] libraries = new String[]{"core_worker_library_java"};
|
||||
for (String library : libraries) {
|
||||
String fileName = System.mapLibraryName(library);
|
||||
// Copy the file from resources to a temp dir, and load the native library.
|
||||
File file = File.createTempFile(fileName, "");
|
||||
file.deleteOnExit();
|
||||
InputStream in = AbstractRayRuntime.class.getResourceAsStream("/" + fileName);
|
||||
Preconditions.checkNotNull(in, "{} doesn't exist.", fileName);
|
||||
Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
|
||||
System.load(file.getAbsolutePath());
|
||||
}
|
||||
LOGGER.debug("Native libraries loaded.");
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Couldn't load native libraries.", e);
|
||||
}
|
||||
}
|
||||
|
||||
public AbstractRayRuntime(RayConfig rayConfig) {
|
||||
this.rayConfig = rayConfig;
|
||||
functionManager = new FunctionManager(rayConfig.jobResourcePath);
|
||||
@@ -79,6 +108,33 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
runtimeContext = new RuntimeContextImpl(this);
|
||||
}
|
||||
|
||||
protected void resetLibraryPath() {
|
||||
if (rayConfig.libraryPath.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
String path = System.getProperty("java.library.path");
|
||||
if (Strings.isNullOrEmpty(path)) {
|
||||
path = "";
|
||||
} else {
|
||||
path += ":";
|
||||
}
|
||||
path += String.join(":", rayConfig.libraryPath);
|
||||
|
||||
// This is a hack to reset library path at runtime,
|
||||
// see https://stackoverflow.com/questions/15409223/.
|
||||
System.setProperty("java.library.path", path);
|
||||
// Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed.
|
||||
final Field sysPathsField;
|
||||
try {
|
||||
sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
|
||||
sysPathsField.setAccessible(true);
|
||||
sysPathsField.set(null, null);
|
||||
} catch (NoSuchFieldException | IllegalAccessException e) {
|
||||
LOGGER.error("Failed to set library path.", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start runtime.
|
||||
*/
|
||||
@@ -330,8 +386,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
* Create the task specification.
|
||||
*
|
||||
* @param func The target remote function.
|
||||
* @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a
|
||||
* Python task.
|
||||
* @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a Python
|
||||
* task.
|
||||
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
|
||||
* @param args The arguments for the remote function.
|
||||
* @param isActorCreationTask Whether this task is an actor creation task.
|
||||
|
||||
@@ -3,7 +3,7 @@ package org.ray.runtime;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.objectstore.MockObjectStore;
|
||||
import org.ray.runtime.objectstore.MockObjectInterface;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.ray.runtime.raylet.MockRayletClient;
|
||||
|
||||
@@ -13,19 +13,22 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
||||
super(rayConfig);
|
||||
}
|
||||
|
||||
private MockObjectStore store;
|
||||
private MockObjectInterface objectInterface;
|
||||
|
||||
private AtomicInteger jobCounter = new AtomicInteger(0);
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
store = new MockObjectStore(this);
|
||||
// Reset library path at runtime.
|
||||
resetLibraryPath();
|
||||
|
||||
objectInterface = new MockObjectInterface(workerContext);
|
||||
if (rayConfig.getJobId().isNil()) {
|
||||
rayConfig.setJobId(nextJobId());
|
||||
}
|
||||
workerContext = new WorkerContext(rayConfig.workerMode,
|
||||
rayConfig.getJobId(), rayConfig.runMode);
|
||||
objectStoreProxy = new ObjectStoreProxy(this, null);
|
||||
objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterface);
|
||||
rayletClient = new MockRayletClient(this, rayConfig.numberExecThreadsForDevRuntime);
|
||||
}
|
||||
|
||||
@@ -34,8 +37,8 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
||||
rayletClient.destroy();
|
||||
}
|
||||
|
||||
public MockObjectStore getObjectStore() {
|
||||
return store;
|
||||
public MockObjectInterface getObjectInterface() {
|
||||
return objectInterface;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,21 +1,13 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Strings;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.WorkerMode;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
import org.ray.runtime.gcs.RedisClient;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.ray.runtime.objectstore.ObjectInterfaceImpl;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.ray.runtime.raylet.RayletClientImpl;
|
||||
import org.ray.runtime.runner.RunManager;
|
||||
@@ -31,58 +23,12 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
|
||||
private RunManager manager = null;
|
||||
|
||||
static {
|
||||
try {
|
||||
LOGGER.debug("Loading native libraries.");
|
||||
// Load native libraries.
|
||||
String[] libraries = new String[]{"raylet_library_java", "plasma_java"};
|
||||
for (String library : libraries) {
|
||||
String fileName = System.mapLibraryName(library);
|
||||
// Copy the file from resources to a temp dir, and load the native library.
|
||||
File file = File.createTempFile(fileName, "");
|
||||
file.deleteOnExit();
|
||||
InputStream in = RayNativeRuntime.class.getResourceAsStream("/" + fileName);
|
||||
Preconditions.checkNotNull(in, "{} doesn't exist.", fileName);
|
||||
Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
|
||||
System.load(file.getAbsolutePath());
|
||||
}
|
||||
LOGGER.debug("Native libraries loaded.");
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Couldn't load native libraries.", e);
|
||||
}
|
||||
}
|
||||
private ObjectInterfaceImpl objectInterfaceImpl = null;
|
||||
|
||||
public RayNativeRuntime(RayConfig rayConfig) {
|
||||
super(rayConfig);
|
||||
}
|
||||
|
||||
private void resetLibraryPath() {
|
||||
if (rayConfig.libraryPath.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
String path = System.getProperty("java.library.path");
|
||||
if (Strings.isNullOrEmpty(path)) {
|
||||
path = "";
|
||||
} else {
|
||||
path += ":";
|
||||
}
|
||||
path += String.join(":", rayConfig.libraryPath);
|
||||
|
||||
// This is a hack to reset library path at runtime,
|
||||
// see https://stackoverflow.com/questions/15409223/.
|
||||
System.setProperty("java.library.path", path);
|
||||
// Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed.
|
||||
final Field sysPathsField;
|
||||
try {
|
||||
sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
|
||||
sysPathsField.setAccessible(true);
|
||||
sysPathsField.set(null, null);
|
||||
} catch (NoSuchFieldException | IllegalAccessException e) {
|
||||
LOGGER.error("Failed to set library path.", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
// Reset library path at runtime.
|
||||
@@ -101,16 +47,18 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
|
||||
workerContext = new WorkerContext(rayConfig.workerMode,
|
||||
rayConfig.getJobId(), rayConfig.runMode);
|
||||
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
|
||||
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
|
||||
|
||||
rayletClient = new RayletClientImpl(
|
||||
rayConfig.rayletSocketName,
|
||||
workerContext.getCurrentWorkerId(),
|
||||
rayConfig.workerMode == WorkerMode.WORKER,
|
||||
rayConfig.workerMode == WorkerType.WORKER,
|
||||
workerContext.getCurrentJobId()
|
||||
);
|
||||
|
||||
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
|
||||
objectInterfaceImpl = new ObjectInterfaceImpl(workerContext, rayletClient,
|
||||
rayConfig.objectStoreSocketName);
|
||||
objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterfaceImpl);
|
||||
|
||||
// register
|
||||
registerWorker();
|
||||
|
||||
@@ -123,6 +71,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
if (null != manager) {
|
||||
manager.cleanup();
|
||||
}
|
||||
objectInterfaceImpl.destroy();
|
||||
workerContext.destroy();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -132,7 +82,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
RedisClient redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
|
||||
Map<String, String> workerInfo = new HashMap<>();
|
||||
String workerId = new String(workerContext.getCurrentWorkerId().getBytes());
|
||||
if (rayConfig.workerMode == WorkerMode.DRIVER) {
|
||||
if (rayConfig.workerMode == WorkerType.DRIVER) {
|
||||
workerInfo.put("node_ip_address", rayConfig.nodeIp);
|
||||
workerInfo.put("driver_id", workerId);
|
||||
workerInfo.put("start_time", String.valueOf(System.currentTimeMillis()));
|
||||
|
||||
@@ -1,37 +1,24 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.config.WorkerMode;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.ray.runtime.raylet.RayletClientImpl;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.ray.runtime.util.IdUtil;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* This is a wrapper class for worker context of core worker.
|
||||
*/
|
||||
public class WorkerContext {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContext.class);
|
||||
|
||||
private UniqueId workerId;
|
||||
|
||||
private ThreadLocal<TaskId> currentTaskId;
|
||||
|
||||
/**
|
||||
* Number of objects that have been put from current task.
|
||||
* The native pointer of worker context of core worker.
|
||||
*/
|
||||
private ThreadLocal<Integer> putIndex;
|
||||
|
||||
/**
|
||||
* Number of tasks that have been submitted from current task.
|
||||
*/
|
||||
private ThreadLocal<Integer> taskIndex;
|
||||
|
||||
private ThreadLocal<TaskSpec> currentTask;
|
||||
|
||||
private JobId currentJobId;
|
||||
private final long nativeWorkerContextPointer;
|
||||
|
||||
private ClassLoader currentClassLoader;
|
||||
|
||||
@@ -45,31 +32,23 @@ public class WorkerContext {
|
||||
*/
|
||||
private RunMode runMode;
|
||||
|
||||
public WorkerContext(WorkerMode workerMode, JobId jobId, RunMode runMode) {
|
||||
public WorkerContext(WorkerType workerType, JobId jobId, RunMode runMode) {
|
||||
this.nativeWorkerContextPointer = nativeCreateWorkerContext(workerType.getNumber(), jobId.getBytes());
|
||||
mainThreadId = Thread.currentThread().getId();
|
||||
taskIndex = ThreadLocal.withInitial(() -> 0);
|
||||
putIndex = ThreadLocal.withInitial(() -> 0);
|
||||
currentTaskId = ThreadLocal.withInitial(TaskId::randomId);
|
||||
this.runMode = runMode;
|
||||
currentTask = ThreadLocal.withInitial(() -> null);
|
||||
currentClassLoader = null;
|
||||
if (workerMode == WorkerMode.DRIVER) {
|
||||
workerId = IdUtil.computeDriverId(jobId);
|
||||
currentTaskId.set(TaskId.randomId());
|
||||
currentJobId = jobId;
|
||||
} else {
|
||||
workerId = UniqueId.randomId();
|
||||
this.currentTaskId.set(TaskId.NIL);
|
||||
this.currentJobId = JobId.NIL;
|
||||
}
|
||||
}
|
||||
|
||||
public long getNativeWorkerContext() {
|
||||
return nativeWorkerContextPointer;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return For the main thread, this method returns the ID of this worker's current running task;
|
||||
* for other threads, this method returns a random ID.
|
||||
* for other threads, this method returns a random ID.
|
||||
*/
|
||||
public TaskId getCurrentTaskId() {
|
||||
return currentTaskId.get();
|
||||
return new TaskId(nativeGetCurrentTaskId(nativeWorkerContextPointer));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -79,17 +58,14 @@ public class WorkerContext {
|
||||
public void setCurrentTask(TaskSpec task, ClassLoader classLoader) {
|
||||
if (runMode == RunMode.CLUSTER) {
|
||||
Preconditions.checkState(
|
||||
Thread.currentThread().getId() == mainThreadId,
|
||||
"This method should only be called from the main thread."
|
||||
Thread.currentThread().getId() == mainThreadId,
|
||||
"This method should only be called from the main thread."
|
||||
);
|
||||
}
|
||||
|
||||
Preconditions.checkNotNull(task);
|
||||
this.currentTaskId.set(task.taskId);
|
||||
this.currentJobId = task.jobId;
|
||||
taskIndex.set(0);
|
||||
putIndex.set(0);
|
||||
this.currentTask.set(task);
|
||||
byte[] taskSpec = RayletClientImpl.convertTaskSpecToProtobuf(task);
|
||||
nativeSetCurrentTask(nativeWorkerContextPointer, taskSpec);
|
||||
currentClassLoader = classLoader;
|
||||
}
|
||||
|
||||
@@ -97,30 +73,28 @@ public class WorkerContext {
|
||||
* Increment the put index and return the new value.
|
||||
*/
|
||||
public int nextPutIndex() {
|
||||
putIndex.set(putIndex.get() + 1);
|
||||
return putIndex.get();
|
||||
return nativeGetNextPutIndex(nativeWorkerContextPointer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Increment the task index and return the new value.
|
||||
*/
|
||||
public int nextTaskIndex() {
|
||||
taskIndex.set(taskIndex.get() + 1);
|
||||
return taskIndex.get();
|
||||
return nativeGetNextTaskIndex(nativeWorkerContextPointer);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The ID of the current worker.
|
||||
*/
|
||||
public UniqueId getCurrentWorkerId() {
|
||||
return workerId;
|
||||
return new UniqueId(nativeGetCurrentWorkerId(nativeWorkerContextPointer));
|
||||
}
|
||||
|
||||
/**
|
||||
* The ID of the current job.
|
||||
*/
|
||||
public JobId getCurrentJobId() {
|
||||
return currentJobId;
|
||||
return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeWorkerContextPointer));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -134,6 +108,32 @@ public class WorkerContext {
|
||||
* Get the current task.
|
||||
*/
|
||||
public TaskSpec getCurrentTask() {
|
||||
return this.currentTask.get();
|
||||
byte[] bytes = nativeGetCurrentTask(nativeWorkerContextPointer);
|
||||
if (bytes == null) {
|
||||
return null;
|
||||
}
|
||||
return RayletClientImpl.parseTaskSpecFromProtobuf(bytes);
|
||||
}
|
||||
|
||||
public void destroy() {
|
||||
nativeDestroy(nativeWorkerContextPointer);
|
||||
}
|
||||
|
||||
private static native long nativeCreateWorkerContext(int workerType, byte[] jobId);
|
||||
|
||||
private static native byte[] nativeGetCurrentTaskId(long nativeWorkerContextPointer);
|
||||
|
||||
private static native void nativeSetCurrentTask(long nativeWorkerContextPointer, byte[] taskSpec);
|
||||
|
||||
private static native byte[] nativeGetCurrentTask(long nativeWorkerContextPointer);
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentJobId(long nativeWorkerContextPointer);
|
||||
|
||||
private static native byte[] nativeGetCurrentWorkerId(long nativeWorkerContextPointer);
|
||||
|
||||
private static native int nativeGetNextTaskIndex(long nativeWorkerContextPointer);
|
||||
|
||||
private static native int nativeGetNextPutIndex(long nativeWorkerContextPointer);
|
||||
|
||||
private static native void nativeDestroy(long nativeWorkerContextPointer);
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.ray.runtime.util.NetworkUtil;
|
||||
import org.ray.runtime.util.ResourceUtil;
|
||||
import org.ray.runtime.util.StringUtil;
|
||||
@@ -29,7 +30,7 @@ public class RayConfig {
|
||||
public static final String CUSTOM_CONFIG_FILE = "ray.conf";
|
||||
|
||||
public final String nodeIp;
|
||||
public final WorkerMode workerMode;
|
||||
public final WorkerType workerMode;
|
||||
public final RunMode runMode;
|
||||
public final Map<String, Double> resources;
|
||||
private JobId jobId;
|
||||
@@ -62,7 +63,7 @@ public class RayConfig {
|
||||
public final int numberExecThreadsForDevRuntime;
|
||||
|
||||
private void validate() {
|
||||
if (workerMode == WorkerMode.WORKER) {
|
||||
if (workerMode == WorkerType.WORKER) {
|
||||
Preconditions.checkArgument(redisAddress != null,
|
||||
"Redis address must be set in worker mode.");
|
||||
}
|
||||
@@ -78,14 +79,14 @@ public class RayConfig {
|
||||
|
||||
public RayConfig(Config config) {
|
||||
// Worker mode.
|
||||
WorkerMode localWorkerMode;
|
||||
WorkerType localWorkerMode;
|
||||
try {
|
||||
localWorkerMode = config.getEnum(WorkerMode.class, "ray.worker.mode");
|
||||
localWorkerMode = config.getEnum(WorkerType.class, "ray.worker.mode");
|
||||
} catch (ConfigException.Missing e) {
|
||||
localWorkerMode = WorkerMode.DRIVER;
|
||||
localWorkerMode = WorkerType.DRIVER;
|
||||
}
|
||||
workerMode = localWorkerMode;
|
||||
boolean isDriver = workerMode == WorkerMode.DRIVER;
|
||||
boolean isDriver = workerMode == WorkerType.DRIVER;
|
||||
// Run mode.
|
||||
runMode = config.getEnum(RunMode.class, "ray.run-mode");
|
||||
// Node ip.
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
package org.ray.runtime.config;
|
||||
|
||||
public enum WorkerMode {
|
||||
DRIVER,
|
||||
WORKER
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package org.ray.runtime.objectstore;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.WorkerContext;
|
||||
import org.ray.runtime.util.IdUtil;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class MockObjectInterface implements ObjectInterface {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectInterface.class);
|
||||
|
||||
private static final int GET_CHECK_INTERVAL_MS = 100;
|
||||
|
||||
private final Map<ObjectId, NativeRayObject> pool = new ConcurrentHashMap<>();
|
||||
private final List<Consumer<ObjectId>> objectPutCallbacks = new ArrayList<>();
|
||||
private final WorkerContext workerContext;
|
||||
|
||||
public MockObjectInterface(WorkerContext workerContext) {
|
||||
this.workerContext = workerContext;
|
||||
}
|
||||
|
||||
public void addObjectPutCallback(Consumer<ObjectId> callback) {
|
||||
this.objectPutCallbacks.add(callback);
|
||||
}
|
||||
|
||||
public boolean isObjectReady(ObjectId id) {
|
||||
return pool.containsKey(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObjectId put(NativeRayObject obj) {
|
||||
ObjectId objectId = IdUtil.computePutId(workerContext.getCurrentTaskId(),
|
||||
workerContext.nextPutIndex());
|
||||
put(obj, objectId);
|
||||
return objectId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void put(NativeRayObject obj, ObjectId objectId) {
|
||||
Preconditions.checkNotNull(obj);
|
||||
Preconditions.checkNotNull(objectId);
|
||||
pool.putIfAbsent(objectId, obj);
|
||||
for (Consumer<ObjectId> callback : objectPutCallbacks) {
|
||||
callback.accept(objectId);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs) {
|
||||
waitInternal(objectIds, objectIds.size(), timeoutMs);
|
||||
return objectIds.stream().map(pool::get).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
|
||||
waitInternal(objectIds, numObjects, timeoutMs);
|
||||
return objectIds.stream().map(pool::containsKey).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private void waitInternal(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
|
||||
int ready = 0;
|
||||
long remainingTime = timeoutMs;
|
||||
boolean firstCheck = true;
|
||||
while (ready < numObjects && (timeoutMs < 0 || remainingTime > 0)) {
|
||||
if (!firstCheck) {
|
||||
long sleepTime = Math.min(remainingTime, GET_CHECK_INTERVAL_MS);
|
||||
try {
|
||||
Thread.sleep(sleepTime);
|
||||
} catch (InterruptedException e) {
|
||||
LOGGER.warn("Got InterruptedException while sleeping.");
|
||||
}
|
||||
remainingTime -= sleepTime;
|
||||
}
|
||||
ready = 0;
|
||||
for (ObjectId objectId : objectIds) {
|
||||
if (pool.containsKey(objectId)) {
|
||||
ready += 1;
|
||||
}
|
||||
}
|
||||
firstCheck = false;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
for (ObjectId objectId : objectIds) {
|
||||
pool.remove(objectId);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,148 +0,0 @@
|
||||
package org.ray.runtime.objectstore;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.arrow.plasma.ObjectStoreLink;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.RayDevRuntime;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* A mock implementation of {@code org.ray.spi.ObjectStoreLink}, which use Map to store data.
|
||||
*/
|
||||
public class MockObjectStore implements ObjectStoreLink {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectStore.class);
|
||||
|
||||
private static final int GET_CHECK_INTERVAL_MS = 100;
|
||||
|
||||
private final RayDevRuntime runtime;
|
||||
private final Map<ObjectId, byte[]> data = new ConcurrentHashMap<>();
|
||||
private final Map<ObjectId, byte[]> metadata = new ConcurrentHashMap<>();
|
||||
private final List<Consumer<ObjectId>> objectPutCallbacks;
|
||||
|
||||
public MockObjectStore(RayDevRuntime runtime) {
|
||||
this.runtime = runtime;
|
||||
this.objectPutCallbacks = new ArrayList<>();
|
||||
}
|
||||
|
||||
public void addObjectPutCallback(Consumer<ObjectId> callback) {
|
||||
this.objectPutCallbacks.add(callback);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void put(byte[] objectId, byte[] value, byte[] metadataValue) {
|
||||
if (objectId == null || objectId.length == 0 || value == null) {
|
||||
LOGGER
|
||||
.error("{} cannot put null: {}, {}", logPrefix(), objectId, Arrays.toString(value));
|
||||
System.exit(-1);
|
||||
}
|
||||
ObjectId id = new ObjectId(objectId);
|
||||
data.put(id, value);
|
||||
if (metadataValue != null) {
|
||||
metadata.put(id, metadataValue);
|
||||
}
|
||||
for (Consumer<ObjectId> callback : objectPutCallbacks) {
|
||||
callback.accept(id);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] get(byte[] objectId, int timeoutMs, boolean isMetadata) {
|
||||
return get(new byte[][] {objectId}, timeoutMs, isMetadata).get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<byte[]> get(byte[][] objectIds, int timeoutMs, boolean isMetadata) {
|
||||
return get(objectIds, timeoutMs)
|
||||
.stream()
|
||||
.map(data -> isMetadata ? data.metadata : data.data)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ObjectStoreData> get(byte[][] objectIds, int timeoutMs) {
|
||||
int ready = 0;
|
||||
int remainingTime = timeoutMs;
|
||||
boolean firstCheck = true;
|
||||
while (ready < objectIds.length && remainingTime > 0) {
|
||||
if (!firstCheck) {
|
||||
int sleepTime = Math.min(remainingTime, GET_CHECK_INTERVAL_MS);
|
||||
try {
|
||||
Thread.sleep(sleepTime);
|
||||
} catch (InterruptedException e) {
|
||||
LOGGER.warn("Got InterruptedException while sleeping.");
|
||||
}
|
||||
remainingTime -= sleepTime;
|
||||
}
|
||||
ready = 0;
|
||||
for (byte[] id : objectIds) {
|
||||
if (data.containsKey(new ObjectId(id))) {
|
||||
ready += 1;
|
||||
}
|
||||
}
|
||||
firstCheck = false;
|
||||
}
|
||||
ArrayList<ObjectStoreData> rets = new ArrayList<>();
|
||||
for (byte[] objId : objectIds) {
|
||||
ObjectId objectId = new ObjectId(objId);
|
||||
rets.add(new ObjectStoreData(metadata.get(objectId), data.get(objectId)));
|
||||
}
|
||||
return rets;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] hash(byte[] objectId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long evict(long numBytes) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void release(byte[] objectId) {
|
||||
return;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(byte[] objectId) {
|
||||
return;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean contains(byte[] objectId) {
|
||||
return data.containsKey(new ObjectId(objectId));
|
||||
}
|
||||
|
||||
private String logPrefix() {
|
||||
return runtime.getWorkerContext().getCurrentTaskId() + "-" + getUserTrace() + " -> ";
|
||||
}
|
||||
|
||||
private String getUserTrace() {
|
||||
StackTraceElement[] stes = Thread.currentThread().getStackTrace();
|
||||
int k = 1;
|
||||
while (stes[k].getClassName().startsWith("org.ray")
|
||||
&& !stes[k].getClassName().contains("test")) {
|
||||
k++;
|
||||
}
|
||||
return stes[k].getFileName() + ":" + stes[k].getLineNumber();
|
||||
}
|
||||
|
||||
public boolean isObjectReady(ObjectId id) {
|
||||
return data.containsKey(id);
|
||||
}
|
||||
|
||||
public void free(ObjectId id) {
|
||||
data.remove(id);
|
||||
metadata.remove(id);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package org.ray.runtime.objectstore;
|
||||
|
||||
public class NativeRayObject {
|
||||
|
||||
public byte[] data;
|
||||
public byte[] metadata;
|
||||
|
||||
public NativeRayObject(byte[] data, byte[] metadata) {
|
||||
this.data = data;
|
||||
this.metadata = metadata;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package org.ray.runtime.objectstore;
|
||||
|
||||
import java.util.List;
|
||||
import org.ray.api.id.ObjectId;
|
||||
|
||||
/**
|
||||
* The interface that contains all worker methods that are related to object store.
|
||||
*/
|
||||
public interface ObjectInterface {
|
||||
|
||||
/**
|
||||
* Put an object into object store.
|
||||
*
|
||||
* @param obj The ray object.
|
||||
* @return Generated ID of the object.
|
||||
*/
|
||||
ObjectId put(NativeRayObject obj);
|
||||
|
||||
/**
|
||||
* Put an object with specified ID into object store.
|
||||
*
|
||||
* @param obj The ray object.
|
||||
* @param objectId Object ID specified by user.
|
||||
*/
|
||||
void put(NativeRayObject obj, ObjectId objectId);
|
||||
|
||||
/**
|
||||
* Get a list of objects from the object store.
|
||||
*
|
||||
* @param objectIds IDs of the objects to get.
|
||||
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
|
||||
* @return Result list of objects data.
|
||||
*/
|
||||
List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs);
|
||||
|
||||
/**
|
||||
* Wait for a list of objects to appear in the object store.
|
||||
*
|
||||
* @param objectIds IDs of the objects to wait for.
|
||||
* @param numObjects Number of objects that should appear.
|
||||
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
|
||||
* @return A bitset that indicates each object has appeared or not.
|
||||
*/
|
||||
List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs);
|
||||
|
||||
/**
|
||||
* Delete a list of objects from the object store.
|
||||
*
|
||||
* @param objectIds IDs of the objects to delete.
|
||||
* @param localOnly Whether only delete the objects in local node, or all nodes in the cluster.
|
||||
* @param deleteCreatingTasks Whether also delete the tasks that created these objects.
|
||||
*/
|
||||
void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks);
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package org.ray.runtime.objectstore;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.id.BaseId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.WorkerContext;
|
||||
import org.ray.runtime.raylet.RayletClient;
|
||||
import org.ray.runtime.raylet.RayletClientImpl;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* This is a wrapper class for core worker object interface.
|
||||
*/
|
||||
public class ObjectInterfaceImpl implements ObjectInterface {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
|
||||
|
||||
/**
|
||||
* The native pointer of core worker object interface.
|
||||
*/
|
||||
private final long nativeObjectInterfacePointer;
|
||||
|
||||
public ObjectInterfaceImpl(WorkerContext workerContext, RayletClient rayletClient,
|
||||
String storeSocketName) {
|
||||
this.nativeObjectInterfacePointer =
|
||||
nativeCreateObjectInterface(workerContext.getNativeWorkerContext(),
|
||||
((RayletClientImpl) rayletClient).getClient(), storeSocketName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObjectId put(NativeRayObject obj) {
|
||||
return new ObjectId(nativePut(nativeObjectInterfacePointer, obj));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void put(NativeRayObject obj, ObjectId objectId) {
|
||||
try {
|
||||
nativePut(nativeObjectInterfacePointer, objectId.getBytes(), obj);
|
||||
} catch (RayException e) {
|
||||
LOGGER.warn(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs) {
|
||||
return nativeGet(nativeObjectInterfacePointer, toBinaryList(objectIds), timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
|
||||
return nativeWait(nativeObjectInterfacePointer, toBinaryList(objectIds), numObjects, timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
nativeDelete(nativeObjectInterfacePointer, toBinaryList(objectIds), localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
public void destroy() {
|
||||
nativeDestroy(nativeObjectInterfacePointer);
|
||||
}
|
||||
|
||||
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
|
||||
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static native long nativeCreateObjectInterface(long nativeObjectInterface,
|
||||
long nativeRayletClient,
|
||||
String storeSocketName);
|
||||
|
||||
private static native byte[] nativePut(long nativeObjectInterface, NativeRayObject obj);
|
||||
|
||||
private static native void nativePut(long nativeObjectInterface, byte[] objectId,
|
||||
NativeRayObject obj);
|
||||
|
||||
private static native List<NativeRayObject> nativeGet(long nativeObjectInterface,
|
||||
List<byte[]> ids,
|
||||
long timeoutMs);
|
||||
|
||||
private static native List<Boolean> nativeWait(long nativeObjectInterface, List<byte[]> objectIds,
|
||||
int numObjects, long timeoutMs);
|
||||
|
||||
private static native void nativeDelete(long nativeObjectInterface, List<byte[]> objectIds,
|
||||
boolean localOnly, boolean deleteCreatingTasks);
|
||||
|
||||
private static native void nativeDestroy(long nativeObjectInterface);
|
||||
}
|
||||
@@ -4,20 +4,14 @@ import com.google.common.collect.ImmutableList;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.apache.arrow.plasma.ObjectStoreLink;
|
||||
import org.apache.arrow.plasma.ObjectStoreLink.ObjectStoreData;
|
||||
import org.apache.arrow.plasma.PlasmaClient;
|
||||
import org.apache.arrow.plasma.exceptions.DuplicateObjectException;
|
||||
import org.ray.api.exception.RayActorException;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.exception.RayWorkerException;
|
||||
import org.ray.api.exception.UnreconstructableException;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayDevRuntime;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.WorkerContext;
|
||||
import org.ray.runtime.generated.Gcs.ErrorType;
|
||||
import org.ray.runtime.util.IdUtil;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -36,21 +30,18 @@ public class ObjectStoreProxy {
|
||||
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
|
||||
|
||||
private final AbstractRayRuntime runtime;
|
||||
private final WorkerContext workerContext;
|
||||
|
||||
private static ThreadLocal<ObjectStoreLink> objectStore;
|
||||
private final ObjectInterface objectInterface;
|
||||
|
||||
public ObjectStoreProxy(AbstractRayRuntime runtime, String storeSocketName) {
|
||||
this.runtime = runtime;
|
||||
objectStore = ThreadLocal.withInitial(() -> {
|
||||
if (runtime.getRayConfig().runMode == RunMode.CLUSTER) {
|
||||
return new PlasmaClient(storeSocketName, "", 0);
|
||||
} else {
|
||||
return ((RayDevRuntime) runtime).getObjectStore();
|
||||
}
|
||||
});
|
||||
public ObjectStoreProxy(WorkerContext workerContext, ObjectInterface objectInterface) {
|
||||
this.workerContext = workerContext;
|
||||
this.objectInterface = objectInterface;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -75,46 +66,44 @@ public class ObjectStoreProxy {
|
||||
* @return A list of GetResult objects.
|
||||
*/
|
||||
public <T> List<GetResult<T>> get(List<ObjectId> ids, int timeoutMs) {
|
||||
byte[][] binaryIds = IdUtil.getIdBytes(ids);
|
||||
List<ObjectStoreData> dataAndMetaList = objectStore.get().get(binaryIds, timeoutMs);
|
||||
List<NativeRayObject> dataAndMetaList = objectInterface.get(ids, timeoutMs);
|
||||
|
||||
List<GetResult<T>> results = new ArrayList<>();
|
||||
for (int i = 0; i < dataAndMetaList.size(); i++) {
|
||||
byte[] meta = dataAndMetaList.get(i).metadata;
|
||||
byte[] data = dataAndMetaList.get(i).data;
|
||||
|
||||
NativeRayObject dataAndMeta = dataAndMetaList.get(i);
|
||||
GetResult<T> result;
|
||||
if (meta != null) {
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
result = deserializeFromMeta(meta, data, ids.get(i));
|
||||
} else if (data != null) {
|
||||
// If data is not null, deserialize the Java object.
|
||||
Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader());
|
||||
if (object instanceof RayException) {
|
||||
// If the object is a `RayException`, it means that an error occurred during task
|
||||
// execution.
|
||||
result = new GetResult<>(true, null, (RayException) object);
|
||||
if (dataAndMeta != null) {
|
||||
byte[] meta = dataAndMeta.metadata;
|
||||
byte[] data = dataAndMeta.data;
|
||||
if (meta != null && meta.length > 0) {
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
result = deserializeFromMeta(meta, data,
|
||||
workerContext.getCurrentClassLoader(), ids.get(i));
|
||||
} else {
|
||||
// Otherwise, the object is valid.
|
||||
result = new GetResult<>(true, (T) object, null);
|
||||
// If data is not null, deserialize the Java object.
|
||||
Object object = Serializer.decode(data, workerContext.getCurrentClassLoader());
|
||||
if (object instanceof RayException) {
|
||||
// If the object is a `RayException`, it means that an error occurred during task
|
||||
// execution.
|
||||
result = new GetResult<>(true, null, (RayException) object);
|
||||
} else {
|
||||
// Otherwise, the object is valid.
|
||||
result = new GetResult<>(true, (T) object, null);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If both meta and data are null, the object doesn't exist in object store.
|
||||
result = new GetResult<>(false, null, null);
|
||||
}
|
||||
|
||||
if (meta != null || data != null) {
|
||||
// Release the object from object store..
|
||||
objectStore.get().release(binaryIds[i]);
|
||||
}
|
||||
|
||||
results.add(result);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private <T> GetResult<T> deserializeFromMeta(byte[] meta, byte[] data, ObjectId objectId) {
|
||||
private <T> GetResult<T> deserializeFromMeta(byte[] meta, byte[] data,
|
||||
ClassLoader classLoader, ObjectId objectId) {
|
||||
if (Arrays.equals(meta, RAW_TYPE_META)) {
|
||||
return (GetResult<T>) new GetResult<>(true, data, null);
|
||||
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
|
||||
@@ -123,6 +112,8 @@ public class ObjectStoreProxy {
|
||||
return new GetResult<>(true, null, RayActorException.INSTANCE);
|
||||
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
|
||||
return new GetResult<>(true, null, new UnreconstructableException(objectId));
|
||||
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
|
||||
return new GetResult<>(true, null, Serializer.decode(data, classLoader));
|
||||
}
|
||||
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
|
||||
}
|
||||
@@ -134,16 +125,14 @@ public class ObjectStoreProxy {
|
||||
* @param object The object to put.
|
||||
*/
|
||||
public void put(ObjectId id, Object object) {
|
||||
try {
|
||||
if (object instanceof byte[]) {
|
||||
// If the object is a byte array, skip serializing it and use a special metadata to
|
||||
// indicate it's raw binary. So that this object can also be read by Python.
|
||||
objectStore.get().put(id.getBytes(), (byte[]) object, RAW_TYPE_META);
|
||||
} else {
|
||||
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
|
||||
}
|
||||
} catch (DuplicateObjectException e) {
|
||||
LOGGER.warn(e.getMessage());
|
||||
if (object instanceof byte[]) {
|
||||
// If the object is a byte array, skip serializing it and use a special metadata to
|
||||
// indicate it's raw binary. So that this object can also be read by Python.
|
||||
objectInterface.put(new NativeRayObject((byte[]) object, RAW_TYPE_META), id);
|
||||
} else if (object instanceof RayTaskException) {
|
||||
objectInterface.put(new NativeRayObject(Serializer.encode(object), TASK_EXECUTION_EXCEPTION_META), id);
|
||||
} else {
|
||||
objectInterface.put(new NativeRayObject(Serializer.encode(object), null), id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,11 +143,7 @@ public class ObjectStoreProxy {
|
||||
* @param serializedObject The serialized object to put.
|
||||
*/
|
||||
public void putSerialized(ObjectId id, byte[] serializedObject) {
|
||||
try {
|
||||
objectStore.get().put(id.getBytes(), serializedObject, null);
|
||||
} catch (DuplicateObjectException e) {
|
||||
LOGGER.warn(e.getMessage());
|
||||
}
|
||||
objectInterface.put(new NativeRayObject(serializedObject, null), id);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -14,6 +14,7 @@ import java.util.concurrent.ConcurrentLinkedDeque;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.WaitResult;
|
||||
@@ -23,7 +24,8 @@ import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.RayDevRuntime;
|
||||
import org.ray.runtime.Worker;
|
||||
import org.ray.runtime.objectstore.MockObjectStore;
|
||||
import org.ray.runtime.objectstore.MockObjectInterface;
|
||||
import org.ray.runtime.objectstore.NativeRayObject;
|
||||
import org.ray.runtime.task.FunctionArg;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.slf4j.Logger;
|
||||
@@ -37,7 +39,7 @@ public class MockRayletClient implements RayletClient {
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(MockRayletClient.class);
|
||||
|
||||
private final Map<ObjectId, Set<TaskSpec>> waitingTasks = new ConcurrentHashMap<>();
|
||||
private final MockObjectStore store;
|
||||
private final MockObjectInterface objectInterface;
|
||||
private final RayDevRuntime runtime;
|
||||
private final ExecutorService exec;
|
||||
private final Deque<Worker> idleWorkers;
|
||||
@@ -46,8 +48,8 @@ public class MockRayletClient implements RayletClient {
|
||||
|
||||
public MockRayletClient(RayDevRuntime runtime, int numberThreads) {
|
||||
this.runtime = runtime;
|
||||
this.store = runtime.getObjectStore();
|
||||
store.addObjectPutCallback(this::onObjectPut);
|
||||
this.objectInterface = runtime.getObjectInterface();
|
||||
objectInterface.addObjectPutCallback(this::onObjectPut);
|
||||
// The thread pool that executes tasks in parallel.
|
||||
exec = Executors.newFixedThreadPool(numberThreads);
|
||||
idleWorkers = new ConcurrentLinkedDeque<>();
|
||||
@@ -113,8 +115,8 @@ public class MockRayletClient implements RayletClient {
|
||||
// can be executed.
|
||||
if (task.isActorCreationTask() || task.isActorTask()) {
|
||||
ObjectId[] returnIds = task.returnIds;
|
||||
store.put(returnIds[returnIds.length - 1].getBytes(),
|
||||
new byte[]{}, new byte[]{});
|
||||
objectInterface.put(new NativeRayObject(new byte[] {}, new byte[] {}),
|
||||
returnIds[returnIds.length - 1]);
|
||||
}
|
||||
} finally {
|
||||
returnWorker(worker);
|
||||
@@ -133,13 +135,13 @@ public class MockRayletClient implements RayletClient {
|
||||
// Check whether task arguments are ready.
|
||||
for (FunctionArg arg : spec.args) {
|
||||
if (arg.id != null) {
|
||||
if (!store.isObjectReady(arg.id)) {
|
||||
if (!objectInterface.isObjectReady(arg.id)) {
|
||||
unreadyObjects.add(arg.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (spec.isActorTask()) {
|
||||
if (!store.isObjectReady(spec.previousActorTaskDummyObjectId)) {
|
||||
if (!objectInterface.isObjectReady(spec.previousActorTaskDummyObjectId)) {
|
||||
unreadyObjects.add(spec.previousActorTaskDummyObjectId);
|
||||
}
|
||||
}
|
||||
@@ -154,7 +156,7 @@ public class MockRayletClient implements RayletClient {
|
||||
|
||||
@Override
|
||||
public void fetchOrReconstruct(List<ObjectId> objectIds, boolean fetchOnly,
|
||||
TaskId currentTaskId) {
|
||||
TaskId currentTaskId) {
|
||||
|
||||
}
|
||||
|
||||
@@ -170,20 +172,17 @@ public class MockRayletClient implements RayletClient {
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
|
||||
timeoutMs, TaskId currentTaskId) {
|
||||
timeoutMs, TaskId currentTaskId) {
|
||||
if (waitFor == null || waitFor.isEmpty()) {
|
||||
return new WaitResult<>(ImmutableList.of(), ImmutableList.of());
|
||||
}
|
||||
|
||||
byte[][] ids = new byte[waitFor.size()][];
|
||||
for (int i = 0; i < waitFor.size(); i++) {
|
||||
ids[i] = waitFor.get(i).getId().getBytes();
|
||||
}
|
||||
List<ObjectId> ids = waitFor.stream().map(RayObject::getId).collect(Collectors.toList());
|
||||
List<RayObject<T>> readyList = new ArrayList<>();
|
||||
List<RayObject<T>> unreadyList = new ArrayList<>();
|
||||
List<byte[]> result = store.get(ids, timeoutMs, false);
|
||||
List<Boolean> result = objectInterface.wait(ids, ids.size(), timeoutMs);
|
||||
for (int i = 0; i < waitFor.size(); i++) {
|
||||
if (result.get(i) != null) {
|
||||
if (result.get(i)) {
|
||||
readyList.add(waitFor.get(i));
|
||||
} else {
|
||||
unreadyList.add(waitFor.get(i));
|
||||
@@ -195,9 +194,7 @@ public class MockRayletClient implements RayletClient {
|
||||
@Override
|
||||
public void freePlasmaObjects(List<ObjectId> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks) {
|
||||
for (ObjectId id : objectIds) {
|
||||
store.free(id);
|
||||
}
|
||||
objectInterface.delete(objectIds, localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,8 +4,6 @@ import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
@@ -40,11 +38,15 @@ public class RayletClientImpl implements RayletClient {
|
||||
|
||||
// TODO(qwang): JobId parameter can be removed once we embed jobId in driverId.
|
||||
public RayletClientImpl(String schedulerSockName, UniqueId clientId,
|
||||
boolean isWorker, JobId jobId) {
|
||||
boolean isWorker, JobId jobId) {
|
||||
client = nativeInit(schedulerSockName, clientId.getBytes(),
|
||||
isWorker, jobId.getBytes());
|
||||
}
|
||||
|
||||
public long getClient() {
|
||||
return client;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
|
||||
timeoutMs, TaskId currentTaskId) {
|
||||
@@ -133,7 +135,7 @@ public class RayletClientImpl implements RayletClient {
|
||||
/**
|
||||
* Parse `TaskSpec` protobuf bytes.
|
||||
*/
|
||||
private static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
|
||||
public static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
|
||||
Common.TaskSpec taskSpec;
|
||||
try {
|
||||
taskSpec = Common.TaskSpec.parseFrom(bytes);
|
||||
@@ -214,7 +216,7 @@ public class RayletClientImpl implements RayletClient {
|
||||
/**
|
||||
* Convert a `TaskSpec` to protobuf-serialized bytes.
|
||||
*/
|
||||
private static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
|
||||
public static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
|
||||
// Set common fields.
|
||||
Common.TaskSpec.Builder builder = Common.TaskSpec.newBuilder()
|
||||
.setJobId(ByteString.copyFrom(task.jobId.getBytes()))
|
||||
|
||||
@@ -154,18 +154,6 @@ public class IdUtil {
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the driver id from the given job.
|
||||
*/
|
||||
public static UniqueId computeDriverId(JobId jobId) {
|
||||
byte[] bytes = new byte[UniqueId.LENGTH];
|
||||
System.arraycopy(jobId.getBytes(), 0, bytes, 0, jobId.size());
|
||||
Arrays.fill(bytes, jobId.size(), UniqueId.LENGTH, (byte)0xFF);
|
||||
ByteBuffer wbb = ByteBuffer.wrap(bytes);
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
return new UniqueId(bytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the murmur hash code of this ID.
|
||||
*/
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.exception.RayActorException;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.exception.RayWorkerException;
|
||||
import org.ray.api.function.RayFunc0;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@@ -23,6 +29,15 @@ public class FailureTest extends BaseTest {
|
||||
return 0;
|
||||
}
|
||||
|
||||
public static int slowFunc() {
|
||||
try {
|
||||
Thread.sleep(10000);
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
public static class BadActor {
|
||||
|
||||
public BadActor(boolean failOnCreation) {
|
||||
@@ -106,5 +121,26 @@ public class FailureTest extends BaseTest {
|
||||
// RayActorException.
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetThrowsQuicklyWhenFoundException() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
List<RayFunc0<Integer>> badFunctions = Arrays.asList(FailureTest::badFunc,
|
||||
FailureTest::badFunc2);
|
||||
for (RayFunc0<Integer> badFunc : badFunctions) {
|
||||
RayObject<Integer> obj1 = Ray.call(badFunc);
|
||||
RayObject<Integer> obj2 = Ray.call(FailureTest::slowFunc);
|
||||
Instant start = Instant.now();
|
||||
try {
|
||||
Ray.get(Arrays.asList(obj1.getId(), obj2.getId()));
|
||||
Assert.fail("Should throw RayException.");
|
||||
} catch (RayException e) {
|
||||
Instant end = Instant.now();
|
||||
long duration = Duration.between(start, end).toMillis();
|
||||
Assert.assertTrue(duration < 5000, "Should fail quickly. " +
|
||||
"Actual execution time: " + duration + " ms.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import org.apache.arrow.plasma.PlasmaClient;
|
||||
import org.apache.arrow.plasma.exceptions.DuplicateObjectException;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@@ -15,15 +13,13 @@ public class PlasmaStoreTest extends BaseTest {
|
||||
@Test
|
||||
public void testPutWithDuplicateId() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
UniqueId objectId = UniqueId.randomId();
|
||||
ObjectId objectId = ObjectId.randomId();
|
||||
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
|
||||
PlasmaClient store = new PlasmaClient(runtime.getRayConfig().objectStoreSocketName, "", 0);
|
||||
store.put(objectId.getBytes(), new byte[]{}, new byte[]{});
|
||||
try {
|
||||
store.put(objectId.getBytes(), new byte[]{}, new byte[]{});
|
||||
Assert.fail("This line shouldn't be reached.");
|
||||
} catch (DuplicateObjectException e) {
|
||||
// Putting 2 objects with duplicate ID should throw DuplicateObjectException.
|
||||
}
|
||||
ObjectStoreProxy objectInterface = runtime.getObjectStoreProxy();
|
||||
objectInterface.put(objectId, 1);
|
||||
Assert.assertEquals(objectInterface.<Integer>get(objectId, -1).object, (Integer) 1);
|
||||
objectInterface.put(objectId, 2);
|
||||
// Putting 2 objects with duplicate ID should fail but ignored.
|
||||
Assert.assertEquals(objectInterface.<Integer>get(objectId, -1).object, (Integer) 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.WorkerMode;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@@ -12,7 +12,7 @@ public class RayConfigTest {
|
||||
try {
|
||||
System.setProperty("ray.job.resource-path", "path/to/ray/job/resource/path");
|
||||
RayConfig rayConfig = RayConfig.create();
|
||||
Assert.assertEquals(WorkerMode.DRIVER, rayConfig.workerMode);
|
||||
Assert.assertEquals(WorkerType.DRIVER, rayConfig.workerMode);
|
||||
Assert.assertEquals("path/to/ray/job/resource/path", rayConfig.jobResourcePath);
|
||||
} finally {
|
||||
// Unset system properties.
|
||||
|
||||
Reference in New Issue
Block a user