mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:02:16 +08:00
[Java] Support direct actor call in Java worker (#5504)
This commit is contained in:
@@ -13,11 +13,11 @@ import org.ray.api.exception.RayException;
|
||||
import org.ray.api.function.RayFunc;
|
||||
import org.ray.api.function.RayFuncVoid;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.api.runtimecontext.RuntimeContext;
|
||||
import org.ray.runtime.actor.NativeRayActor;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.RuntimeContextImpl;
|
||||
import org.ray.runtime.context.WorkerContext;
|
||||
@@ -28,7 +28,6 @@ import org.ray.runtime.gcs.GcsClient;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
import org.ray.runtime.object.RayObjectImpl;
|
||||
import org.ray.runtime.raylet.RayletClient;
|
||||
import org.ray.runtime.task.ArgumentsBuilder;
|
||||
import org.ray.runtime.task.FunctionArg;
|
||||
import org.ray.runtime.task.TaskExecutor;
|
||||
@@ -51,7 +50,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
protected ObjectStore objectStore;
|
||||
protected TaskSubmitter taskSubmitter;
|
||||
protected RayletClient rayletClient;
|
||||
protected WorkerContext workerContext;
|
||||
|
||||
public AbstractRayRuntime(RayConfig rayConfig) {
|
||||
@@ -85,15 +83,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
objectStore.delete(objectIds, localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
Preconditions.checkArgument(Double.compare(capacity, 0) >= 0);
|
||||
if (nodeId == null) {
|
||||
nodeId = UniqueId.NIL;
|
||||
}
|
||||
rayletClient.setResource(resourceName, capacity, nodeId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
|
||||
return objectStore.wait(waitList, numReturns, timeoutMs);
|
||||
@@ -176,7 +165,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
|
||||
Object[] args, int numReturns, CallOptions options) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, /*isDirectCall*/false);
|
||||
List<ObjectId> returnIds = taskSubmitter.submitTask(functionDescriptor,
|
||||
functionArgs, numReturns, options);
|
||||
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
|
||||
@@ -189,7 +178,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
private RayObject callActorFunction(RayActor rayActor,
|
||||
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, isDirectCall(rayActor));
|
||||
List<ObjectId> returnIds = taskSubmitter.submitActorTask(rayActor,
|
||||
functionDescriptor, functionArgs, numReturns, null);
|
||||
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
|
||||
@@ -202,7 +191,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
private RayActor createActorImpl(FunctionDescriptor functionDescriptor,
|
||||
Object[] args, ActorCreationOptions options) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, /*isDirectCall*/false);
|
||||
if (functionDescriptor.getLanguage() != Language.JAVA && options != null) {
|
||||
Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions));
|
||||
}
|
||||
@@ -210,6 +199,13 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
return actor;
|
||||
}
|
||||
|
||||
private boolean isDirectCall(RayActor rayActor) {
|
||||
if (rayActor instanceof NativeRayActor) {
|
||||
return ((NativeRayActor) rayActor).isDirectCallActor();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public WorkerContext getWorkerContext() {
|
||||
return workerContext;
|
||||
}
|
||||
@@ -218,10 +214,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
return objectStore;
|
||||
}
|
||||
|
||||
public RayletClient getRayletClient() {
|
||||
return rayletClient;
|
||||
}
|
||||
|
||||
public FunctionManager getFunctionManager() {
|
||||
return functionManager;
|
||||
}
|
||||
|
||||
@@ -2,15 +2,19 @@ package org.ray.runtime;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.LocalModeWorkerContext;
|
||||
import org.ray.runtime.object.LocalModeObjectStore;
|
||||
import org.ray.runtime.raylet.LocalModeRayletClient;
|
||||
import org.ray.runtime.task.LocalModeTaskExecutor;
|
||||
import org.ray.runtime.task.LocalModeTaskSubmitter;
|
||||
import org.ray.runtime.task.TaskExecutor;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class RayDevRuntime extends AbstractRayRuntime {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayDevRuntime.class);
|
||||
|
||||
private AtomicInteger jobCounter = new AtomicInteger(0);
|
||||
|
||||
public RayDevRuntime(RayConfig rayConfig) {
|
||||
@@ -18,14 +22,13 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
||||
if (rayConfig.getJobId().isNil()) {
|
||||
rayConfig.setJobId(nextJobId());
|
||||
}
|
||||
taskExecutor = new TaskExecutor(this);
|
||||
taskExecutor = new LocalModeTaskExecutor(this);
|
||||
workerContext = new LocalModeWorkerContext(rayConfig.getJobId());
|
||||
objectStore = new LocalModeObjectStore(workerContext);
|
||||
taskSubmitter = new LocalModeTaskSubmitter(this, (LocalModeObjectStore) objectStore,
|
||||
rayConfig.numberExecThreadsForDevRuntime);
|
||||
((LocalModeObjectStore) objectStore).addObjectPutCallback(
|
||||
objectId -> ((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId));
|
||||
rayletClient = new LocalModeRayletClient();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -33,6 +36,11 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
||||
taskExecutor = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
LOGGER.error("Not implemented under SINGLE_PROCESS mode.");
|
||||
}
|
||||
|
||||
private JobId nextJobId() {
|
||||
return JobId.fromInt(jobCounter.getAndIncrement());
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.NativeWorkerContext;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
@@ -16,8 +17,8 @@ import org.ray.runtime.gcs.GcsClientOptions;
|
||||
import org.ray.runtime.gcs.RedisClient;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.ray.runtime.object.NativeObjectStore;
|
||||
import org.ray.runtime.raylet.NativeRayletClient;
|
||||
import org.ray.runtime.runner.RunManager;
|
||||
import org.ray.runtime.task.NativeTaskExecutor;
|
||||
import org.ray.runtime.task.NativeTaskSubmitter;
|
||||
import org.ray.runtime.task.TaskExecutor;
|
||||
import org.ray.runtime.util.FileUtil;
|
||||
@@ -112,11 +113,10 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
new GcsClientOptions(rayConfig));
|
||||
Preconditions.checkState(nativeCoreWorkerPointer != 0);
|
||||
|
||||
taskExecutor = new TaskExecutor(this);
|
||||
taskExecutor = new NativeTaskExecutor(nativeCoreWorkerPointer, this);
|
||||
workerContext = new NativeWorkerContext(nativeCoreWorkerPointer);
|
||||
objectStore = new NativeObjectStore(workerContext, nativeCoreWorkerPointer);
|
||||
taskSubmitter = new NativeTaskSubmitter(nativeCoreWorkerPointer);
|
||||
rayletClient = new NativeRayletClient(nativeCoreWorkerPointer);
|
||||
|
||||
// register
|
||||
registerWorker();
|
||||
@@ -136,6 +136,15 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
Preconditions.checkArgument(Double.compare(capacity, 0) >= 0);
|
||||
if (nodeId == null) {
|
||||
nodeId = UniqueId.NIL;
|
||||
}
|
||||
nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes());
|
||||
}
|
||||
|
||||
public void run() {
|
||||
nativeRunTaskExecutor(nativeCoreWorkerPointer, taskExecutor);
|
||||
}
|
||||
@@ -176,4 +185,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
private static native void nativeSetup(String logDir);
|
||||
|
||||
private static native void nativeShutdownHook();
|
||||
|
||||
private static native void nativeSetResource(long conn, String resourceName, double capacity,
|
||||
byte[] nodeId);
|
||||
}
|
||||
|
||||
@@ -51,6 +51,10 @@ public class NativeRayActor implements RayActor, RayPyActor, Externalizable {
|
||||
return Language.forNumber(nativeGetLanguage(nativeActorHandle));
|
||||
}
|
||||
|
||||
public boolean isDirectCallActor() {
|
||||
return nativeIsDirectCallActor(nativeActorHandle);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModuleName() {
|
||||
Preconditions.checkState(getLanguage() == Language.PYTHON);
|
||||
@@ -90,6 +94,8 @@ public class NativeRayActor implements RayActor, RayPyActor, Externalizable {
|
||||
|
||||
private static native int nativeGetLanguage(long nativeActorHandle);
|
||||
|
||||
private static native boolean nativeIsDirectCallActor(long nativeActorHandle);
|
||||
|
||||
private static native List<String> nativeGetActorCreationTaskFunctionDescriptor(
|
||||
long nativeActorHandle);
|
||||
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
package org.ray.runtime.raylet;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Raylet client for local mode.
|
||||
*/
|
||||
public class LocalModeRayletClient implements RayletClient {
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeRayletClient.class);
|
||||
|
||||
@Override
|
||||
public UniqueId prepareCheckpoint(ActorId actorId) {
|
||||
throw new NotImplementedException("Not implemented.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) {
|
||||
throw new NotImplementedException("Not implemented.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
LOGGER.error("Not implemented under SINGLE_PROCESS mode.");
|
||||
}
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
package org.ray.runtime.raylet;
|
||||
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
|
||||
/**
|
||||
* Raylet client for cluster mode. This is a wrapper class for C++ RayletClient.
|
||||
*/
|
||||
public class NativeRayletClient implements RayletClient {
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private long nativeCoreWorkerPointer = 0;
|
||||
|
||||
public NativeRayletClient(long nativeCoreWorkerPointer) {
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId prepareCheckpoint(ActorId actorId) {
|
||||
return new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer, actorId.getBytes()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) {
|
||||
nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, actorId.getBytes(),
|
||||
checkpointId.getBytes());
|
||||
}
|
||||
|
||||
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes());
|
||||
}
|
||||
|
||||
/// Native method declarations.
|
||||
///
|
||||
/// If you change the signature of any native methods, please re-generate
|
||||
/// the C++ header file and update the C++ implementation accordingly:
|
||||
///
|
||||
/// Suppose that $Dir is your ray root directory.
|
||||
/// 1) pushd $Dir/java/runtime/target/classes
|
||||
/// 2) javah -classpath .:$Dir/java/api/target/classes org.ray.runtime.raylet.NativeRayletClient
|
||||
/// 3) clang-format -i org_ray_runtime_raylet_NativeRayletClient.h
|
||||
/// 4) cp org_ray_runtime_raylet_NativeRayletClient.h $Dir/src/ray/core_worker/lib/java/
|
||||
/// 5) vim $Dir/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc
|
||||
/// 6) popd
|
||||
|
||||
private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId);
|
||||
|
||||
private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId,
|
||||
byte[] checkpointId);
|
||||
|
||||
private static native void nativeSetResource(long conn, String resourceName, double capacity,
|
||||
byte[] nodeId) throws RayException;
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package org.ray.runtime.raylet;
|
||||
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
|
||||
/**
|
||||
* Client to the Raylet backend.
|
||||
*/
|
||||
public interface RayletClient {
|
||||
|
||||
UniqueId prepareCheckpoint(ActorId actorId);
|
||||
|
||||
void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId);
|
||||
|
||||
void setResource(String resourceName, double capacity, UniqueId nodeId);
|
||||
}
|
||||
@@ -25,16 +25,20 @@ public class ArgumentsBuilder {
|
||||
/**
|
||||
* Convert real function arguments to task spec arguments.
|
||||
*/
|
||||
public static List<FunctionArg> wrap(Object[] args) {
|
||||
public static List<FunctionArg> wrap(Object[] args, boolean isDirectCall) {
|
||||
List<FunctionArg> ret = new ArrayList<>();
|
||||
for (Object arg : args) {
|
||||
ObjectId id = null;
|
||||
NativeRayObject value = null;
|
||||
if (arg instanceof RayObject) {
|
||||
if (isDirectCall) {
|
||||
throw new IllegalArgumentException(
|
||||
"Passing RayObject to a direct call actor is not supported.");
|
||||
}
|
||||
id = ((RayObject) arg).getId();
|
||||
} else {
|
||||
value = ObjectSerializer.serialize(arg);
|
||||
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
if (!isDirectCall && value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
RayRuntime runtime = Ray.internal();
|
||||
if (runtime instanceof RayMultiWorkerNativeRuntime) {
|
||||
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
|
||||
/**
|
||||
* Task executor for local mode.
|
||||
*/
|
||||
public class LocalModeTaskExecutor extends TaskExecutor {
|
||||
|
||||
public LocalModeTaskExecutor(AbstractRayRuntime runtime) {
|
||||
super(runtime);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void maybeSaveCheckpoint(Object actor, ActorId actorId) {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void maybeLoadCheckpoint(Object actor, ActorId actorId) {
|
||||
}
|
||||
}
|
||||
@@ -95,12 +95,12 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
if (task.getType() == TaskType.ACTOR_TASK) {
|
||||
taskExecutor = actorTaskExecutors.get(getActorId(task));
|
||||
} else if (task.getType() == TaskType.ACTOR_CREATION_TASK) {
|
||||
taskExecutor = new TaskExecutor(runtime);
|
||||
taskExecutor = new LocalModeTaskExecutor(runtime);
|
||||
actorTaskExecutors.put(getActorId(task), taskExecutor);
|
||||
} else if (idleTaskExecutors.size() > 0) {
|
||||
taskExecutor = idleTaskExecutors.pop();
|
||||
} else {
|
||||
taskExecutor = new TaskExecutor(runtime);
|
||||
taskExecutor = new LocalModeTaskExecutor(runtime);
|
||||
}
|
||||
}
|
||||
currentTaskExecutor.set(taskExecutor);
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.ray.api.Checkpointable;
|
||||
import org.ray.api.Checkpointable.Checkpoint;
|
||||
import org.ray.api.Checkpointable.CheckpointContext;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
|
||||
/**
|
||||
* Task executor for cluster mode.
|
||||
*/
|
||||
public class NativeTaskExecutor extends TaskExecutor {
|
||||
|
||||
// TODO(hchen): Use the C++ config.
|
||||
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private final long nativeCoreWorkerPointer;
|
||||
|
||||
/**
|
||||
* Number of tasks executed since last actor checkpoint.
|
||||
*/
|
||||
private int numTasksSinceLastCheckpoint = 0;
|
||||
|
||||
/**
|
||||
* IDs of this actor's previous checkpoints.
|
||||
*/
|
||||
private List<UniqueId> checkpointIds;
|
||||
|
||||
/**
|
||||
* Timestamp of the last actor checkpoint.
|
||||
*/
|
||||
private long lastCheckpointTimestamp = 0;
|
||||
|
||||
public NativeTaskExecutor(long nativeCoreWorkerPointer, AbstractRayRuntime runtime) {
|
||||
super(runtime);
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void maybeSaveCheckpoint(Object actor, ActorId actorId) {
|
||||
if (!(actor instanceof Checkpointable)) {
|
||||
return;
|
||||
}
|
||||
CheckpointContext checkpointContext = new CheckpointContext(actorId,
|
||||
++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp);
|
||||
Checkpointable checkpointable = (Checkpointable) actor;
|
||||
if (!checkpointable.shouldCheckpoint(checkpointContext)) {
|
||||
return;
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer));
|
||||
checkpointIds.add(checkpointId);
|
||||
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
|
||||
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
|
||||
checkpointIds.remove(0);
|
||||
}
|
||||
checkpointable.saveCheckpoint(actorId, checkpointId);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void maybeLoadCheckpoint(Object actor, ActorId actorId) {
|
||||
if (!(actor instanceof Checkpointable)) {
|
||||
return;
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
checkpointIds = new ArrayList<>();
|
||||
List<Checkpoint> availableCheckpoints
|
||||
= runtime.getGcsClient().getCheckpointsForActor(actorId);
|
||||
if (availableCheckpoints.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints);
|
||||
if (checkpointId != null) {
|
||||
boolean checkpointValid = false;
|
||||
for (Checkpoint checkpoint : availableCheckpoints) {
|
||||
if (checkpoint.checkpointId.equals(checkpointId)) {
|
||||
checkpointValid = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Preconditions.checkArgument(checkpointValid,
|
||||
"'loadCheckpoint' must return a checkpoint ID that exists in the "
|
||||
+ "'availableCheckpoints' list, or null.");
|
||||
|
||||
nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, checkpointId.getBytes());
|
||||
}
|
||||
}
|
||||
|
||||
private static native byte[] nativePrepareCheckpoint(long nativeCoreWorkerPointer);
|
||||
|
||||
private static native void nativeNotifyActorResumedFromCheckpoint(long nativeCoreWorkerPointer,
|
||||
byte[] checkpointId);
|
||||
}
|
||||
@@ -3,16 +3,11 @@ package org.ray.runtime.task;
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.ray.api.Checkpointable;
|
||||
import org.ray.api.Checkpointable.Checkpoint;
|
||||
import org.ray.api.Checkpointable.CheckpointContext;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
@@ -24,13 +19,10 @@ import org.slf4j.LoggerFactory;
|
||||
/**
|
||||
* The task executor, which executes tasks assigned by raylet continuously.
|
||||
*/
|
||||
public final class TaskExecutor {
|
||||
public abstract class TaskExecutor {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class);
|
||||
|
||||
// TODO(hchen): Use the C++ config.
|
||||
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;
|
||||
|
||||
protected final AbstractRayRuntime runtime;
|
||||
|
||||
/**
|
||||
@@ -43,22 +35,7 @@ public final class TaskExecutor {
|
||||
*/
|
||||
private Exception actorCreationException = null;
|
||||
|
||||
/**
|
||||
* Number of tasks executed since last actor checkpoint.
|
||||
*/
|
||||
private int numTasksSinceLastCheckpoint = 0;
|
||||
|
||||
/**
|
||||
* IDs of this actor's previous checkpoints.
|
||||
*/
|
||||
private List<UniqueId> checkpointIds;
|
||||
|
||||
/**
|
||||
* Timestamp of the last actor checkpoint.
|
||||
*/
|
||||
private long lastCheckpointTimestamp = 0;
|
||||
|
||||
public TaskExecutor(AbstractRayRuntime runtime) {
|
||||
protected TaskExecutor(AbstractRayRuntime runtime) {
|
||||
this.runtime = runtime;
|
||||
}
|
||||
|
||||
@@ -134,60 +111,7 @@ public final class TaskExecutor {
|
||||
rayFunctionInfo.get(2));
|
||||
}
|
||||
|
||||
private void maybeSaveCheckpoint(Object actor, ActorId actorId) {
|
||||
if (!(actor instanceof Checkpointable)) {
|
||||
return;
|
||||
}
|
||||
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
|
||||
// Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet.
|
||||
return;
|
||||
}
|
||||
CheckpointContext checkpointContext = new CheckpointContext(actorId,
|
||||
++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp);
|
||||
Checkpointable checkpointable = (Checkpointable) actor;
|
||||
if (!checkpointable.shouldCheckpoint(checkpointContext)) {
|
||||
return;
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
UniqueId checkpointId = runtime.getRayletClient().prepareCheckpoint(actorId);
|
||||
checkpointIds.add(checkpointId);
|
||||
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
|
||||
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
|
||||
checkpointIds.remove(0);
|
||||
}
|
||||
checkpointable.saveCheckpoint(actorId, checkpointId);
|
||||
}
|
||||
protected abstract void maybeSaveCheckpoint(Object actor, ActorId actorId);
|
||||
|
||||
private void maybeLoadCheckpoint(Object actor, ActorId actorId) {
|
||||
if (!(actor instanceof Checkpointable)) {
|
||||
return;
|
||||
}
|
||||
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
|
||||
// Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet.
|
||||
return;
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
checkpointIds = new ArrayList<>();
|
||||
List<Checkpoint> availableCheckpoints
|
||||
= runtime.getGcsClient().getCheckpointsForActor(actorId);
|
||||
if (availableCheckpoints.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints);
|
||||
if (checkpointId != null) {
|
||||
boolean checkpointValid = false;
|
||||
for (Checkpoint checkpoint : availableCheckpoints) {
|
||||
if (checkpoint.checkpointId.equals(checkpointId)) {
|
||||
checkpointValid = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Preconditions.checkArgument(checkpointValid,
|
||||
"'loadCheckpoint' must return a checkpoint ID that exists in the "
|
||||
+ "'availableCheckpoints' list, or null.");
|
||||
runtime.getRayletClient().notifyActorResumedFromCheckpoint(actorId, checkpointId);
|
||||
}
|
||||
}
|
||||
protected abstract void maybeLoadCheckpoint(Object actor, ActorId actorId);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user