[Java worker] Refactor object store and worker context on top of core worker (#5079)

This commit is contained in:
Kai Yang
2019-07-16 20:58:02 +08:00
committed by Hao Chen
parent e5be5fd46d
commit 806524384b
40 changed files with 1386 additions and 571 deletions
+1 -8
View File
@@ -69,7 +69,6 @@ define_java_module(
],
deps = [
":org_ray_ray_api",
"@plasma//:org_apache_arrow_arrow_plasma",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:com_typesafe_config",
@@ -97,7 +96,6 @@ define_java_module(
deps = [
":org_ray_ray_api",
":org_ray_ray_runtime",
"@plasma//:org_apache_arrow_arrow_plasma",
"@maven//:com_google_guava_guava",
"@maven//:com_sun_xml_bind_jaxb_core",
"@maven//:com_sun_xml_bind_jaxb_impl",
@@ -176,9 +174,8 @@ filegroup(
"//:redis-server",
"//:libray_redis_module.so",
"//:raylet",
"//:raylet_library_java",
"//:core_worker_library_java",
"@plasma//:plasma_store_server",
"@plasma//:plasma_client_java",
],
)
@@ -189,7 +186,6 @@ genrule(
":all_java_proto",
":java_native_deps",
":copy_pom_file",
"@plasma//:org_apache_arrow_arrow_plasma",
],
outs = ["gen_maven_deps.out"],
cmd = """
@@ -208,9 +204,6 @@ genrule(
chmod +w $$f
cp $$f $$NATIVE_DEPS_DIR
done
# Install plasma jar to local maven repo.
mvn install:install-file -Dfile=$(locations @plasma//:org_apache_arrow_arrow_plasma) -Dpackaging=jar \
-DgroupId=org.apache.arrow -DartifactId=arrow-plasma -Dversion=0.13.0-SNAPSHOT
echo $$(date) > $@
""",
local = 1,
-5
View File
@@ -24,11 +24,6 @@
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-plasma</artifactId>
<version>0.13.0-SNAPSHOT</version>
</dependency>
</dependencies>
</dependencyManagement>
-4
View File
@@ -22,10 +22,6 @@
<artifactId>ray-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-plasma</artifactId>
</dependency>
<dependency>
<groupId>com.beust</groupId>
<artifactId>jcommander</artifactId>
-4
View File
@@ -22,10 +22,6 @@
<artifactId>ray-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-plasma</artifactId>
</dependency>
{generated_bzl_deps}
</dependencies>
@@ -1,7 +1,15 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -72,6 +80,27 @@ public abstract class AbstractRayRuntime implements RayRuntime {
protected RuntimeContext runtimeContext;
protected GcsClient gcsClient;
static {
try {
LOGGER.debug("Loading native libraries.");
// Load native libraries.
String[] libraries = new String[]{"core_worker_library_java"};
for (String library : libraries) {
String fileName = System.mapLibraryName(library);
// Copy the file from resources to a temp dir, and load the native library.
File file = File.createTempFile(fileName, "");
file.deleteOnExit();
InputStream in = AbstractRayRuntime.class.getResourceAsStream("/" + fileName);
Preconditions.checkNotNull(in, "{} doesn't exist.", fileName);
Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
System.load(file.getAbsolutePath());
}
LOGGER.debug("Native libraries loaded.");
} catch (IOException e) {
throw new RuntimeException("Couldn't load native libraries.", e);
}
}
public AbstractRayRuntime(RayConfig rayConfig) {
this.rayConfig = rayConfig;
functionManager = new FunctionManager(rayConfig.jobResourcePath);
@@ -79,6 +108,33 @@ public abstract class AbstractRayRuntime implements RayRuntime {
runtimeContext = new RuntimeContextImpl(this);
}
protected void resetLibraryPath() {
if (rayConfig.libraryPath.isEmpty()) {
return;
}
String path = System.getProperty("java.library.path");
if (Strings.isNullOrEmpty(path)) {
path = "";
} else {
path += ":";
}
path += String.join(":", rayConfig.libraryPath);
// This is a hack to reset library path at runtime,
// see https://stackoverflow.com/questions/15409223/.
System.setProperty("java.library.path", path);
// Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed.
final Field sysPathsField;
try {
sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
sysPathsField.setAccessible(true);
sysPathsField.set(null, null);
} catch (NoSuchFieldException | IllegalAccessException e) {
LOGGER.error("Failed to set library path.", e);
}
}
/**
* Start runtime.
*/
@@ -330,8 +386,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
* Create the task specification.
*
* @param func The target remote function.
* @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a
* Python task.
* @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a Python
* task.
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
* @param args The arguments for the remote function.
* @param isActorCreationTask Whether this task is an actor creation task.
@@ -3,7 +3,7 @@ package org.ray.runtime;
import java.util.concurrent.atomic.AtomicInteger;
import org.ray.api.id.JobId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.objectstore.MockObjectStore;
import org.ray.runtime.objectstore.MockObjectInterface;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.ray.runtime.raylet.MockRayletClient;
@@ -13,19 +13,22 @@ public class RayDevRuntime extends AbstractRayRuntime {
super(rayConfig);
}
private MockObjectStore store;
private MockObjectInterface objectInterface;
private AtomicInteger jobCounter = new AtomicInteger(0);
@Override
public void start() {
store = new MockObjectStore(this);
// Reset library path at runtime.
resetLibraryPath();
objectInterface = new MockObjectInterface(workerContext);
if (rayConfig.getJobId().isNil()) {
rayConfig.setJobId(nextJobId());
}
workerContext = new WorkerContext(rayConfig.workerMode,
rayConfig.getJobId(), rayConfig.runMode);
objectStoreProxy = new ObjectStoreProxy(this, null);
objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterface);
rayletClient = new MockRayletClient(this, rayConfig.numberExecThreadsForDevRuntime);
}
@@ -34,8 +37,8 @@ public class RayDevRuntime extends AbstractRayRuntime {
rayletClient.destroy();
}
public MockObjectStore getObjectStore() {
return store;
public MockObjectInterface getObjectInterface() {
return objectInterface;
}
@Override
@@ -1,21 +1,13 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.HashMap;
import java.util.Map;
import org.ray.api.id.JobId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.gcs.GcsClient;
import org.ray.runtime.gcs.RedisClient;
import org.ray.runtime.generated.Common.WorkerType;
import org.ray.runtime.objectstore.ObjectInterfaceImpl;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.ray.runtime.raylet.RayletClientImpl;
import org.ray.runtime.runner.RunManager;
@@ -31,58 +23,12 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
private RunManager manager = null;
static {
try {
LOGGER.debug("Loading native libraries.");
// Load native libraries.
String[] libraries = new String[]{"raylet_library_java", "plasma_java"};
for (String library : libraries) {
String fileName = System.mapLibraryName(library);
// Copy the file from resources to a temp dir, and load the native library.
File file = File.createTempFile(fileName, "");
file.deleteOnExit();
InputStream in = RayNativeRuntime.class.getResourceAsStream("/" + fileName);
Preconditions.checkNotNull(in, "{} doesn't exist.", fileName);
Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
System.load(file.getAbsolutePath());
}
LOGGER.debug("Native libraries loaded.");
} catch (IOException e) {
throw new RuntimeException("Couldn't load native libraries.", e);
}
}
private ObjectInterfaceImpl objectInterfaceImpl = null;
public RayNativeRuntime(RayConfig rayConfig) {
super(rayConfig);
}
private void resetLibraryPath() {
if (rayConfig.libraryPath.isEmpty()) {
return;
}
String path = System.getProperty("java.library.path");
if (Strings.isNullOrEmpty(path)) {
path = "";
} else {
path += ":";
}
path += String.join(":", rayConfig.libraryPath);
// This is a hack to reset library path at runtime,
// see https://stackoverflow.com/questions/15409223/.
System.setProperty("java.library.path", path);
// Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed.
final Field sysPathsField;
try {
sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
sysPathsField.setAccessible(true);
sysPathsField.set(null, null);
} catch (NoSuchFieldException | IllegalAccessException e) {
LOGGER.error("Failed to set library path.", e);
}
}
@Override
public void start() {
// Reset library path at runtime.
@@ -101,16 +47,18 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
workerContext = new WorkerContext(rayConfig.workerMode,
rayConfig.getJobId(), rayConfig.runMode);
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
rayletClient = new RayletClientImpl(
rayConfig.rayletSocketName,
workerContext.getCurrentWorkerId(),
rayConfig.workerMode == WorkerMode.WORKER,
rayConfig.workerMode == WorkerType.WORKER,
workerContext.getCurrentJobId()
);
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
objectInterfaceImpl = new ObjectInterfaceImpl(workerContext, rayletClient,
rayConfig.objectStoreSocketName);
objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterfaceImpl);
// register
registerWorker();
@@ -123,6 +71,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
if (null != manager) {
manager.cleanup();
}
objectInterfaceImpl.destroy();
workerContext.destroy();
}
/**
@@ -132,7 +82,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
RedisClient redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
Map<String, String> workerInfo = new HashMap<>();
String workerId = new String(workerContext.getCurrentWorkerId().getBytes());
if (rayConfig.workerMode == WorkerMode.DRIVER) {
if (rayConfig.workerMode == WorkerType.DRIVER) {
workerInfo.put("node_ip_address", rayConfig.nodeIp);
workerInfo.put("driver_id", workerId);
workerInfo.put("start_time", String.valueOf(System.currentTimeMillis()));
@@ -1,37 +1,24 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import java.nio.ByteBuffer;
import org.ray.api.id.JobId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.generated.Common.WorkerType;
import org.ray.runtime.raylet.RayletClientImpl;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.IdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This is a wrapper class for worker context of core worker.
*/
public class WorkerContext {
private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContext.class);
private UniqueId workerId;
private ThreadLocal<TaskId> currentTaskId;
/**
* Number of objects that have been put from current task.
* The native pointer of worker context of core worker.
*/
private ThreadLocal<Integer> putIndex;
/**
* Number of tasks that have been submitted from current task.
*/
private ThreadLocal<Integer> taskIndex;
private ThreadLocal<TaskSpec> currentTask;
private JobId currentJobId;
private final long nativeWorkerContextPointer;
private ClassLoader currentClassLoader;
@@ -45,31 +32,23 @@ public class WorkerContext {
*/
private RunMode runMode;
public WorkerContext(WorkerMode workerMode, JobId jobId, RunMode runMode) {
public WorkerContext(WorkerType workerType, JobId jobId, RunMode runMode) {
this.nativeWorkerContextPointer = nativeCreateWorkerContext(workerType.getNumber(), jobId.getBytes());
mainThreadId = Thread.currentThread().getId();
taskIndex = ThreadLocal.withInitial(() -> 0);
putIndex = ThreadLocal.withInitial(() -> 0);
currentTaskId = ThreadLocal.withInitial(TaskId::randomId);
this.runMode = runMode;
currentTask = ThreadLocal.withInitial(() -> null);
currentClassLoader = null;
if (workerMode == WorkerMode.DRIVER) {
workerId = IdUtil.computeDriverId(jobId);
currentTaskId.set(TaskId.randomId());
currentJobId = jobId;
} else {
workerId = UniqueId.randomId();
this.currentTaskId.set(TaskId.NIL);
this.currentJobId = JobId.NIL;
}
}
public long getNativeWorkerContext() {
return nativeWorkerContextPointer;
}
/**
* @return For the main thread, this method returns the ID of this worker's current running task;
* for other threads, this method returns a random ID.
* for other threads, this method returns a random ID.
*/
public TaskId getCurrentTaskId() {
return currentTaskId.get();
return new TaskId(nativeGetCurrentTaskId(nativeWorkerContextPointer));
}
/**
@@ -79,17 +58,14 @@ public class WorkerContext {
public void setCurrentTask(TaskSpec task, ClassLoader classLoader) {
if (runMode == RunMode.CLUSTER) {
Preconditions.checkState(
Thread.currentThread().getId() == mainThreadId,
"This method should only be called from the main thread."
Thread.currentThread().getId() == mainThreadId,
"This method should only be called from the main thread."
);
}
Preconditions.checkNotNull(task);
this.currentTaskId.set(task.taskId);
this.currentJobId = task.jobId;
taskIndex.set(0);
putIndex.set(0);
this.currentTask.set(task);
byte[] taskSpec = RayletClientImpl.convertTaskSpecToProtobuf(task);
nativeSetCurrentTask(nativeWorkerContextPointer, taskSpec);
currentClassLoader = classLoader;
}
@@ -97,30 +73,28 @@ public class WorkerContext {
* Increment the put index and return the new value.
*/
public int nextPutIndex() {
putIndex.set(putIndex.get() + 1);
return putIndex.get();
return nativeGetNextPutIndex(nativeWorkerContextPointer);
}
/**
* Increment the task index and return the new value.
*/
public int nextTaskIndex() {
taskIndex.set(taskIndex.get() + 1);
return taskIndex.get();
return nativeGetNextTaskIndex(nativeWorkerContextPointer);
}
/**
* @return The ID of the current worker.
*/
public UniqueId getCurrentWorkerId() {
return workerId;
return new UniqueId(nativeGetCurrentWorkerId(nativeWorkerContextPointer));
}
/**
* The ID of the current job.
*/
public JobId getCurrentJobId() {
return currentJobId;
return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeWorkerContextPointer));
}
/**
@@ -134,6 +108,32 @@ public class WorkerContext {
* Get the current task.
*/
public TaskSpec getCurrentTask() {
return this.currentTask.get();
byte[] bytes = nativeGetCurrentTask(nativeWorkerContextPointer);
if (bytes == null) {
return null;
}
return RayletClientImpl.parseTaskSpecFromProtobuf(bytes);
}
public void destroy() {
nativeDestroy(nativeWorkerContextPointer);
}
private static native long nativeCreateWorkerContext(int workerType, byte[] jobId);
private static native byte[] nativeGetCurrentTaskId(long nativeWorkerContextPointer);
private static native void nativeSetCurrentTask(long nativeWorkerContextPointer, byte[] taskSpec);
private static native byte[] nativeGetCurrentTask(long nativeWorkerContextPointer);
private static native ByteBuffer nativeGetCurrentJobId(long nativeWorkerContextPointer);
private static native byte[] nativeGetCurrentWorkerId(long nativeWorkerContextPointer);
private static native int nativeGetNextTaskIndex(long nativeWorkerContextPointer);
private static native int nativeGetNextPutIndex(long nativeWorkerContextPointer);
private static native void nativeDestroy(long nativeWorkerContextPointer);
}
@@ -11,6 +11,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.ray.api.id.JobId;
import org.ray.runtime.generated.Common.WorkerType;
import org.ray.runtime.util.NetworkUtil;
import org.ray.runtime.util.ResourceUtil;
import org.ray.runtime.util.StringUtil;
@@ -29,7 +30,7 @@ public class RayConfig {
public static final String CUSTOM_CONFIG_FILE = "ray.conf";
public final String nodeIp;
public final WorkerMode workerMode;
public final WorkerType workerMode;
public final RunMode runMode;
public final Map<String, Double> resources;
private JobId jobId;
@@ -62,7 +63,7 @@ public class RayConfig {
public final int numberExecThreadsForDevRuntime;
private void validate() {
if (workerMode == WorkerMode.WORKER) {
if (workerMode == WorkerType.WORKER) {
Preconditions.checkArgument(redisAddress != null,
"Redis address must be set in worker mode.");
}
@@ -78,14 +79,14 @@ public class RayConfig {
public RayConfig(Config config) {
// Worker mode.
WorkerMode localWorkerMode;
WorkerType localWorkerMode;
try {
localWorkerMode = config.getEnum(WorkerMode.class, "ray.worker.mode");
localWorkerMode = config.getEnum(WorkerType.class, "ray.worker.mode");
} catch (ConfigException.Missing e) {
localWorkerMode = WorkerMode.DRIVER;
localWorkerMode = WorkerType.DRIVER;
}
workerMode = localWorkerMode;
boolean isDriver = workerMode == WorkerMode.DRIVER;
boolean isDriver = workerMode == WorkerType.DRIVER;
// Run mode.
runMode = config.getEnum(RunMode.class, "ray.run-mode");
// Node ip.
@@ -1,6 +0,0 @@
package org.ray.runtime.config;
public enum WorkerMode {
DRIVER,
WORKER
}
@@ -0,0 +1,98 @@
package org.ray.runtime.objectstore;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.ray.api.id.ObjectId;
import org.ray.runtime.WorkerContext;
import org.ray.runtime.util.IdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class MockObjectInterface implements ObjectInterface {
private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectInterface.class);
private static final int GET_CHECK_INTERVAL_MS = 100;
private final Map<ObjectId, NativeRayObject> pool = new ConcurrentHashMap<>();
private final List<Consumer<ObjectId>> objectPutCallbacks = new ArrayList<>();
private final WorkerContext workerContext;
public MockObjectInterface(WorkerContext workerContext) {
this.workerContext = workerContext;
}
public void addObjectPutCallback(Consumer<ObjectId> callback) {
this.objectPutCallbacks.add(callback);
}
public boolean isObjectReady(ObjectId id) {
return pool.containsKey(id);
}
@Override
public ObjectId put(NativeRayObject obj) {
ObjectId objectId = IdUtil.computePutId(workerContext.getCurrentTaskId(),
workerContext.nextPutIndex());
put(obj, objectId);
return objectId;
}
@Override
public void put(NativeRayObject obj, ObjectId objectId) {
Preconditions.checkNotNull(obj);
Preconditions.checkNotNull(objectId);
pool.putIfAbsent(objectId, obj);
for (Consumer<ObjectId> callback : objectPutCallbacks) {
callback.accept(objectId);
}
}
@Override
public List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs) {
waitInternal(objectIds, objectIds.size(), timeoutMs);
return objectIds.stream().map(pool::get).collect(Collectors.toList());
}
@Override
public List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
waitInternal(objectIds, numObjects, timeoutMs);
return objectIds.stream().map(pool::containsKey).collect(Collectors.toList());
}
private void waitInternal(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
int ready = 0;
long remainingTime = timeoutMs;
boolean firstCheck = true;
while (ready < numObjects && (timeoutMs < 0 || remainingTime > 0)) {
if (!firstCheck) {
long sleepTime = Math.min(remainingTime, GET_CHECK_INTERVAL_MS);
try {
Thread.sleep(sleepTime);
} catch (InterruptedException e) {
LOGGER.warn("Got InterruptedException while sleeping.");
}
remainingTime -= sleepTime;
}
ready = 0;
for (ObjectId objectId : objectIds) {
if (pool.containsKey(objectId)) {
ready += 1;
}
}
firstCheck = false;
}
}
@Override
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
for (ObjectId objectId : objectIds) {
pool.remove(objectId);
}
}
}
@@ -1,148 +0,0 @@
package org.ray.runtime.objectstore;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.ray.api.id.ObjectId;
import org.ray.runtime.RayDevRuntime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A mock implementation of {@code org.ray.spi.ObjectStoreLink}, which use Map to store data.
*/
public class MockObjectStore implements ObjectStoreLink {
private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectStore.class);
private static final int GET_CHECK_INTERVAL_MS = 100;
private final RayDevRuntime runtime;
private final Map<ObjectId, byte[]> data = new ConcurrentHashMap<>();
private final Map<ObjectId, byte[]> metadata = new ConcurrentHashMap<>();
private final List<Consumer<ObjectId>> objectPutCallbacks;
public MockObjectStore(RayDevRuntime runtime) {
this.runtime = runtime;
this.objectPutCallbacks = new ArrayList<>();
}
public void addObjectPutCallback(Consumer<ObjectId> callback) {
this.objectPutCallbacks.add(callback);
}
@Override
public void put(byte[] objectId, byte[] value, byte[] metadataValue) {
if (objectId == null || objectId.length == 0 || value == null) {
LOGGER
.error("{} cannot put null: {}, {}", logPrefix(), objectId, Arrays.toString(value));
System.exit(-1);
}
ObjectId id = new ObjectId(objectId);
data.put(id, value);
if (metadataValue != null) {
metadata.put(id, metadataValue);
}
for (Consumer<ObjectId> callback : objectPutCallbacks) {
callback.accept(id);
}
}
@Override
public byte[] get(byte[] objectId, int timeoutMs, boolean isMetadata) {
return get(new byte[][] {objectId}, timeoutMs, isMetadata).get(0);
}
@Override
public List<byte[]> get(byte[][] objectIds, int timeoutMs, boolean isMetadata) {
return get(objectIds, timeoutMs)
.stream()
.map(data -> isMetadata ? data.metadata : data.data)
.collect(Collectors.toList());
}
@Override
public List<ObjectStoreData> get(byte[][] objectIds, int timeoutMs) {
int ready = 0;
int remainingTime = timeoutMs;
boolean firstCheck = true;
while (ready < objectIds.length && remainingTime > 0) {
if (!firstCheck) {
int sleepTime = Math.min(remainingTime, GET_CHECK_INTERVAL_MS);
try {
Thread.sleep(sleepTime);
} catch (InterruptedException e) {
LOGGER.warn("Got InterruptedException while sleeping.");
}
remainingTime -= sleepTime;
}
ready = 0;
for (byte[] id : objectIds) {
if (data.containsKey(new ObjectId(id))) {
ready += 1;
}
}
firstCheck = false;
}
ArrayList<ObjectStoreData> rets = new ArrayList<>();
for (byte[] objId : objectIds) {
ObjectId objectId = new ObjectId(objId);
rets.add(new ObjectStoreData(metadata.get(objectId), data.get(objectId)));
}
return rets;
}
@Override
public byte[] hash(byte[] objectId) {
return null;
}
@Override
public long evict(long numBytes) {
return 0;
}
@Override
public void release(byte[] objectId) {
return;
}
@Override
public void delete(byte[] objectId) {
return;
}
@Override
public boolean contains(byte[] objectId) {
return data.containsKey(new ObjectId(objectId));
}
private String logPrefix() {
return runtime.getWorkerContext().getCurrentTaskId() + "-" + getUserTrace() + " -> ";
}
private String getUserTrace() {
StackTraceElement[] stes = Thread.currentThread().getStackTrace();
int k = 1;
while (stes[k].getClassName().startsWith("org.ray")
&& !stes[k].getClassName().contains("test")) {
k++;
}
return stes[k].getFileName() + ":" + stes[k].getLineNumber();
}
public boolean isObjectReady(ObjectId id) {
return data.containsKey(id);
}
public void free(ObjectId id) {
data.remove(id);
metadata.remove(id);
}
}
@@ -0,0 +1,13 @@
package org.ray.runtime.objectstore;
public class NativeRayObject {
public byte[] data;
public byte[] metadata;
public NativeRayObject(byte[] data, byte[] metadata) {
this.data = data;
this.metadata = metadata;
}
}
@@ -0,0 +1,54 @@
package org.ray.runtime.objectstore;
import java.util.List;
import org.ray.api.id.ObjectId;
/**
* The interface that contains all worker methods that are related to object store.
*/
public interface ObjectInterface {
/**
* Put an object into object store.
*
* @param obj The ray object.
* @return Generated ID of the object.
*/
ObjectId put(NativeRayObject obj);
/**
* Put an object with specified ID into object store.
*
* @param obj The ray object.
* @param objectId Object ID specified by user.
*/
void put(NativeRayObject obj, ObjectId objectId);
/**
* Get a list of objects from the object store.
*
* @param objectIds IDs of the objects to get.
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
* @return Result list of objects data.
*/
List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs);
/**
* Wait for a list of objects to appear in the object store.
*
* @param objectIds IDs of the objects to wait for.
* @param numObjects Number of objects that should appear.
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
* @return A bitset that indicates each object has appeared or not.
*/
List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs);
/**
* Delete a list of objects from the object store.
*
* @param objectIds IDs of the objects to delete.
* @param localOnly Whether only delete the objects in local node, or all nodes in the cluster.
* @param deleteCreatingTasks Whether also delete the tasks that created these objects.
*/
void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks);
}
@@ -0,0 +1,91 @@
package org.ray.runtime.objectstore;
import java.util.List;
import java.util.stream.Collectors;
import org.ray.api.exception.RayException;
import org.ray.api.id.BaseId;
import org.ray.api.id.ObjectId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.WorkerContext;
import org.ray.runtime.raylet.RayletClient;
import org.ray.runtime.raylet.RayletClientImpl;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This is a wrapper class for core worker object interface.
*/
public class ObjectInterfaceImpl implements ObjectInterface {
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
/**
* The native pointer of core worker object interface.
*/
private final long nativeObjectInterfacePointer;
public ObjectInterfaceImpl(WorkerContext workerContext, RayletClient rayletClient,
String storeSocketName) {
this.nativeObjectInterfacePointer =
nativeCreateObjectInterface(workerContext.getNativeWorkerContext(),
((RayletClientImpl) rayletClient).getClient(), storeSocketName);
}
@Override
public ObjectId put(NativeRayObject obj) {
return new ObjectId(nativePut(nativeObjectInterfacePointer, obj));
}
@Override
public void put(NativeRayObject obj, ObjectId objectId) {
try {
nativePut(nativeObjectInterfacePointer, objectId.getBytes(), obj);
} catch (RayException e) {
LOGGER.warn(e.getMessage());
}
}
@Override
public List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs) {
return nativeGet(nativeObjectInterfacePointer, toBinaryList(objectIds), timeoutMs);
}
@Override
public List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
return nativeWait(nativeObjectInterfacePointer, toBinaryList(objectIds), numObjects, timeoutMs);
}
@Override
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
nativeDelete(nativeObjectInterfacePointer, toBinaryList(objectIds), localOnly, deleteCreatingTasks);
}
public void destroy() {
nativeDestroy(nativeObjectInterfacePointer);
}
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
}
private static native long nativeCreateObjectInterface(long nativeObjectInterface,
long nativeRayletClient,
String storeSocketName);
private static native byte[] nativePut(long nativeObjectInterface, NativeRayObject obj);
private static native void nativePut(long nativeObjectInterface, byte[] objectId,
NativeRayObject obj);
private static native List<NativeRayObject> nativeGet(long nativeObjectInterface,
List<byte[]> ids,
long timeoutMs);
private static native List<Boolean> nativeWait(long nativeObjectInterface, List<byte[]> objectIds,
int numObjects, long timeoutMs);
private static native void nativeDelete(long nativeObjectInterface, List<byte[]> objectIds,
boolean localOnly, boolean deleteCreatingTasks);
private static native void nativeDestroy(long nativeObjectInterface);
}
@@ -4,20 +4,14 @@ import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.arrow.plasma.ObjectStoreLink.ObjectStoreData;
import org.apache.arrow.plasma.PlasmaClient;
import org.apache.arrow.plasma.exceptions.DuplicateObjectException;
import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
import org.ray.api.exception.UnreconstructableException;
import org.ray.api.id.ObjectId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.WorkerContext;
import org.ray.runtime.generated.Gcs.ErrorType;
import org.ray.runtime.util.IdUtil;
import org.ray.runtime.util.Serializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -36,21 +30,18 @@ public class ObjectStoreProxy {
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
private final AbstractRayRuntime runtime;
private final WorkerContext workerContext;
private static ThreadLocal<ObjectStoreLink> objectStore;
private final ObjectInterface objectInterface;
public ObjectStoreProxy(AbstractRayRuntime runtime, String storeSocketName) {
this.runtime = runtime;
objectStore = ThreadLocal.withInitial(() -> {
if (runtime.getRayConfig().runMode == RunMode.CLUSTER) {
return new PlasmaClient(storeSocketName, "", 0);
} else {
return ((RayDevRuntime) runtime).getObjectStore();
}
});
public ObjectStoreProxy(WorkerContext workerContext, ObjectInterface objectInterface) {
this.workerContext = workerContext;
this.objectInterface = objectInterface;
}
/**
@@ -75,46 +66,44 @@ public class ObjectStoreProxy {
* @return A list of GetResult objects.
*/
public <T> List<GetResult<T>> get(List<ObjectId> ids, int timeoutMs) {
byte[][] binaryIds = IdUtil.getIdBytes(ids);
List<ObjectStoreData> dataAndMetaList = objectStore.get().get(binaryIds, timeoutMs);
List<NativeRayObject> dataAndMetaList = objectInterface.get(ids, timeoutMs);
List<GetResult<T>> results = new ArrayList<>();
for (int i = 0; i < dataAndMetaList.size(); i++) {
byte[] meta = dataAndMetaList.get(i).metadata;
byte[] data = dataAndMetaList.get(i).data;
NativeRayObject dataAndMeta = dataAndMetaList.get(i);
GetResult<T> result;
if (meta != null) {
// If meta is not null, deserialize the object from meta.
result = deserializeFromMeta(meta, data, ids.get(i));
} else if (data != null) {
// If data is not null, deserialize the Java object.
Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader());
if (object instanceof RayException) {
// If the object is a `RayException`, it means that an error occurred during task
// execution.
result = new GetResult<>(true, null, (RayException) object);
if (dataAndMeta != null) {
byte[] meta = dataAndMeta.metadata;
byte[] data = dataAndMeta.data;
if (meta != null && meta.length > 0) {
// If meta is not null, deserialize the object from meta.
result = deserializeFromMeta(meta, data,
workerContext.getCurrentClassLoader(), ids.get(i));
} else {
// Otherwise, the object is valid.
result = new GetResult<>(true, (T) object, null);
// If data is not null, deserialize the Java object.
Object object = Serializer.decode(data, workerContext.getCurrentClassLoader());
if (object instanceof RayException) {
// If the object is a `RayException`, it means that an error occurred during task
// execution.
result = new GetResult<>(true, null, (RayException) object);
} else {
// Otherwise, the object is valid.
result = new GetResult<>(true, (T) object, null);
}
}
} else {
// If both meta and data are null, the object doesn't exist in object store.
result = new GetResult<>(false, null, null);
}
if (meta != null || data != null) {
// Release the object from object store..
objectStore.get().release(binaryIds[i]);
}
results.add(result);
}
return results;
}
@SuppressWarnings("unchecked")
private <T> GetResult<T> deserializeFromMeta(byte[] meta, byte[] data, ObjectId objectId) {
private <T> GetResult<T> deserializeFromMeta(byte[] meta, byte[] data,
ClassLoader classLoader, ObjectId objectId) {
if (Arrays.equals(meta, RAW_TYPE_META)) {
return (GetResult<T>) new GetResult<>(true, data, null);
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
@@ -123,6 +112,8 @@ public class ObjectStoreProxy {
return new GetResult<>(true, null, RayActorException.INSTANCE);
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
return new GetResult<>(true, null, new UnreconstructableException(objectId));
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
return new GetResult<>(true, null, Serializer.decode(data, classLoader));
}
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
}
@@ -134,16 +125,14 @@ public class ObjectStoreProxy {
* @param object The object to put.
*/
public void put(ObjectId id, Object object) {
try {
if (object instanceof byte[]) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
objectStore.get().put(id.getBytes(), (byte[]) object, RAW_TYPE_META);
} else {
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
}
} catch (DuplicateObjectException e) {
LOGGER.warn(e.getMessage());
if (object instanceof byte[]) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
objectInterface.put(new NativeRayObject((byte[]) object, RAW_TYPE_META), id);
} else if (object instanceof RayTaskException) {
objectInterface.put(new NativeRayObject(Serializer.encode(object), TASK_EXECUTION_EXCEPTION_META), id);
} else {
objectInterface.put(new NativeRayObject(Serializer.encode(object), null), id);
}
}
@@ -154,11 +143,7 @@ public class ObjectStoreProxy {
* @param serializedObject The serialized object to put.
*/
public void putSerialized(ObjectId id, byte[] serializedObject) {
try {
objectStore.get().put(id.getBytes(), serializedObject, null);
} catch (DuplicateObjectException e) {
LOGGER.warn(e.getMessage());
}
objectInterface.put(new NativeRayObject(serializedObject, null), id);
}
/**
@@ -14,6 +14,7 @@ import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import org.apache.commons.lang3.NotImplementedException;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
@@ -23,7 +24,8 @@ import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.Worker;
import org.ray.runtime.objectstore.MockObjectStore;
import org.ray.runtime.objectstore.MockObjectInterface;
import org.ray.runtime.objectstore.NativeRayObject;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.TaskSpec;
import org.slf4j.Logger;
@@ -37,7 +39,7 @@ public class MockRayletClient implements RayletClient {
private static final Logger LOGGER = LoggerFactory.getLogger(MockRayletClient.class);
private final Map<ObjectId, Set<TaskSpec>> waitingTasks = new ConcurrentHashMap<>();
private final MockObjectStore store;
private final MockObjectInterface objectInterface;
private final RayDevRuntime runtime;
private final ExecutorService exec;
private final Deque<Worker> idleWorkers;
@@ -46,8 +48,8 @@ public class MockRayletClient implements RayletClient {
public MockRayletClient(RayDevRuntime runtime, int numberThreads) {
this.runtime = runtime;
this.store = runtime.getObjectStore();
store.addObjectPutCallback(this::onObjectPut);
this.objectInterface = runtime.getObjectInterface();
objectInterface.addObjectPutCallback(this::onObjectPut);
// The thread pool that executes tasks in parallel.
exec = Executors.newFixedThreadPool(numberThreads);
idleWorkers = new ConcurrentLinkedDeque<>();
@@ -113,8 +115,8 @@ public class MockRayletClient implements RayletClient {
// can be executed.
if (task.isActorCreationTask() || task.isActorTask()) {
ObjectId[] returnIds = task.returnIds;
store.put(returnIds[returnIds.length - 1].getBytes(),
new byte[]{}, new byte[]{});
objectInterface.put(new NativeRayObject(new byte[] {}, new byte[] {}),
returnIds[returnIds.length - 1]);
}
} finally {
returnWorker(worker);
@@ -133,13 +135,13 @@ public class MockRayletClient implements RayletClient {
// Check whether task arguments are ready.
for (FunctionArg arg : spec.args) {
if (arg.id != null) {
if (!store.isObjectReady(arg.id)) {
if (!objectInterface.isObjectReady(arg.id)) {
unreadyObjects.add(arg.id);
}
}
}
if (spec.isActorTask()) {
if (!store.isObjectReady(spec.previousActorTaskDummyObjectId)) {
if (!objectInterface.isObjectReady(spec.previousActorTaskDummyObjectId)) {
unreadyObjects.add(spec.previousActorTaskDummyObjectId);
}
}
@@ -154,7 +156,7 @@ public class MockRayletClient implements RayletClient {
@Override
public void fetchOrReconstruct(List<ObjectId> objectIds, boolean fetchOnly,
TaskId currentTaskId) {
TaskId currentTaskId) {
}
@@ -170,20 +172,17 @@ public class MockRayletClient implements RayletClient {
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
timeoutMs, TaskId currentTaskId) {
timeoutMs, TaskId currentTaskId) {
if (waitFor == null || waitFor.isEmpty()) {
return new WaitResult<>(ImmutableList.of(), ImmutableList.of());
}
byte[][] ids = new byte[waitFor.size()][];
for (int i = 0; i < waitFor.size(); i++) {
ids[i] = waitFor.get(i).getId().getBytes();
}
List<ObjectId> ids = waitFor.stream().map(RayObject::getId).collect(Collectors.toList());
List<RayObject<T>> readyList = new ArrayList<>();
List<RayObject<T>> unreadyList = new ArrayList<>();
List<byte[]> result = store.get(ids, timeoutMs, false);
List<Boolean> result = objectInterface.wait(ids, ids.size(), timeoutMs);
for (int i = 0; i < waitFor.size(); i++) {
if (result.get(i) != null) {
if (result.get(i)) {
readyList.add(waitFor.get(i));
} else {
unreadyList.add(waitFor.get(i));
@@ -195,9 +194,7 @@ public class MockRayletClient implements RayletClient {
@Override
public void freePlasmaObjects(List<ObjectId> objectIds, boolean localOnly,
boolean deleteCreatingTasks) {
for (ObjectId id : objectIds) {
store.free(id);
}
objectInterface.delete(objectIds, localOnly, deleteCreatingTasks);
}
@@ -4,8 +4,6 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
@@ -40,11 +38,15 @@ public class RayletClientImpl implements RayletClient {
// TODO(qwang): JobId parameter can be removed once we embed jobId in driverId.
public RayletClientImpl(String schedulerSockName, UniqueId clientId,
boolean isWorker, JobId jobId) {
boolean isWorker, JobId jobId) {
client = nativeInit(schedulerSockName, clientId.getBytes(),
isWorker, jobId.getBytes());
}
public long getClient() {
return client;
}
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
timeoutMs, TaskId currentTaskId) {
@@ -133,7 +135,7 @@ public class RayletClientImpl implements RayletClient {
/**
* Parse `TaskSpec` protobuf bytes.
*/
private static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
public static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
Common.TaskSpec taskSpec;
try {
taskSpec = Common.TaskSpec.parseFrom(bytes);
@@ -214,7 +216,7 @@ public class RayletClientImpl implements RayletClient {
/**
* Convert a `TaskSpec` to protobuf-serialized bytes.
*/
private static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
public static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
// Set common fields.
Common.TaskSpec.Builder builder = Common.TaskSpec.newBuilder()
.setJobId(ByteString.copyFrom(task.jobId.getBytes()))
@@ -154,18 +154,6 @@ public class IdUtil {
}
/**
* Compute the driver id from the given job.
*/
public static UniqueId computeDriverId(JobId jobId) {
byte[] bytes = new byte[UniqueId.LENGTH];
System.arraycopy(jobId.getBytes(), 0, bytes, 0, jobId.size());
Arrays.fill(bytes, jobId.size(), UniqueId.LENGTH, (byte)0xFF);
ByteBuffer wbb = ByteBuffer.wrap(bytes);
wbb.order(ByteOrder.LITTLE_ENDIAN);
return new UniqueId(bytes);
}
/**
* Compute the murmur hash code of this ID.
*/
@@ -1,12 +1,18 @@
package org.ray.api.test;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.TestUtils;
import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
import org.ray.api.function.RayFunc0;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -23,6 +29,15 @@ public class FailureTest extends BaseTest {
return 0;
}
public static int slowFunc() {
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
return 0;
}
public static class BadActor {
public BadActor(boolean failOnCreation) {
@@ -106,5 +121,26 @@ public class FailureTest extends BaseTest {
// RayActorException.
}
}
@Test
public void testGetThrowsQuicklyWhenFoundException() {
TestUtils.skipTestUnderSingleProcess();
List<RayFunc0<Integer>> badFunctions = Arrays.asList(FailureTest::badFunc,
FailureTest::badFunc2);
for (RayFunc0<Integer> badFunc : badFunctions) {
RayObject<Integer> obj1 = Ray.call(badFunc);
RayObject<Integer> obj2 = Ray.call(FailureTest::slowFunc);
Instant start = Instant.now();
try {
Ray.get(Arrays.asList(obj1.getId(), obj2.getId()));
Assert.fail("Should throw RayException.");
} catch (RayException e) {
Instant end = Instant.now();
long duration = Duration.between(start, end).toMillis();
Assert.assertTrue(duration < 5000, "Should fail quickly. " +
"Actual execution time: " + duration + " ms.");
}
}
}
}
@@ -1,12 +1,10 @@
package org.ray.api.test;
import org.apache.arrow.plasma.PlasmaClient;
import org.apache.arrow.plasma.exceptions.DuplicateObjectException;
import org.ray.api.Ray;
import org.ray.api.TestUtils;
import org.ray.api.id.UniqueId;
import org.ray.api.id.ObjectId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -15,15 +13,13 @@ public class PlasmaStoreTest extends BaseTest {
@Test
public void testPutWithDuplicateId() {
TestUtils.skipTestUnderSingleProcess();
UniqueId objectId = UniqueId.randomId();
ObjectId objectId = ObjectId.randomId();
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
PlasmaClient store = new PlasmaClient(runtime.getRayConfig().objectStoreSocketName, "", 0);
store.put(objectId.getBytes(), new byte[]{}, new byte[]{});
try {
store.put(objectId.getBytes(), new byte[]{}, new byte[]{});
Assert.fail("This line shouldn't be reached.");
} catch (DuplicateObjectException e) {
// Putting 2 objects with duplicate ID should throw DuplicateObjectException.
}
ObjectStoreProxy objectInterface = runtime.getObjectStoreProxy();
objectInterface.put(objectId, 1);
Assert.assertEquals(objectInterface.<Integer>get(objectId, -1).object, (Integer) 1);
objectInterface.put(objectId, 2);
// Putting 2 objects with duplicate ID should fail but ignored.
Assert.assertEquals(objectInterface.<Integer>get(objectId, -1).object, (Integer) 1);
}
}
@@ -1,7 +1,7 @@
package org.ray.api.test;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.generated.Common.WorkerType;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -12,7 +12,7 @@ public class RayConfigTest {
try {
System.setProperty("ray.job.resource-path", "path/to/ray/job/resource/path");
RayConfig rayConfig = RayConfig.create();
Assert.assertEquals(WorkerMode.DRIVER, rayConfig.workerMode);
Assert.assertEquals(WorkerType.DRIVER, rayConfig.workerMode);
Assert.assertEquals("path/to/ray/job/resource/path", rayConfig.jobResourcePath);
} finally {
// Unset system properties.