[Java] Support direct actor call in Java worker (#5504)

This commit is contained in:
Kai Yang
2019-09-09 14:29:20 +08:00
committed by Hao Chen
parent 74abeab057
commit ed761900f6
61 changed files with 608 additions and 728 deletions
@@ -10,26 +10,33 @@ public class ActorCreationOptions extends BaseTaskOptions {
public static final int NO_RECONSTRUCTION = 0;
public static final int INFINITE_RECONSTRUCTIONS = (int) Math.pow(2, 30);
// DO NOT set this environment variable. It's only used for test purposes.
// Please use `setUseDirectCall` instead.
public static final boolean DEFAULT_USE_DIRECT_CALL = "1"
.equals(System.getenv("ACTOR_CREATION_OPTIONS_DEFAULT_USE_DIRECT_CALL"));
public final int maxReconstructions;
public final boolean useDirectCall;
public final String jvmOptions;
private ActorCreationOptions(Map<String, Double> resources,
int maxReconstructions,
String jvmOptions) {
private ActorCreationOptions(Map<String, Double> resources, int maxReconstructions,
boolean useDirectCall, String jvmOptions) {
super(resources);
this.maxReconstructions = maxReconstructions;
this.useDirectCall = useDirectCall;
this.jvmOptions = jvmOptions;
}
/**
* The inner class for building ActorCreationOptions.
* The inner class for building ActorCreationOptions.
*/
public static class Builder {
private Map<String, Double> resources = new HashMap<>();
private int maxReconstructions = NO_RECONSTRUCTION;
private boolean useDirectCall = DEFAULT_USE_DIRECT_CALL;
private String jvmOptions = null;
public Builder setResources(Map<String, Double> resources) {
@@ -42,13 +49,21 @@ public class ActorCreationOptions extends BaseTaskOptions {
return this;
}
// Since direct call is not fully supported yet (see issue #5559),
// users are not allowed to set the option to true.
// TODO (kfstorm): uncomment when direct call is ready.
// public Builder setUseDirectCall(boolean useDirectCall) {
// this.useDirectCall = useDirectCall;
// return this;
// }
public Builder setJvmOptions(String jvmOptions) {
this.jvmOptions = jvmOptions;
return this;
}
public ActorCreationOptions createActorCreationOptions() {
return new ActorCreationOptions(resources, maxReconstructions, jvmOptions);
return new ActorCreationOptions(resources, maxReconstructions, useDirectCall, jvmOptions);
}
}
+1 -1
View File
@@ -16,7 +16,7 @@ def gen_java_deps():
"org.apache.commons:commons-lang3:3.4",
"org.ow2.asm:asm:6.0",
"org.slf4j:slf4j-log4j12:1.7.25",
"org.testng:testng:6.9.9",
"org.testng:testng:6.9.10",
"redis.clients:jedis:2.8.0",
],
repositories = [
+1 -1
View File
@@ -75,7 +75,7 @@
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<version>6.9.9</version>
<version>6.9.10</version>
</dependency>
<dependency>
<groupId>redis.clients</groupId>
@@ -13,11 +13,11 @@ import org.ray.api.exception.RayException;
import org.ray.api.function.RayFunc;
import org.ray.api.function.RayFuncVoid;
import org.ray.api.id.ObjectId;
import org.ray.api.id.UniqueId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.CallOptions;
import org.ray.api.runtime.RayRuntime;
import org.ray.api.runtimecontext.RuntimeContext;
import org.ray.runtime.actor.NativeRayActor;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.RuntimeContextImpl;
import org.ray.runtime.context.WorkerContext;
@@ -28,7 +28,6 @@ import org.ray.runtime.gcs.GcsClient;
import org.ray.runtime.generated.Common.Language;
import org.ray.runtime.object.ObjectStore;
import org.ray.runtime.object.RayObjectImpl;
import org.ray.runtime.raylet.RayletClient;
import org.ray.runtime.task.ArgumentsBuilder;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.TaskExecutor;
@@ -51,7 +50,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
protected ObjectStore objectStore;
protected TaskSubmitter taskSubmitter;
protected RayletClient rayletClient;
protected WorkerContext workerContext;
public AbstractRayRuntime(RayConfig rayConfig) {
@@ -85,15 +83,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
objectStore.delete(objectIds, localOnly, deleteCreatingTasks);
}
@Override
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
Preconditions.checkArgument(Double.compare(capacity, 0) >= 0);
if (nodeId == null) {
nodeId = UniqueId.NIL;
}
rayletClient.setResource(resourceName, capacity, nodeId);
}
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
return objectStore.wait(waitList, numReturns, timeoutMs);
@@ -176,7 +165,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
Object[] args, int numReturns, CallOptions options) {
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, /*isDirectCall*/false);
List<ObjectId> returnIds = taskSubmitter.submitTask(functionDescriptor,
functionArgs, numReturns, options);
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
@@ -189,7 +178,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
private RayObject callActorFunction(RayActor rayActor,
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) {
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, isDirectCall(rayActor));
List<ObjectId> returnIds = taskSubmitter.submitActorTask(rayActor,
functionDescriptor, functionArgs, numReturns, null);
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
@@ -202,7 +191,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
private RayActor createActorImpl(FunctionDescriptor functionDescriptor,
Object[] args, ActorCreationOptions options) {
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, /*isDirectCall*/false);
if (functionDescriptor.getLanguage() != Language.JAVA && options != null) {
Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions));
}
@@ -210,6 +199,13 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return actor;
}
private boolean isDirectCall(RayActor rayActor) {
if (rayActor instanceof NativeRayActor) {
return ((NativeRayActor) rayActor).isDirectCallActor();
}
return false;
}
public WorkerContext getWorkerContext() {
return workerContext;
}
@@ -218,10 +214,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return objectStore;
}
public RayletClient getRayletClient() {
return rayletClient;
}
public FunctionManager getFunctionManager() {
return functionManager;
}
@@ -2,15 +2,19 @@ package org.ray.runtime;
import java.util.concurrent.atomic.AtomicInteger;
import org.ray.api.id.JobId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.LocalModeWorkerContext;
import org.ray.runtime.object.LocalModeObjectStore;
import org.ray.runtime.raylet.LocalModeRayletClient;
import org.ray.runtime.task.LocalModeTaskExecutor;
import org.ray.runtime.task.LocalModeTaskSubmitter;
import org.ray.runtime.task.TaskExecutor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class RayDevRuntime extends AbstractRayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(RayDevRuntime.class);
private AtomicInteger jobCounter = new AtomicInteger(0);
public RayDevRuntime(RayConfig rayConfig) {
@@ -18,14 +22,13 @@ public class RayDevRuntime extends AbstractRayRuntime {
if (rayConfig.getJobId().isNil()) {
rayConfig.setJobId(nextJobId());
}
taskExecutor = new TaskExecutor(this);
taskExecutor = new LocalModeTaskExecutor(this);
workerContext = new LocalModeWorkerContext(rayConfig.getJobId());
objectStore = new LocalModeObjectStore(workerContext);
taskSubmitter = new LocalModeTaskSubmitter(this, (LocalModeObjectStore) objectStore,
rayConfig.numberExecThreadsForDevRuntime);
((LocalModeObjectStore) objectStore).addObjectPutCallback(
objectId -> ((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId));
rayletClient = new LocalModeRayletClient();
}
@Override
@@ -33,6 +36,11 @@ public class RayDevRuntime extends AbstractRayRuntime {
taskExecutor = null;
}
@Override
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
LOGGER.error("Not implemented under SINGLE_PROCESS mode.");
}
private JobId nextJobId() {
return JobId.fromInt(jobCounter.getAndIncrement());
}
@@ -9,6 +9,7 @@ import java.util.HashMap;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.ray.api.id.JobId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.NativeWorkerContext;
import org.ray.runtime.gcs.GcsClient;
@@ -16,8 +17,8 @@ import org.ray.runtime.gcs.GcsClientOptions;
import org.ray.runtime.gcs.RedisClient;
import org.ray.runtime.generated.Common.WorkerType;
import org.ray.runtime.object.NativeObjectStore;
import org.ray.runtime.raylet.NativeRayletClient;
import org.ray.runtime.runner.RunManager;
import org.ray.runtime.task.NativeTaskExecutor;
import org.ray.runtime.task.NativeTaskSubmitter;
import org.ray.runtime.task.TaskExecutor;
import org.ray.runtime.util.FileUtil;
@@ -112,11 +113,10 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
new GcsClientOptions(rayConfig));
Preconditions.checkState(nativeCoreWorkerPointer != 0);
taskExecutor = new TaskExecutor(this);
taskExecutor = new NativeTaskExecutor(nativeCoreWorkerPointer, this);
workerContext = new NativeWorkerContext(nativeCoreWorkerPointer);
objectStore = new NativeObjectStore(workerContext, nativeCoreWorkerPointer);
taskSubmitter = new NativeTaskSubmitter(nativeCoreWorkerPointer);
rayletClient = new NativeRayletClient(nativeCoreWorkerPointer);
// register
registerWorker();
@@ -136,6 +136,15 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
}
}
@Override
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
Preconditions.checkArgument(Double.compare(capacity, 0) >= 0);
if (nodeId == null) {
nodeId = UniqueId.NIL;
}
nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes());
}
public void run() {
nativeRunTaskExecutor(nativeCoreWorkerPointer, taskExecutor);
}
@@ -176,4 +185,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
private static native void nativeSetup(String logDir);
private static native void nativeShutdownHook();
private static native void nativeSetResource(long conn, String resourceName, double capacity,
byte[] nodeId);
}
@@ -51,6 +51,10 @@ public class NativeRayActor implements RayActor, RayPyActor, Externalizable {
return Language.forNumber(nativeGetLanguage(nativeActorHandle));
}
public boolean isDirectCallActor() {
return nativeIsDirectCallActor(nativeActorHandle);
}
@Override
public String getModuleName() {
Preconditions.checkState(getLanguage() == Language.PYTHON);
@@ -90,6 +94,8 @@ public class NativeRayActor implements RayActor, RayPyActor, Externalizable {
private static native int nativeGetLanguage(long nativeActorHandle);
private static native boolean nativeIsDirectCallActor(long nativeActorHandle);
private static native List<String> nativeGetActorCreationTaskFunctionDescriptor(
long nativeActorHandle);
@@ -1,29 +0,0 @@
package org.ray.runtime.raylet;
import org.apache.commons.lang3.NotImplementedException;
import org.ray.api.id.ActorId;
import org.ray.api.id.UniqueId;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Raylet client for local mode.
*/
public class LocalModeRayletClient implements RayletClient {
private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeRayletClient.class);
@Override
public UniqueId prepareCheckpoint(ActorId actorId) {
throw new NotImplementedException("Not implemented.");
}
@Override
public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) {
throw new NotImplementedException("Not implemented.");
}
@Override
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
LOGGER.error("Not implemented under SINGLE_PROCESS mode.");
}
}
@@ -1,57 +0,0 @@
package org.ray.runtime.raylet;
import org.ray.api.exception.RayException;
import org.ray.api.id.ActorId;
import org.ray.api.id.UniqueId;
/**
* Raylet client for cluster mode. This is a wrapper class for C++ RayletClient.
*/
public class NativeRayletClient implements RayletClient {
/**
* The native pointer of core worker.
*/
private long nativeCoreWorkerPointer = 0;
public NativeRayletClient(long nativeCoreWorkerPointer) {
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
}
@Override
public UniqueId prepareCheckpoint(ActorId actorId) {
return new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer, actorId.getBytes()));
}
@Override
public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) {
nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, actorId.getBytes(),
checkpointId.getBytes());
}
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes());
}
/// Native method declarations.
///
/// If you change the signature of any native methods, please re-generate
/// the C++ header file and update the C++ implementation accordingly:
///
/// Suppose that $Dir is your ray root directory.
/// 1) pushd $Dir/java/runtime/target/classes
/// 2) javah -classpath .:$Dir/java/api/target/classes org.ray.runtime.raylet.NativeRayletClient
/// 3) clang-format -i org_ray_runtime_raylet_NativeRayletClient.h
/// 4) cp org_ray_runtime_raylet_NativeRayletClient.h $Dir/src/ray/core_worker/lib/java/
/// 5) vim $Dir/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc
/// 6) popd
private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId);
private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId,
byte[] checkpointId);
private static native void nativeSetResource(long conn, String resourceName, double capacity,
byte[] nodeId) throws RayException;
}
@@ -1,16 +0,0 @@
package org.ray.runtime.raylet;
import org.ray.api.id.ActorId;
import org.ray.api.id.UniqueId;
/**
* Client to the Raylet backend.
*/
public interface RayletClient {
UniqueId prepareCheckpoint(ActorId actorId);
void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId);
void setResource(String resourceName, double capacity, UniqueId nodeId);
}
@@ -25,16 +25,20 @@ public class ArgumentsBuilder {
/**
* Convert real function arguments to task spec arguments.
*/
public static List<FunctionArg> wrap(Object[] args) {
public static List<FunctionArg> wrap(Object[] args, boolean isDirectCall) {
List<FunctionArg> ret = new ArrayList<>();
for (Object arg : args) {
ObjectId id = null;
NativeRayObject value = null;
if (arg instanceof RayObject) {
if (isDirectCall) {
throw new IllegalArgumentException(
"Passing RayObject to a direct call actor is not supported.");
}
id = ((RayObject) arg).getId();
} else {
value = ObjectSerializer.serialize(arg);
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
if (!isDirectCall && value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
RayRuntime runtime = Ray.internal();
if (runtime instanceof RayMultiWorkerNativeRuntime) {
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
@@ -0,0 +1,22 @@
package org.ray.runtime.task;
import org.ray.api.id.ActorId;
import org.ray.runtime.AbstractRayRuntime;
/**
* Task executor for local mode.
*/
public class LocalModeTaskExecutor extends TaskExecutor {
public LocalModeTaskExecutor(AbstractRayRuntime runtime) {
super(runtime);
}
@Override
protected void maybeSaveCheckpoint(Object actor, ActorId actorId) {
}
@Override
protected void maybeLoadCheckpoint(Object actor, ActorId actorId) {
}
}
@@ -95,12 +95,12 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
if (task.getType() == TaskType.ACTOR_TASK) {
taskExecutor = actorTaskExecutors.get(getActorId(task));
} else if (task.getType() == TaskType.ACTOR_CREATION_TASK) {
taskExecutor = new TaskExecutor(runtime);
taskExecutor = new LocalModeTaskExecutor(runtime);
actorTaskExecutors.put(getActorId(task), taskExecutor);
} else if (idleTaskExecutors.size() > 0) {
taskExecutor = idleTaskExecutors.pop();
} else {
taskExecutor = new TaskExecutor(runtime);
taskExecutor = new LocalModeTaskExecutor(runtime);
}
}
currentTaskExecutor.set(taskExecutor);
@@ -0,0 +1,102 @@
package org.ray.runtime.task;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.List;
import org.ray.api.Checkpointable;
import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.Checkpointable.CheckpointContext;
import org.ray.api.id.ActorId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
/**
* Task executor for cluster mode.
*/
public class NativeTaskExecutor extends TaskExecutor {
// TODO(hchen): Use the C++ config.
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;
/**
* The native pointer of core worker.
*/
private final long nativeCoreWorkerPointer;
/**
* Number of tasks executed since last actor checkpoint.
*/
private int numTasksSinceLastCheckpoint = 0;
/**
* IDs of this actor's previous checkpoints.
*/
private List<UniqueId> checkpointIds;
/**
* Timestamp of the last actor checkpoint.
*/
private long lastCheckpointTimestamp = 0;
public NativeTaskExecutor(long nativeCoreWorkerPointer, AbstractRayRuntime runtime) {
super(runtime);
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
}
@Override
protected void maybeSaveCheckpoint(Object actor, ActorId actorId) {
if (!(actor instanceof Checkpointable)) {
return;
}
CheckpointContext checkpointContext = new CheckpointContext(actorId,
++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp);
Checkpointable checkpointable = (Checkpointable) actor;
if (!checkpointable.shouldCheckpoint(checkpointContext)) {
return;
}
numTasksSinceLastCheckpoint = 0;
lastCheckpointTimestamp = System.currentTimeMillis();
UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer));
checkpointIds.add(checkpointId);
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
checkpointIds.remove(0);
}
checkpointable.saveCheckpoint(actorId, checkpointId);
}
@Override
protected void maybeLoadCheckpoint(Object actor, ActorId actorId) {
if (!(actor instanceof Checkpointable)) {
return;
}
numTasksSinceLastCheckpoint = 0;
lastCheckpointTimestamp = System.currentTimeMillis();
checkpointIds = new ArrayList<>();
List<Checkpoint> availableCheckpoints
= runtime.getGcsClient().getCheckpointsForActor(actorId);
if (availableCheckpoints.isEmpty()) {
return;
}
UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints);
if (checkpointId != null) {
boolean checkpointValid = false;
for (Checkpoint checkpoint : availableCheckpoints) {
if (checkpoint.checkpointId.equals(checkpointId)) {
checkpointValid = true;
break;
}
}
Preconditions.checkArgument(checkpointValid,
"'loadCheckpoint' must return a checkpoint ID that exists in the "
+ "'availableCheckpoints' list, or null.");
nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, checkpointId.getBytes());
}
}
private static native byte[] nativePrepareCheckpoint(long nativeCoreWorkerPointer);
private static native void nativeNotifyActorResumedFromCheckpoint(long nativeCoreWorkerPointer,
byte[] checkpointId);
}
@@ -3,16 +3,11 @@ package org.ray.runtime.task;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.List;
import org.ray.api.Checkpointable;
import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.Checkpointable.CheckpointContext;
import org.ray.api.exception.RayTaskException;
import org.ray.api.id.ActorId;
import org.ray.api.id.JobId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.functionmanager.RayFunction;
import org.ray.runtime.generated.Common.TaskType;
@@ -24,13 +19,10 @@ import org.slf4j.LoggerFactory;
/**
* The task executor, which executes tasks assigned by raylet continuously.
*/
public final class TaskExecutor {
public abstract class TaskExecutor {
private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class);
// TODO(hchen): Use the C++ config.
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;
protected final AbstractRayRuntime runtime;
/**
@@ -43,22 +35,7 @@ public final class TaskExecutor {
*/
private Exception actorCreationException = null;
/**
* Number of tasks executed since last actor checkpoint.
*/
private int numTasksSinceLastCheckpoint = 0;
/**
* IDs of this actor's previous checkpoints.
*/
private List<UniqueId> checkpointIds;
/**
* Timestamp of the last actor checkpoint.
*/
private long lastCheckpointTimestamp = 0;
public TaskExecutor(AbstractRayRuntime runtime) {
protected TaskExecutor(AbstractRayRuntime runtime) {
this.runtime = runtime;
}
@@ -134,60 +111,7 @@ public final class TaskExecutor {
rayFunctionInfo.get(2));
}
private void maybeSaveCheckpoint(Object actor, ActorId actorId) {
if (!(actor instanceof Checkpointable)) {
return;
}
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
// Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet.
return;
}
CheckpointContext checkpointContext = new CheckpointContext(actorId,
++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp);
Checkpointable checkpointable = (Checkpointable) actor;
if (!checkpointable.shouldCheckpoint(checkpointContext)) {
return;
}
numTasksSinceLastCheckpoint = 0;
lastCheckpointTimestamp = System.currentTimeMillis();
UniqueId checkpointId = runtime.getRayletClient().prepareCheckpoint(actorId);
checkpointIds.add(checkpointId);
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
checkpointIds.remove(0);
}
checkpointable.saveCheckpoint(actorId, checkpointId);
}
protected abstract void maybeSaveCheckpoint(Object actor, ActorId actorId);
private void maybeLoadCheckpoint(Object actor, ActorId actorId) {
if (!(actor instanceof Checkpointable)) {
return;
}
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
// Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet.
return;
}
numTasksSinceLastCheckpoint = 0;
lastCheckpointTimestamp = System.currentTimeMillis();
checkpointIds = new ArrayList<>();
List<Checkpoint> availableCheckpoints
= runtime.getGcsClient().getCheckpointsForActor(actorId);
if (availableCheckpoints.isEmpty()) {
return;
}
UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints);
if (checkpointId != null) {
boolean checkpointValid = false;
for (Checkpoint checkpoint : availableCheckpoints) {
if (checkpoint.checkpointId.equals(checkpointId)) {
checkpointValid = true;
break;
}
}
Preconditions.checkArgument(checkpointValid,
"'loadCheckpoint' must return a checkpoint ID that exists in the "
+ "'availableCheckpoints' list, or null.");
runtime.getRayletClient().notifyActorResumedFromCheckpoint(actorId, checkpointId);
}
}
protected abstract void maybeLoadCheckpoint(Object actor, ActorId actorId);
}
+1 -1
View File
@@ -50,7 +50,7 @@
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<version>6.9.9</version>
<version>6.9.10</version>
</dependency>
</dependencies>
</project>
+3
View File
@@ -27,6 +27,9 @@ echo "Running tests under cluster mode."
# bazel test //java:all_tests --action_env=ENABLE_MULTI_LANGUAGE_TESTS=1 --test_output="errors" || cluster_exit_code=$?
ENABLE_MULTI_LANGUAGE_TESTS=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml
echo "Running tests under cluster mode with direct actor call turned on."
ENABLE_MULTI_LANGUAGE_TESTS=1 ACTOR_CREATION_OPTIONS_DEFAULT_USE_DIRECT_CALL=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml
echo "Running tests under single-process mode."
# bazel test //java:all_tests --jvmopt="-Dray.run-mode=SINGLE_PROCESS" --test_output="errors" || single_exit_code=$?
run_testng java -Dray.run-mode="SINGLE_PROCESS" -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml
+1 -1
View File
@@ -65,7 +65,7 @@
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<version>6.9.9</version>
<version>6.9.10</version>
</dependency>
</dependencies>
<build>
@@ -0,0 +1,23 @@
package org.ray.api;
import java.util.List;
import org.ray.api.options.ActorCreationOptions;
import org.testng.IAlterSuiteListener;
import org.testng.xml.XmlGroups;
import org.testng.xml.XmlRun;
import org.testng.xml.XmlSuite;
public class RayAlterSuiteListener implements IAlterSuiteListener {
@Override
public void alter(List<XmlSuite> suites) {
XmlSuite suite = suites.get(0);
if (ActorCreationOptions.DEFAULT_USE_DIRECT_CALL) {
XmlGroups groups = new XmlGroups();
XmlRun run = new XmlRun();
run.onInclude("directCall");
groups.setRun(run);
suite.setGroups(groups);
}
}
}
@@ -1,8 +1,10 @@
package org.ray.api;
import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.util.function.Supplier;
import org.ray.api.annotation.RayRemote;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
@@ -12,6 +14,11 @@ import org.testng.SkipException;
public class TestUtils {
public static class LargeObject implements Serializable {
public byte[] data = new byte[1024 * 1024];
}
private static final int WAIT_INTERVAL_MS = 5;
public static void skipTestUnderSingleProcess() {
@@ -20,6 +27,12 @@ public class TestUtils {
}
}
public static void skipTestIfDirectActorCallEnabled() {
if (ActorCreationOptions.DEFAULT_USE_DIRECT_CALL) {
throw new SkipException("This test doesn't work when direct actor call is enabled.");
}
}
/**
* Wait until the given condition is met.
*
@@ -17,6 +17,7 @@ import org.ray.api.options.ActorCreationOptions;
import org.testng.Assert;
import org.testng.annotations.Test;
@Test(groups = {"directCall"})
public class ActorReconstructionTest extends BaseTest {
@RayRemote()
@@ -44,7 +45,6 @@ public class ActorReconstructionTest extends BaseTest {
}
}
@Test
public void testActorReconstruction() throws InterruptedException, IOException {
TestUtils.skipTestUnderSingleProcess();
ActorCreationOptions options =
@@ -65,7 +65,7 @@ public class ActorReconstructionTest extends BaseTest {
// Try calling increase on this actor again and check the value is now 4.
int value = Ray.call(Counter::increase, actor).get();
Assert.assertEquals(value, 4);
Assert.assertEquals(value, options.useDirectCall ? 1 : 4);
Assert.assertTrue(Ray.call(Counter::wasCurrentActorReconstructed, actor).get());
@@ -125,7 +125,6 @@ public class ActorReconstructionTest extends BaseTest {
}
}
@Test
public void testActorCheckpointing() throws IOException, InterruptedException {
TestUtils.skipTestUnderSingleProcess();
ActorCreationOptions options =
@@ -8,6 +8,7 @@ import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.TestUtils;
import org.ray.api.TestUtils.LargeObject;
import org.ray.api.annotation.RayRemote;
import org.ray.api.exception.UnreconstructableException;
import org.ray.api.id.UniqueId;
@@ -16,6 +17,7 @@ import org.ray.runtime.object.NativeRayObject;
import org.testng.Assert;
import org.testng.annotations.Test;
@Test(groups = {"directCall"})
public class ActorTest extends BaseTest {
@RayRemote
@@ -39,9 +41,13 @@ public class ActorTest extends BaseTest {
value += delta;
return value;
}
public int accessLargeObject(LargeObject largeObject) {
value += largeObject.data.length;
return value;
}
}
@Test
public void testCreateAndCallActor() {
// Test creating an actor from a constructor
RayActor<Counter> actor = Ray.createActor(Counter::new, 1);
@@ -52,12 +58,18 @@ public class ActorTest extends BaseTest {
Assert.assertEquals(Integer.valueOf(3), Ray.call(Counter::increaseAndGet, actor, 1).get());
}
public void testCallActorWithLargeObject() {
RayActor<Counter> actor = Ray.createActor(Counter::new, 1);
LargeObject largeObject = new LargeObject();
Assert.assertEquals(Integer.valueOf(largeObject.data.length + 1),
Ray.call(Counter::accessLargeObject, actor, largeObject).get());
}
@RayRemote
public static Counter factory(int initValue) {
static Counter factory(int initValue) {
return new Counter(initValue);
}
@Test
public void testCreateActorFromFactory() {
// Test creating an actor from a factory method
RayActor<Counter> actor = Ray.createActor(ActorTest::factory, 1);
@@ -67,24 +79,23 @@ public class ActorTest extends BaseTest {
}
@RayRemote
public static int testActorAsFirstParameter(RayActor<Counter> actor, int delta) {
static int testActorAsFirstParameter(RayActor<Counter> actor, int delta) {
RayObject<Integer> res = Ray.call(Counter::increaseAndGet, actor, delta);
return res.get();
}
@RayRemote
public static int testActorAsSecondParameter(int delta, RayActor<Counter> actor) {
static int testActorAsSecondParameter(int delta, RayActor<Counter> actor) {
RayObject<Integer> res = Ray.call(Counter::increaseAndGet, actor, delta);
return res.get();
}
@RayRemote
public static int testActorAsFieldOfParameter(List<RayActor<Counter>> actor, int delta) {
static int testActorAsFieldOfParameter(List<RayActor<Counter>> actor, int delta) {
RayObject<Integer> res = Ray.call(Counter::increaseAndGet, actor.get(0), delta);
return res.get();
}
@Test
public void testPassActorAsParameter() {
RayActor<Counter> actor = Ray.createActor(Counter::new, 0);
Assert.assertEquals(Integer.valueOf(1),
@@ -96,7 +107,6 @@ public class ActorTest extends BaseTest {
.get());
}
@Test
public void testForkingActorHandle() {
TestUtils.skipTestUnderSingleProcess();
RayActor<Counter> counter = Ray.createActor(Counter::new, 100);
@@ -105,9 +115,11 @@ public class ActorTest extends BaseTest {
Assert.assertEquals(Integer.valueOf(103), Ray.call(Counter::increaseAndGet, counter2, 2).get());
}
@Test
public void testUnreconstructableActorObject() throws InterruptedException {
TestUtils.skipTestUnderSingleProcess();
// The UnreconstructableException is created by raylet.
// TODO (kfstorm): This should be supported by direct actor call.
TestUtils.skipTestIfDirectActorCallEnabled();
RayActor<Counter> counter = Ray.createActor(Counter::new, 100);
// Call an actor method.
RayObject value = Ray.call(Counter::getValue, counter);
@@ -45,7 +45,7 @@ public abstract class BaseMultiLanguageTest {
}
}
@BeforeClass
@BeforeClass(alwaysRun = true)
public void setUp() {
if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) {
LOGGER.info("Skip Multi-language tests because environment variable "
@@ -100,7 +100,7 @@ public abstract class BaseMultiLanguageTest {
return ImmutableMap.of();
}
@AfterClass
@AfterClass(alwaysRun = true)
public void tearDown() {
// Disconnect to the cluster.
Ray.shutdown();
@@ -16,7 +16,7 @@ public class BaseTest {
private List<File> filesToDelete;
@BeforeMethod
@BeforeMethod(alwaysRun = true)
public void setUpBase(Method method) {
LOGGER.info("===== Running test: "
+ method.getDeclaringClass().getName() + "." + method.getName());
@@ -34,7 +34,7 @@ public class BaseTest {
filesToDelete.forEach(File::deleteOnExit);
}
@AfterMethod
@AfterMethod(alwaysRun = true)
public void tearDownBase() {
Ray.shutdown();
@@ -9,6 +9,7 @@ import org.apache.commons.io.FileUtils;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
import org.ray.api.TestUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -45,8 +46,10 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
Assert.assertEquals(res.get(), "Response from Python: hello".getBytes());
}
@Test
@Test(groups = {"directCall"})
public void testCallingPythonActor() {
// Python worker doesn't support direct call yet.
TestUtils.skipTestIfDirectActorCallEnabled();
RayPyActor actor = Ray.createPyActor(PYTHON_MODULE, "Counter", "1".getBytes());
RayObject res = Ray.callPy(actor, "increase", "1".getBytes());
Assert.assertEquals(res.get(), "2".getBytes());
@@ -76,14 +76,14 @@ public class FailureTest extends BaseTest {
assertTaskFailedWithRayTaskException(Ray.call(FailureTest::badFunc));
}
@Test
@Test(groups = {"directCall"})
public void testActorCreationFailure() {
TestUtils.skipTestUnderSingleProcess();
RayActor<BadActor> actor = Ray.createActor(BadActor::new, true);
assertTaskFailedWithRayTaskException(Ray.call(BadActor::badMethod, actor));
}
@Test
@Test(groups = {"directCall"})
public void testActorTaskFailure() {
TestUtils.skipTestUnderSingleProcess();
RayActor<BadActor> actor = Ray.createActor(BadActor::new, false);
@@ -102,9 +102,12 @@ public class FailureTest extends BaseTest {
}
}
@Test
@Test(groups = {"directCall"})
public void testActorProcessDying() {
TestUtils.skipTestUnderSingleProcess();
// This test case hangs if the worker to worker connection is implemented with grpc.
// TODO (kfstorm): Should be fixed.
TestUtils.skipTestIfDirectActorCallEnabled();
RayActor<BadActor> actor = Ray.createActor(BadActor::new, false);
try {
Ray.call(BadActor::badMethod2, actor).get();
@@ -21,7 +21,7 @@ import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.Test;
@Test(groups = {"directCall"})
public class MultiThreadingTest extends BaseTest {
private static final Logger LOGGER = LoggerFactory.getLogger(MultiThreadingTest.class);
@@ -30,7 +30,7 @@ public class MultiThreadingTest extends BaseTest {
private static final int NUM_THREADS = 20;
@RayRemote
public static Integer echo(int num) {
static Integer echo(int num) {
return num;
}
@@ -73,7 +73,7 @@ public class MultiThreadingTest extends BaseTest {
}
}
public static String testMultiThreading() {
static String testMultiThreading() {
Random random = new Random();
// Test calling normal functions.
runTestCaseInMultipleThreads(() -> {
@@ -123,12 +123,10 @@ public class MultiThreadingTest extends BaseTest {
return "ok";
}
@Test
public void testInDriver() {
testMultiThreading();
}
@Test
public void testInWorker() {
// Single-process mode doesn't have real workers.
TestUtils.skipTestUnderSingleProcess();
@@ -136,7 +134,6 @@ public class MultiThreadingTest extends BaseTest {
Assert.assertEquals("ok", obj.get());
}
@Test
public void testGetCurrentActorId() {
TestUtils.skipTestUnderSingleProcess();
RayActor<ActorIdTester> actorIdTester = Ray.createActor(ActorIdTester::new);
@@ -2,11 +2,11 @@ package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import org.ray.api.Ray;
import org.ray.api.TestUtils;
import org.ray.api.TestUtils.LargeObject;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.ObjectId;
import org.testng.Assert;
@@ -67,11 +67,6 @@ public class RayCallTest extends BaseTest {
return val;
}
public static class LargeObject implements Serializable {
private byte[] data = new byte[1024 * 1024];
}
@RayRemote
private static LargeObject testLargeObject(LargeObject largeObject) {
return largeObject;
@@ -72,7 +72,7 @@ public class StressTest extends BaseTest {
}
}
@Test
@Test(groups = {"directCall"})
public void testSubmittingManyTasksToOneActor() {
TestUtils.skipTestUnderSingleProcess();
RayActor<Actor> actor = Ray.createActor(Actor::new);
+9 -6
View File
@@ -1,10 +1,13 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE suite SYSTEM "http://testng.org/testng-1.0.dtd">
<suite name="RAY suite" verbose="2">
<test name = "RAY test" >
<packages>
<package name = "org.ray.api.test.*" />
<package name = "org.ray.runtime.*" />
</packages>
</test>
<test name = "RAY test">
<packages>
<package name = "org.ray.api.test.*" />
<package name = "org.ray.runtime.*" />
</packages>
</test>
<listeners>
<listener class-name="org.ray.api.RayAlterSuiteListener" />
</listeners>
</suite>