mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 18:44:07 +08:00
[Java] Support direct actor call in Java worker (#5504)
This commit is contained in:
@@ -10,26 +10,33 @@ public class ActorCreationOptions extends BaseTaskOptions {
|
||||
|
||||
public static final int NO_RECONSTRUCTION = 0;
|
||||
public static final int INFINITE_RECONSTRUCTIONS = (int) Math.pow(2, 30);
|
||||
// DO NOT set this environment variable. It's only used for test purposes.
|
||||
// Please use `setUseDirectCall` instead.
|
||||
public static final boolean DEFAULT_USE_DIRECT_CALL = "1"
|
||||
.equals(System.getenv("ACTOR_CREATION_OPTIONS_DEFAULT_USE_DIRECT_CALL"));
|
||||
|
||||
public final int maxReconstructions;
|
||||
|
||||
public final boolean useDirectCall;
|
||||
|
||||
public final String jvmOptions;
|
||||
|
||||
private ActorCreationOptions(Map<String, Double> resources,
|
||||
int maxReconstructions,
|
||||
String jvmOptions) {
|
||||
private ActorCreationOptions(Map<String, Double> resources, int maxReconstructions,
|
||||
boolean useDirectCall, String jvmOptions) {
|
||||
super(resources);
|
||||
this.maxReconstructions = maxReconstructions;
|
||||
this.useDirectCall = useDirectCall;
|
||||
this.jvmOptions = jvmOptions;
|
||||
}
|
||||
|
||||
/**
|
||||
* The inner class for building ActorCreationOptions.
|
||||
* The inner class for building ActorCreationOptions.
|
||||
*/
|
||||
public static class Builder {
|
||||
|
||||
private Map<String, Double> resources = new HashMap<>();
|
||||
private int maxReconstructions = NO_RECONSTRUCTION;
|
||||
private boolean useDirectCall = DEFAULT_USE_DIRECT_CALL;
|
||||
private String jvmOptions = null;
|
||||
|
||||
public Builder setResources(Map<String, Double> resources) {
|
||||
@@ -42,13 +49,21 @@ public class ActorCreationOptions extends BaseTaskOptions {
|
||||
return this;
|
||||
}
|
||||
|
||||
// Since direct call is not fully supported yet (see issue #5559),
|
||||
// users are not allowed to set the option to true.
|
||||
// TODO (kfstorm): uncomment when direct call is ready.
|
||||
// public Builder setUseDirectCall(boolean useDirectCall) {
|
||||
// this.useDirectCall = useDirectCall;
|
||||
// return this;
|
||||
// }
|
||||
|
||||
public Builder setJvmOptions(String jvmOptions) {
|
||||
this.jvmOptions = jvmOptions;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ActorCreationOptions createActorCreationOptions() {
|
||||
return new ActorCreationOptions(resources, maxReconstructions, jvmOptions);
|
||||
return new ActorCreationOptions(resources, maxReconstructions, useDirectCall, jvmOptions);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ def gen_java_deps():
|
||||
"org.apache.commons:commons-lang3:3.4",
|
||||
"org.ow2.asm:asm:6.0",
|
||||
"org.slf4j:slf4j-log4j12:1.7.25",
|
||||
"org.testng:testng:6.9.9",
|
||||
"org.testng:testng:6.9.10",
|
||||
"redis.clients:jedis:2.8.0",
|
||||
],
|
||||
repositories = [
|
||||
|
||||
@@ -75,7 +75,7 @@
|
||||
<dependency>
|
||||
<groupId>org.testng</groupId>
|
||||
<artifactId>testng</artifactId>
|
||||
<version>6.9.9</version>
|
||||
<version>6.9.10</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>redis.clients</groupId>
|
||||
|
||||
@@ -13,11 +13,11 @@ import org.ray.api.exception.RayException;
|
||||
import org.ray.api.function.RayFunc;
|
||||
import org.ray.api.function.RayFuncVoid;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.api.runtimecontext.RuntimeContext;
|
||||
import org.ray.runtime.actor.NativeRayActor;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.RuntimeContextImpl;
|
||||
import org.ray.runtime.context.WorkerContext;
|
||||
@@ -28,7 +28,6 @@ import org.ray.runtime.gcs.GcsClient;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
import org.ray.runtime.object.RayObjectImpl;
|
||||
import org.ray.runtime.raylet.RayletClient;
|
||||
import org.ray.runtime.task.ArgumentsBuilder;
|
||||
import org.ray.runtime.task.FunctionArg;
|
||||
import org.ray.runtime.task.TaskExecutor;
|
||||
@@ -51,7 +50,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
protected ObjectStore objectStore;
|
||||
protected TaskSubmitter taskSubmitter;
|
||||
protected RayletClient rayletClient;
|
||||
protected WorkerContext workerContext;
|
||||
|
||||
public AbstractRayRuntime(RayConfig rayConfig) {
|
||||
@@ -85,15 +83,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
objectStore.delete(objectIds, localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
Preconditions.checkArgument(Double.compare(capacity, 0) >= 0);
|
||||
if (nodeId == null) {
|
||||
nodeId = UniqueId.NIL;
|
||||
}
|
||||
rayletClient.setResource(resourceName, capacity, nodeId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
|
||||
return objectStore.wait(waitList, numReturns, timeoutMs);
|
||||
@@ -176,7 +165,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
|
||||
Object[] args, int numReturns, CallOptions options) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, /*isDirectCall*/false);
|
||||
List<ObjectId> returnIds = taskSubmitter.submitTask(functionDescriptor,
|
||||
functionArgs, numReturns, options);
|
||||
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
|
||||
@@ -189,7 +178,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
private RayObject callActorFunction(RayActor rayActor,
|
||||
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, isDirectCall(rayActor));
|
||||
List<ObjectId> returnIds = taskSubmitter.submitActorTask(rayActor,
|
||||
functionDescriptor, functionArgs, numReturns, null);
|
||||
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
|
||||
@@ -202,7 +191,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
private RayActor createActorImpl(FunctionDescriptor functionDescriptor,
|
||||
Object[] args, ActorCreationOptions options) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, /*isDirectCall*/false);
|
||||
if (functionDescriptor.getLanguage() != Language.JAVA && options != null) {
|
||||
Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions));
|
||||
}
|
||||
@@ -210,6 +199,13 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
return actor;
|
||||
}
|
||||
|
||||
private boolean isDirectCall(RayActor rayActor) {
|
||||
if (rayActor instanceof NativeRayActor) {
|
||||
return ((NativeRayActor) rayActor).isDirectCallActor();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public WorkerContext getWorkerContext() {
|
||||
return workerContext;
|
||||
}
|
||||
@@ -218,10 +214,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
return objectStore;
|
||||
}
|
||||
|
||||
public RayletClient getRayletClient() {
|
||||
return rayletClient;
|
||||
}
|
||||
|
||||
public FunctionManager getFunctionManager() {
|
||||
return functionManager;
|
||||
}
|
||||
|
||||
@@ -2,15 +2,19 @@ package org.ray.runtime;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.LocalModeWorkerContext;
|
||||
import org.ray.runtime.object.LocalModeObjectStore;
|
||||
import org.ray.runtime.raylet.LocalModeRayletClient;
|
||||
import org.ray.runtime.task.LocalModeTaskExecutor;
|
||||
import org.ray.runtime.task.LocalModeTaskSubmitter;
|
||||
import org.ray.runtime.task.TaskExecutor;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class RayDevRuntime extends AbstractRayRuntime {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayDevRuntime.class);
|
||||
|
||||
private AtomicInteger jobCounter = new AtomicInteger(0);
|
||||
|
||||
public RayDevRuntime(RayConfig rayConfig) {
|
||||
@@ -18,14 +22,13 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
||||
if (rayConfig.getJobId().isNil()) {
|
||||
rayConfig.setJobId(nextJobId());
|
||||
}
|
||||
taskExecutor = new TaskExecutor(this);
|
||||
taskExecutor = new LocalModeTaskExecutor(this);
|
||||
workerContext = new LocalModeWorkerContext(rayConfig.getJobId());
|
||||
objectStore = new LocalModeObjectStore(workerContext);
|
||||
taskSubmitter = new LocalModeTaskSubmitter(this, (LocalModeObjectStore) objectStore,
|
||||
rayConfig.numberExecThreadsForDevRuntime);
|
||||
((LocalModeObjectStore) objectStore).addObjectPutCallback(
|
||||
objectId -> ((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId));
|
||||
rayletClient = new LocalModeRayletClient();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -33,6 +36,11 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
||||
taskExecutor = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
LOGGER.error("Not implemented under SINGLE_PROCESS mode.");
|
||||
}
|
||||
|
||||
private JobId nextJobId() {
|
||||
return JobId.fromInt(jobCounter.getAndIncrement());
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.NativeWorkerContext;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
@@ -16,8 +17,8 @@ import org.ray.runtime.gcs.GcsClientOptions;
|
||||
import org.ray.runtime.gcs.RedisClient;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.ray.runtime.object.NativeObjectStore;
|
||||
import org.ray.runtime.raylet.NativeRayletClient;
|
||||
import org.ray.runtime.runner.RunManager;
|
||||
import org.ray.runtime.task.NativeTaskExecutor;
|
||||
import org.ray.runtime.task.NativeTaskSubmitter;
|
||||
import org.ray.runtime.task.TaskExecutor;
|
||||
import org.ray.runtime.util.FileUtil;
|
||||
@@ -112,11 +113,10 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
new GcsClientOptions(rayConfig));
|
||||
Preconditions.checkState(nativeCoreWorkerPointer != 0);
|
||||
|
||||
taskExecutor = new TaskExecutor(this);
|
||||
taskExecutor = new NativeTaskExecutor(nativeCoreWorkerPointer, this);
|
||||
workerContext = new NativeWorkerContext(nativeCoreWorkerPointer);
|
||||
objectStore = new NativeObjectStore(workerContext, nativeCoreWorkerPointer);
|
||||
taskSubmitter = new NativeTaskSubmitter(nativeCoreWorkerPointer);
|
||||
rayletClient = new NativeRayletClient(nativeCoreWorkerPointer);
|
||||
|
||||
// register
|
||||
registerWorker();
|
||||
@@ -136,6 +136,15 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
Preconditions.checkArgument(Double.compare(capacity, 0) >= 0);
|
||||
if (nodeId == null) {
|
||||
nodeId = UniqueId.NIL;
|
||||
}
|
||||
nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes());
|
||||
}
|
||||
|
||||
public void run() {
|
||||
nativeRunTaskExecutor(nativeCoreWorkerPointer, taskExecutor);
|
||||
}
|
||||
@@ -176,4 +185,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
private static native void nativeSetup(String logDir);
|
||||
|
||||
private static native void nativeShutdownHook();
|
||||
|
||||
private static native void nativeSetResource(long conn, String resourceName, double capacity,
|
||||
byte[] nodeId);
|
||||
}
|
||||
|
||||
@@ -51,6 +51,10 @@ public class NativeRayActor implements RayActor, RayPyActor, Externalizable {
|
||||
return Language.forNumber(nativeGetLanguage(nativeActorHandle));
|
||||
}
|
||||
|
||||
public boolean isDirectCallActor() {
|
||||
return nativeIsDirectCallActor(nativeActorHandle);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModuleName() {
|
||||
Preconditions.checkState(getLanguage() == Language.PYTHON);
|
||||
@@ -90,6 +94,8 @@ public class NativeRayActor implements RayActor, RayPyActor, Externalizable {
|
||||
|
||||
private static native int nativeGetLanguage(long nativeActorHandle);
|
||||
|
||||
private static native boolean nativeIsDirectCallActor(long nativeActorHandle);
|
||||
|
||||
private static native List<String> nativeGetActorCreationTaskFunctionDescriptor(
|
||||
long nativeActorHandle);
|
||||
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
package org.ray.runtime.raylet;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Raylet client for local mode.
|
||||
*/
|
||||
public class LocalModeRayletClient implements RayletClient {
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeRayletClient.class);
|
||||
|
||||
@Override
|
||||
public UniqueId prepareCheckpoint(ActorId actorId) {
|
||||
throw new NotImplementedException("Not implemented.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) {
|
||||
throw new NotImplementedException("Not implemented.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
LOGGER.error("Not implemented under SINGLE_PROCESS mode.");
|
||||
}
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
package org.ray.runtime.raylet;
|
||||
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
|
||||
/**
|
||||
* Raylet client for cluster mode. This is a wrapper class for C++ RayletClient.
|
||||
*/
|
||||
public class NativeRayletClient implements RayletClient {
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private long nativeCoreWorkerPointer = 0;
|
||||
|
||||
public NativeRayletClient(long nativeCoreWorkerPointer) {
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId prepareCheckpoint(ActorId actorId) {
|
||||
return new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer, actorId.getBytes()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) {
|
||||
nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, actorId.getBytes(),
|
||||
checkpointId.getBytes());
|
||||
}
|
||||
|
||||
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes());
|
||||
}
|
||||
|
||||
/// Native method declarations.
|
||||
///
|
||||
/// If you change the signature of any native methods, please re-generate
|
||||
/// the C++ header file and update the C++ implementation accordingly:
|
||||
///
|
||||
/// Suppose that $Dir is your ray root directory.
|
||||
/// 1) pushd $Dir/java/runtime/target/classes
|
||||
/// 2) javah -classpath .:$Dir/java/api/target/classes org.ray.runtime.raylet.NativeRayletClient
|
||||
/// 3) clang-format -i org_ray_runtime_raylet_NativeRayletClient.h
|
||||
/// 4) cp org_ray_runtime_raylet_NativeRayletClient.h $Dir/src/ray/core_worker/lib/java/
|
||||
/// 5) vim $Dir/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc
|
||||
/// 6) popd
|
||||
|
||||
private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId);
|
||||
|
||||
private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId,
|
||||
byte[] checkpointId);
|
||||
|
||||
private static native void nativeSetResource(long conn, String resourceName, double capacity,
|
||||
byte[] nodeId) throws RayException;
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package org.ray.runtime.raylet;
|
||||
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
|
||||
/**
|
||||
* Client to the Raylet backend.
|
||||
*/
|
||||
public interface RayletClient {
|
||||
|
||||
UniqueId prepareCheckpoint(ActorId actorId);
|
||||
|
||||
void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId);
|
||||
|
||||
void setResource(String resourceName, double capacity, UniqueId nodeId);
|
||||
}
|
||||
@@ -25,16 +25,20 @@ public class ArgumentsBuilder {
|
||||
/**
|
||||
* Convert real function arguments to task spec arguments.
|
||||
*/
|
||||
public static List<FunctionArg> wrap(Object[] args) {
|
||||
public static List<FunctionArg> wrap(Object[] args, boolean isDirectCall) {
|
||||
List<FunctionArg> ret = new ArrayList<>();
|
||||
for (Object arg : args) {
|
||||
ObjectId id = null;
|
||||
NativeRayObject value = null;
|
||||
if (arg instanceof RayObject) {
|
||||
if (isDirectCall) {
|
||||
throw new IllegalArgumentException(
|
||||
"Passing RayObject to a direct call actor is not supported.");
|
||||
}
|
||||
id = ((RayObject) arg).getId();
|
||||
} else {
|
||||
value = ObjectSerializer.serialize(arg);
|
||||
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
if (!isDirectCall && value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
RayRuntime runtime = Ray.internal();
|
||||
if (runtime instanceof RayMultiWorkerNativeRuntime) {
|
||||
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
|
||||
/**
|
||||
* Task executor for local mode.
|
||||
*/
|
||||
public class LocalModeTaskExecutor extends TaskExecutor {
|
||||
|
||||
public LocalModeTaskExecutor(AbstractRayRuntime runtime) {
|
||||
super(runtime);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void maybeSaveCheckpoint(Object actor, ActorId actorId) {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void maybeLoadCheckpoint(Object actor, ActorId actorId) {
|
||||
}
|
||||
}
|
||||
@@ -95,12 +95,12 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
if (task.getType() == TaskType.ACTOR_TASK) {
|
||||
taskExecutor = actorTaskExecutors.get(getActorId(task));
|
||||
} else if (task.getType() == TaskType.ACTOR_CREATION_TASK) {
|
||||
taskExecutor = new TaskExecutor(runtime);
|
||||
taskExecutor = new LocalModeTaskExecutor(runtime);
|
||||
actorTaskExecutors.put(getActorId(task), taskExecutor);
|
||||
} else if (idleTaskExecutors.size() > 0) {
|
||||
taskExecutor = idleTaskExecutors.pop();
|
||||
} else {
|
||||
taskExecutor = new TaskExecutor(runtime);
|
||||
taskExecutor = new LocalModeTaskExecutor(runtime);
|
||||
}
|
||||
}
|
||||
currentTaskExecutor.set(taskExecutor);
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.ray.api.Checkpointable;
|
||||
import org.ray.api.Checkpointable.Checkpoint;
|
||||
import org.ray.api.Checkpointable.CheckpointContext;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
|
||||
/**
|
||||
* Task executor for cluster mode.
|
||||
*/
|
||||
public class NativeTaskExecutor extends TaskExecutor {
|
||||
|
||||
// TODO(hchen): Use the C++ config.
|
||||
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private final long nativeCoreWorkerPointer;
|
||||
|
||||
/**
|
||||
* Number of tasks executed since last actor checkpoint.
|
||||
*/
|
||||
private int numTasksSinceLastCheckpoint = 0;
|
||||
|
||||
/**
|
||||
* IDs of this actor's previous checkpoints.
|
||||
*/
|
||||
private List<UniqueId> checkpointIds;
|
||||
|
||||
/**
|
||||
* Timestamp of the last actor checkpoint.
|
||||
*/
|
||||
private long lastCheckpointTimestamp = 0;
|
||||
|
||||
public NativeTaskExecutor(long nativeCoreWorkerPointer, AbstractRayRuntime runtime) {
|
||||
super(runtime);
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void maybeSaveCheckpoint(Object actor, ActorId actorId) {
|
||||
if (!(actor instanceof Checkpointable)) {
|
||||
return;
|
||||
}
|
||||
CheckpointContext checkpointContext = new CheckpointContext(actorId,
|
||||
++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp);
|
||||
Checkpointable checkpointable = (Checkpointable) actor;
|
||||
if (!checkpointable.shouldCheckpoint(checkpointContext)) {
|
||||
return;
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer));
|
||||
checkpointIds.add(checkpointId);
|
||||
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
|
||||
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
|
||||
checkpointIds.remove(0);
|
||||
}
|
||||
checkpointable.saveCheckpoint(actorId, checkpointId);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void maybeLoadCheckpoint(Object actor, ActorId actorId) {
|
||||
if (!(actor instanceof Checkpointable)) {
|
||||
return;
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
checkpointIds = new ArrayList<>();
|
||||
List<Checkpoint> availableCheckpoints
|
||||
= runtime.getGcsClient().getCheckpointsForActor(actorId);
|
||||
if (availableCheckpoints.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints);
|
||||
if (checkpointId != null) {
|
||||
boolean checkpointValid = false;
|
||||
for (Checkpoint checkpoint : availableCheckpoints) {
|
||||
if (checkpoint.checkpointId.equals(checkpointId)) {
|
||||
checkpointValid = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Preconditions.checkArgument(checkpointValid,
|
||||
"'loadCheckpoint' must return a checkpoint ID that exists in the "
|
||||
+ "'availableCheckpoints' list, or null.");
|
||||
|
||||
nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, checkpointId.getBytes());
|
||||
}
|
||||
}
|
||||
|
||||
private static native byte[] nativePrepareCheckpoint(long nativeCoreWorkerPointer);
|
||||
|
||||
private static native void nativeNotifyActorResumedFromCheckpoint(long nativeCoreWorkerPointer,
|
||||
byte[] checkpointId);
|
||||
}
|
||||
@@ -3,16 +3,11 @@ package org.ray.runtime.task;
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.ray.api.Checkpointable;
|
||||
import org.ray.api.Checkpointable.Checkpoint;
|
||||
import org.ray.api.Checkpointable.CheckpointContext;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
@@ -24,13 +19,10 @@ import org.slf4j.LoggerFactory;
|
||||
/**
|
||||
* The task executor, which executes tasks assigned by raylet continuously.
|
||||
*/
|
||||
public final class TaskExecutor {
|
||||
public abstract class TaskExecutor {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class);
|
||||
|
||||
// TODO(hchen): Use the C++ config.
|
||||
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;
|
||||
|
||||
protected final AbstractRayRuntime runtime;
|
||||
|
||||
/**
|
||||
@@ -43,22 +35,7 @@ public final class TaskExecutor {
|
||||
*/
|
||||
private Exception actorCreationException = null;
|
||||
|
||||
/**
|
||||
* Number of tasks executed since last actor checkpoint.
|
||||
*/
|
||||
private int numTasksSinceLastCheckpoint = 0;
|
||||
|
||||
/**
|
||||
* IDs of this actor's previous checkpoints.
|
||||
*/
|
||||
private List<UniqueId> checkpointIds;
|
||||
|
||||
/**
|
||||
* Timestamp of the last actor checkpoint.
|
||||
*/
|
||||
private long lastCheckpointTimestamp = 0;
|
||||
|
||||
public TaskExecutor(AbstractRayRuntime runtime) {
|
||||
protected TaskExecutor(AbstractRayRuntime runtime) {
|
||||
this.runtime = runtime;
|
||||
}
|
||||
|
||||
@@ -134,60 +111,7 @@ public final class TaskExecutor {
|
||||
rayFunctionInfo.get(2));
|
||||
}
|
||||
|
||||
private void maybeSaveCheckpoint(Object actor, ActorId actorId) {
|
||||
if (!(actor instanceof Checkpointable)) {
|
||||
return;
|
||||
}
|
||||
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
|
||||
// Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet.
|
||||
return;
|
||||
}
|
||||
CheckpointContext checkpointContext = new CheckpointContext(actorId,
|
||||
++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp);
|
||||
Checkpointable checkpointable = (Checkpointable) actor;
|
||||
if (!checkpointable.shouldCheckpoint(checkpointContext)) {
|
||||
return;
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
UniqueId checkpointId = runtime.getRayletClient().prepareCheckpoint(actorId);
|
||||
checkpointIds.add(checkpointId);
|
||||
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
|
||||
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
|
||||
checkpointIds.remove(0);
|
||||
}
|
||||
checkpointable.saveCheckpoint(actorId, checkpointId);
|
||||
}
|
||||
protected abstract void maybeSaveCheckpoint(Object actor, ActorId actorId);
|
||||
|
||||
private void maybeLoadCheckpoint(Object actor, ActorId actorId) {
|
||||
if (!(actor instanceof Checkpointable)) {
|
||||
return;
|
||||
}
|
||||
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
|
||||
// Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet.
|
||||
return;
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
checkpointIds = new ArrayList<>();
|
||||
List<Checkpoint> availableCheckpoints
|
||||
= runtime.getGcsClient().getCheckpointsForActor(actorId);
|
||||
if (availableCheckpoints.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints);
|
||||
if (checkpointId != null) {
|
||||
boolean checkpointValid = false;
|
||||
for (Checkpoint checkpoint : availableCheckpoints) {
|
||||
if (checkpoint.checkpointId.equals(checkpointId)) {
|
||||
checkpointValid = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Preconditions.checkArgument(checkpointValid,
|
||||
"'loadCheckpoint' must return a checkpoint ID that exists in the "
|
||||
+ "'availableCheckpoints' list, or null.");
|
||||
runtime.getRayletClient().notifyActorResumedFromCheckpoint(actorId, checkpointId);
|
||||
}
|
||||
}
|
||||
protected abstract void maybeLoadCheckpoint(Object actor, ActorId actorId);
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@
|
||||
<dependency>
|
||||
<groupId>org.testng</groupId>
|
||||
<artifactId>testng</artifactId>
|
||||
<version>6.9.9</version>
|
||||
<version>6.9.10</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@@ -27,6 +27,9 @@ echo "Running tests under cluster mode."
|
||||
# bazel test //java:all_tests --action_env=ENABLE_MULTI_LANGUAGE_TESTS=1 --test_output="errors" || cluster_exit_code=$?
|
||||
ENABLE_MULTI_LANGUAGE_TESTS=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml
|
||||
|
||||
echo "Running tests under cluster mode with direct actor call turned on."
|
||||
ENABLE_MULTI_LANGUAGE_TESTS=1 ACTOR_CREATION_OPTIONS_DEFAULT_USE_DIRECT_CALL=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml
|
||||
|
||||
echo "Running tests under single-process mode."
|
||||
# bazel test //java:all_tests --jvmopt="-Dray.run-mode=SINGLE_PROCESS" --test_output="errors" || single_exit_code=$?
|
||||
run_testng java -Dray.run-mode="SINGLE_PROCESS" -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml
|
||||
|
||||
+1
-1
@@ -65,7 +65,7 @@
|
||||
<dependency>
|
||||
<groupId>org.testng</groupId>
|
||||
<artifactId>testng</artifactId>
|
||||
<version>6.9.9</version>
|
||||
<version>6.9.10</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<build>
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package org.ray.api;
|
||||
|
||||
import java.util.List;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.testng.IAlterSuiteListener;
|
||||
import org.testng.xml.XmlGroups;
|
||||
import org.testng.xml.XmlRun;
|
||||
import org.testng.xml.XmlSuite;
|
||||
|
||||
public class RayAlterSuiteListener implements IAlterSuiteListener {
|
||||
|
||||
@Override
|
||||
public void alter(List<XmlSuite> suites) {
|
||||
XmlSuite suite = suites.get(0);
|
||||
if (ActorCreationOptions.DEFAULT_USE_DIRECT_CALL) {
|
||||
XmlGroups groups = new XmlGroups();
|
||||
XmlRun run = new XmlRun();
|
||||
run.onInclude("directCall");
|
||||
groups.setRun(run);
|
||||
suite.setGroups(groups);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
package org.ray.api;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.io.Serializable;
|
||||
import java.util.function.Supplier;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
@@ -12,6 +14,11 @@ import org.testng.SkipException;
|
||||
|
||||
public class TestUtils {
|
||||
|
||||
public static class LargeObject implements Serializable {
|
||||
|
||||
public byte[] data = new byte[1024 * 1024];
|
||||
}
|
||||
|
||||
private static final int WAIT_INTERVAL_MS = 5;
|
||||
|
||||
public static void skipTestUnderSingleProcess() {
|
||||
@@ -20,6 +27,12 @@ public class TestUtils {
|
||||
}
|
||||
}
|
||||
|
||||
public static void skipTestIfDirectActorCallEnabled() {
|
||||
if (ActorCreationOptions.DEFAULT_USE_DIRECT_CALL) {
|
||||
throw new SkipException("This test doesn't work when direct actor call is enabled.");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait until the given condition is met.
|
||||
*
|
||||
|
||||
@@ -17,6 +17,7 @@ import org.ray.api.options.ActorCreationOptions;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@Test(groups = {"directCall"})
|
||||
public class ActorReconstructionTest extends BaseTest {
|
||||
|
||||
@RayRemote()
|
||||
@@ -44,7 +45,6 @@ public class ActorReconstructionTest extends BaseTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testActorReconstruction() throws InterruptedException, IOException {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
ActorCreationOptions options =
|
||||
@@ -65,7 +65,7 @@ public class ActorReconstructionTest extends BaseTest {
|
||||
|
||||
// Try calling increase on this actor again and check the value is now 4.
|
||||
int value = Ray.call(Counter::increase, actor).get();
|
||||
Assert.assertEquals(value, 4);
|
||||
Assert.assertEquals(value, options.useDirectCall ? 1 : 4);
|
||||
|
||||
Assert.assertTrue(Ray.call(Counter::wasCurrentActorReconstructed, actor).get());
|
||||
|
||||
@@ -125,7 +125,6 @@ public class ActorReconstructionTest extends BaseTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testActorCheckpointing() throws IOException, InterruptedException {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
ActorCreationOptions options =
|
||||
|
||||
@@ -8,6 +8,7 @@ import org.ray.api.Ray;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.TestUtils.LargeObject;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.exception.UnreconstructableException;
|
||||
import org.ray.api.id.UniqueId;
|
||||
@@ -16,6 +17,7 @@ import org.ray.runtime.object.NativeRayObject;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@Test(groups = {"directCall"})
|
||||
public class ActorTest extends BaseTest {
|
||||
|
||||
@RayRemote
|
||||
@@ -39,9 +41,13 @@ public class ActorTest extends BaseTest {
|
||||
value += delta;
|
||||
return value;
|
||||
}
|
||||
|
||||
public int accessLargeObject(LargeObject largeObject) {
|
||||
value += largeObject.data.length;
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateAndCallActor() {
|
||||
// Test creating an actor from a constructor
|
||||
RayActor<Counter> actor = Ray.createActor(Counter::new, 1);
|
||||
@@ -52,12 +58,18 @@ public class ActorTest extends BaseTest {
|
||||
Assert.assertEquals(Integer.valueOf(3), Ray.call(Counter::increaseAndGet, actor, 1).get());
|
||||
}
|
||||
|
||||
public void testCallActorWithLargeObject() {
|
||||
RayActor<Counter> actor = Ray.createActor(Counter::new, 1);
|
||||
LargeObject largeObject = new LargeObject();
|
||||
Assert.assertEquals(Integer.valueOf(largeObject.data.length + 1),
|
||||
Ray.call(Counter::accessLargeObject, actor, largeObject).get());
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
public static Counter factory(int initValue) {
|
||||
static Counter factory(int initValue) {
|
||||
return new Counter(initValue);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateActorFromFactory() {
|
||||
// Test creating an actor from a factory method
|
||||
RayActor<Counter> actor = Ray.createActor(ActorTest::factory, 1);
|
||||
@@ -67,24 +79,23 @@ public class ActorTest extends BaseTest {
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
public static int testActorAsFirstParameter(RayActor<Counter> actor, int delta) {
|
||||
static int testActorAsFirstParameter(RayActor<Counter> actor, int delta) {
|
||||
RayObject<Integer> res = Ray.call(Counter::increaseAndGet, actor, delta);
|
||||
return res.get();
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
public static int testActorAsSecondParameter(int delta, RayActor<Counter> actor) {
|
||||
static int testActorAsSecondParameter(int delta, RayActor<Counter> actor) {
|
||||
RayObject<Integer> res = Ray.call(Counter::increaseAndGet, actor, delta);
|
||||
return res.get();
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
public static int testActorAsFieldOfParameter(List<RayActor<Counter>> actor, int delta) {
|
||||
static int testActorAsFieldOfParameter(List<RayActor<Counter>> actor, int delta) {
|
||||
RayObject<Integer> res = Ray.call(Counter::increaseAndGet, actor.get(0), delta);
|
||||
return res.get();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPassActorAsParameter() {
|
||||
RayActor<Counter> actor = Ray.createActor(Counter::new, 0);
|
||||
Assert.assertEquals(Integer.valueOf(1),
|
||||
@@ -96,7 +107,6 @@ public class ActorTest extends BaseTest {
|
||||
.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testForkingActorHandle() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
RayActor<Counter> counter = Ray.createActor(Counter::new, 100);
|
||||
@@ -105,9 +115,11 @@ public class ActorTest extends BaseTest {
|
||||
Assert.assertEquals(Integer.valueOf(103), Ray.call(Counter::increaseAndGet, counter2, 2).get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnreconstructableActorObject() throws InterruptedException {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
// The UnreconstructableException is created by raylet.
|
||||
// TODO (kfstorm): This should be supported by direct actor call.
|
||||
TestUtils.skipTestIfDirectActorCallEnabled();
|
||||
RayActor<Counter> counter = Ray.createActor(Counter::new, 100);
|
||||
// Call an actor method.
|
||||
RayObject value = Ray.call(Counter::getValue, counter);
|
||||
|
||||
@@ -45,7 +45,7 @@ public abstract class BaseMultiLanguageTest {
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
@BeforeClass(alwaysRun = true)
|
||||
public void setUp() {
|
||||
if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) {
|
||||
LOGGER.info("Skip Multi-language tests because environment variable "
|
||||
@@ -100,7 +100,7 @@ public abstract class BaseMultiLanguageTest {
|
||||
return ImmutableMap.of();
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
@AfterClass(alwaysRun = true)
|
||||
public void tearDown() {
|
||||
// Disconnect to the cluster.
|
||||
Ray.shutdown();
|
||||
|
||||
@@ -16,7 +16,7 @@ public class BaseTest {
|
||||
|
||||
private List<File> filesToDelete;
|
||||
|
||||
@BeforeMethod
|
||||
@BeforeMethod(alwaysRun = true)
|
||||
public void setUpBase(Method method) {
|
||||
LOGGER.info("===== Running test: "
|
||||
+ method.getDeclaringClass().getName() + "." + method.getName());
|
||||
@@ -34,7 +34,7 @@ public class BaseTest {
|
||||
filesToDelete.forEach(File::deleteOnExit);
|
||||
}
|
||||
|
||||
@AfterMethod
|
||||
@AfterMethod(alwaysRun = true)
|
||||
public void tearDownBase() {
|
||||
Ray.shutdown();
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import org.apache.commons.io.FileUtils;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@@ -45,8 +46,10 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
|
||||
Assert.assertEquals(res.get(), "Response from Python: hello".getBytes());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(groups = {"directCall"})
|
||||
public void testCallingPythonActor() {
|
||||
// Python worker doesn't support direct call yet.
|
||||
TestUtils.skipTestIfDirectActorCallEnabled();
|
||||
RayPyActor actor = Ray.createPyActor(PYTHON_MODULE, "Counter", "1".getBytes());
|
||||
RayObject res = Ray.callPy(actor, "increase", "1".getBytes());
|
||||
Assert.assertEquals(res.get(), "2".getBytes());
|
||||
|
||||
@@ -76,14 +76,14 @@ public class FailureTest extends BaseTest {
|
||||
assertTaskFailedWithRayTaskException(Ray.call(FailureTest::badFunc));
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(groups = {"directCall"})
|
||||
public void testActorCreationFailure() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
RayActor<BadActor> actor = Ray.createActor(BadActor::new, true);
|
||||
assertTaskFailedWithRayTaskException(Ray.call(BadActor::badMethod, actor));
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(groups = {"directCall"})
|
||||
public void testActorTaskFailure() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
RayActor<BadActor> actor = Ray.createActor(BadActor::new, false);
|
||||
@@ -102,9 +102,12 @@ public class FailureTest extends BaseTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(groups = {"directCall"})
|
||||
public void testActorProcessDying() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
// This test case hangs if the worker to worker connection is implemented with grpc.
|
||||
// TODO (kfstorm): Should be fixed.
|
||||
TestUtils.skipTestIfDirectActorCallEnabled();
|
||||
RayActor<BadActor> actor = Ray.createActor(BadActor::new, false);
|
||||
try {
|
||||
Ray.call(BadActor::badMethod2, actor).get();
|
||||
|
||||
@@ -21,7 +21,7 @@ import org.slf4j.LoggerFactory;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
|
||||
@Test(groups = {"directCall"})
|
||||
public class MultiThreadingTest extends BaseTest {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(MultiThreadingTest.class);
|
||||
@@ -30,7 +30,7 @@ public class MultiThreadingTest extends BaseTest {
|
||||
private static final int NUM_THREADS = 20;
|
||||
|
||||
@RayRemote
|
||||
public static Integer echo(int num) {
|
||||
static Integer echo(int num) {
|
||||
return num;
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ public class MultiThreadingTest extends BaseTest {
|
||||
}
|
||||
}
|
||||
|
||||
public static String testMultiThreading() {
|
||||
static String testMultiThreading() {
|
||||
Random random = new Random();
|
||||
// Test calling normal functions.
|
||||
runTestCaseInMultipleThreads(() -> {
|
||||
@@ -123,12 +123,10 @@ public class MultiThreadingTest extends BaseTest {
|
||||
return "ok";
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInDriver() {
|
||||
testMultiThreading();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInWorker() {
|
||||
// Single-process mode doesn't have real workers.
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
@@ -136,7 +134,6 @@ public class MultiThreadingTest extends BaseTest {
|
||||
Assert.assertEquals("ok", obj.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetCurrentActorId() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
RayActor<ActorIdTester> actorIdTester = Ray.createActor(ActorIdTester::new);
|
||||
|
||||
@@ -2,11 +2,11 @@ package org.ray.api.test;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.TestUtils.LargeObject;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.testng.Assert;
|
||||
@@ -67,11 +67,6 @@ public class RayCallTest extends BaseTest {
|
||||
return val;
|
||||
}
|
||||
|
||||
public static class LargeObject implements Serializable {
|
||||
|
||||
private byte[] data = new byte[1024 * 1024];
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
private static LargeObject testLargeObject(LargeObject largeObject) {
|
||||
return largeObject;
|
||||
|
||||
@@ -72,7 +72,7 @@ public class StressTest extends BaseTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(groups = {"directCall"})
|
||||
public void testSubmittingManyTasksToOneActor() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
RayActor<Actor> actor = Ray.createActor(Actor::new);
|
||||
|
||||
+9
-6
@@ -1,10 +1,13 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE suite SYSTEM "http://testng.org/testng-1.0.dtd">
|
||||
<suite name="RAY suite" verbose="2">
|
||||
<test name = "RAY test" >
|
||||
<packages>
|
||||
<package name = "org.ray.api.test.*" />
|
||||
<package name = "org.ray.runtime.*" />
|
||||
</packages>
|
||||
</test>
|
||||
<test name = "RAY test">
|
||||
<packages>
|
||||
<package name = "org.ray.api.test.*" />
|
||||
<package name = "org.ray.runtime.*" />
|
||||
</packages>
|
||||
</test>
|
||||
<listeners>
|
||||
<listener class-name="org.ray.api.RayAlterSuiteListener" />
|
||||
</listeners>
|
||||
</suite>
|
||||
|
||||
@@ -92,7 +92,8 @@ cdef extern from "ray/common/task/task_util.h" namespace "ray" nogil:
|
||||
|
||||
TaskSpecBuilder &SetActorCreationTaskSpec(
|
||||
const CActorID &actor_id, uint64_t max_reconstructions,
|
||||
const c_vector[c_string] &dynamic_worker_options)
|
||||
const c_vector[c_string] &dynamic_worker_options,
|
||||
c_bool is_direct_call)
|
||||
|
||||
TaskSpecBuilder &SetActorTaskSpec(
|
||||
const CActorID &actor_id, const CActorHandleID &actor_handle_id,
|
||||
|
||||
@@ -82,6 +82,7 @@ cdef class TaskSpec:
|
||||
actor_creation_id.native(),
|
||||
max_actor_reconstructions,
|
||||
[],
|
||||
False,
|
||||
)
|
||||
elif not actor_id.is_nil():
|
||||
# Actor task.
|
||||
|
||||
@@ -154,6 +154,11 @@ std::vector<ActorHandleID> TaskSpecification::NewActorHandles() const {
|
||||
message_->actor_task_spec().new_actor_handles());
|
||||
}
|
||||
|
||||
bool TaskSpecification::IsDirectCall() const {
|
||||
RAY_CHECK(IsActorCreationTask());
|
||||
return message_->actor_creation_task_spec().is_direct_call();
|
||||
}
|
||||
|
||||
std::string TaskSpecification::DebugString() const {
|
||||
std::ostringstream stream;
|
||||
stream << "Type=" << TaskType_Name(message_->type())
|
||||
@@ -177,7 +182,8 @@ std::string TaskSpecification::DebugString() const {
|
||||
if (IsActorCreationTask()) {
|
||||
// Print actor creation task spec.
|
||||
stream << ", actor_creation_task_spec={actor_id=" << ActorCreationId()
|
||||
<< ", max_reconstructions=" << MaxActorReconstructions() << "}";
|
||||
<< ", max_reconstructions=" << MaxActorReconstructions()
|
||||
<< ", is_direct_call=" << IsDirectCall() << "}";
|
||||
} else if (IsActorTask()) {
|
||||
// Print actor task spec.
|
||||
stream << ", actor_task_spec={actor_id=" << ActorId()
|
||||
|
||||
@@ -131,6 +131,8 @@ class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
|
||||
|
||||
std::vector<ActorHandleID> NewActorHandles() const;
|
||||
|
||||
bool IsDirectCall() const;
|
||||
|
||||
ObjectID ActorDummyObject() const;
|
||||
|
||||
std::string DebugString() const;
|
||||
|
||||
@@ -90,7 +90,8 @@ class TaskSpecBuilder {
|
||||
/// \return Reference to the builder object itself.
|
||||
TaskSpecBuilder &SetActorCreationTaskSpec(
|
||||
const ActorID &actor_id, uint64_t max_reconstructions = 0,
|
||||
const std::vector<std::string> &dynamic_worker_options = {}) {
|
||||
const std::vector<std::string> &dynamic_worker_options = {},
|
||||
bool is_direct_call = false) {
|
||||
message_->set_type(TaskType::ACTOR_CREATION_TASK);
|
||||
auto actor_creation_spec = message_->mutable_actor_creation_task_spec();
|
||||
actor_creation_spec->set_actor_id(actor_id.Binary());
|
||||
@@ -98,6 +99,7 @@ class TaskSpecBuilder {
|
||||
for (const auto &option : dynamic_worker_options) {
|
||||
actor_creation_spec->add_dynamic_worker_options(option);
|
||||
}
|
||||
actor_creation_spec->set_is_direct_call(is_direct_call);
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
@@ -80,6 +80,7 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
|
||||
if (task_spec.IsActorCreationTask()) {
|
||||
RAY_CHECK(current_actor_id_.IsNil());
|
||||
current_actor_id_ = task_spec.ActorCreationId();
|
||||
current_actor_use_direct_call_ = task_spec.IsDirectCall();
|
||||
}
|
||||
if (task_spec.IsActorTask()) {
|
||||
RAY_CHECK(current_actor_id_ == task_spec.ActorId());
|
||||
@@ -91,6 +92,10 @@ std::shared_ptr<const TaskSpecification> WorkerContext::GetCurrentTask() const {
|
||||
|
||||
const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; }
|
||||
|
||||
bool WorkerContext::CurrentActorUseDirectCall() const {
|
||||
return current_actor_use_direct_call_;
|
||||
}
|
||||
|
||||
WorkerThreadContext &WorkerContext::GetThreadContext() {
|
||||
if (thread_context_ == nullptr) {
|
||||
thread_context_ = std::unique_ptr<WorkerThreadContext>(new WorkerThreadContext());
|
||||
|
||||
@@ -26,6 +26,8 @@ class WorkerContext {
|
||||
|
||||
const ActorID &GetCurrentActorID() const;
|
||||
|
||||
bool CurrentActorUseDirectCall() const;
|
||||
|
||||
int GetNextTaskIndex();
|
||||
|
||||
int GetNextPutIndex();
|
||||
@@ -43,6 +45,9 @@ class WorkerContext {
|
||||
/// ID of current actor.
|
||||
ActorID current_actor_id_;
|
||||
|
||||
/// Whether current actor accepts direct calls.
|
||||
bool current_actor_use_direct_call_;
|
||||
|
||||
private:
|
||||
static WorkerThreadContext &GetThreadContext();
|
||||
|
||||
|
||||
@@ -49,7 +49,9 @@ jclass java_base_task_options_class;
|
||||
jfieldID java_base_task_options_resources;
|
||||
|
||||
jclass java_actor_creation_options_class;
|
||||
jfieldID java_actor_creation_options_default_use_direct_call;
|
||||
jfieldID java_actor_creation_options_max_reconstructions;
|
||||
jfieldID java_actor_creation_options_use_direct_call;
|
||||
jfieldID java_actor_creation_options_jvm_options;
|
||||
|
||||
jclass java_gcs_client_options_class;
|
||||
@@ -146,8 +148,12 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
||||
|
||||
java_actor_creation_options_class =
|
||||
LoadClass(env, "org/ray/api/options/ActorCreationOptions");
|
||||
java_actor_creation_options_default_use_direct_call = env->GetStaticFieldID(
|
||||
java_actor_creation_options_class, "DEFAULT_USE_DIRECT_CALL", "Z");
|
||||
java_actor_creation_options_max_reconstructions =
|
||||
env->GetFieldID(java_actor_creation_options_class, "maxReconstructions", "I");
|
||||
java_actor_creation_options_use_direct_call =
|
||||
env->GetFieldID(java_actor_creation_options_class, "useDirectCall", "Z");
|
||||
java_actor_creation_options_jvm_options = env->GetFieldID(
|
||||
java_actor_creation_options_class, "jvmOptions", "Ljava/lang/String;");
|
||||
|
||||
|
||||
@@ -92,8 +92,12 @@ extern jfieldID java_base_task_options_resources;
|
||||
|
||||
/// ActorCreationOptions class
|
||||
extern jclass java_actor_creation_options_class;
|
||||
/// DEFAULT_USE_DIRECT_CALL field of ActorCreationOptions class
|
||||
extern jfieldID java_actor_creation_options_default_use_direct_call;
|
||||
/// maxReconstructions field of ActorCreationOptions class
|
||||
extern jfieldID java_actor_creation_options_max_reconstructions;
|
||||
/// useDirectCall field of ActorCreationOptions class
|
||||
extern jfieldID java_actor_creation_options_use_direct_call;
|
||||
/// jvmOptions field of ActorCreationOptions class
|
||||
extern jfieldID java_actor_creation_options_jvm_options;
|
||||
|
||||
|
||||
@@ -129,6 +129,25 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook(
|
||||
ray::RayLog::ShutDownRayLog();
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_RayNativeRuntime
|
||||
* Method: nativeSetResource
|
||||
* Signature: (JLjava/lang/String;D[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetResource(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jstring resourceName,
|
||||
jdouble capacity, jbyteArray nodeId) {
|
||||
const auto node_id = JavaByteArrayToId<ClientID>(env, nodeId);
|
||||
const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE);
|
||||
|
||||
auto &raylet_client =
|
||||
reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)->GetRayletClient();
|
||||
auto status = raylet_client.SetResource(native_resource_name,
|
||||
static_cast<double>(capacity), node_id);
|
||||
env->ReleaseStringUTFChars(resourceName, native_resource_name);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -48,6 +48,14 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(JNIEnv
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook(JNIEnv *,
|
||||
jclass);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_RayNativeRuntime
|
||||
* Method: nativeSetResource
|
||||
* Signature: (JLjava/lang/String;D[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetResource(
|
||||
JNIEnv *, jclass, jlong, jstring, jdouble, jbyteArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -57,6 +57,16 @@ JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLangua
|
||||
return (jint)GetActorHandle(nativeActorHandle).ActorLanguage();
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_actor_NativeRayActor
|
||||
* Method: nativeIsDirectCallActor
|
||||
* Signature: (J)Z
|
||||
*/
|
||||
JNIEXPORT jboolean JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeIsDirectCallActor(
|
||||
JNIEnv *env, jclass o, jlong nativeActorHandle) {
|
||||
return GetActorHandle(nativeActorHandle).IsDirectCallActor();
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_actor_NativeRayActor
|
||||
* Method: nativeGetActorCreationTaskFunctionDescriptor
|
||||
|
||||
@@ -40,6 +40,14 @@ Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorHandleId(JNIEnv *, jclas
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_actor_NativeRayActor
|
||||
* Method: nativeIsDirectCallActor
|
||||
* Signature: (J)Z
|
||||
*/
|
||||
JNIEXPORT jboolean JNICALL
|
||||
Java_org_ray_runtime_actor_NativeRayActor_nativeIsDirectCallActor(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_actor_NativeRayActor
|
||||
* Method: nativeGetActorCreationTaskFunctionDescriptor
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
#include "ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h"
|
||||
#include <jni.h>
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/core_worker/common.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
#include "ray/core_worker/lib/java/jni_utils.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
|
||||
inline RayletClient &GetRayletClientFromPointer(jlong nativeCoreWorkerPointer) {
|
||||
return reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)->GetRayletClient();
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
using ray::ClientID;
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_NativeRayletClient
|
||||
* Method: nativePrepareCheckpoint
|
||||
* Signature: (J[B)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_raylet_NativeRayletClient_nativePrepareCheckpoint(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray actorId) {
|
||||
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
|
||||
ActorCheckpointID checkpoint_id;
|
||||
auto status = GetRayletClientFromPointer(nativeCoreWorkerPointer)
|
||||
.PrepareActorCheckpoint(actor_id, checkpoint_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
jbyteArray result = env->NewByteArray(checkpoint_id.Size());
|
||||
env->SetByteArrayRegion(result, 0, checkpoint_id.Size(),
|
||||
reinterpret_cast<const jbyte *>(checkpoint_id.Data()));
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_NativeRayletClient
|
||||
* Method: nativeNotifyActorResumedFromCheckpoint
|
||||
* Signature: (J[B[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_raylet_NativeRayletClient_nativeNotifyActorResumedFromCheckpoint(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray actorId,
|
||||
jbyteArray checkpointId) {
|
||||
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
|
||||
const auto checkpoint_id = JavaByteArrayToId<ActorCheckpointID>(env, checkpointId);
|
||||
auto status = GetRayletClientFromPointer(nativeCoreWorkerPointer)
|
||||
.NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_NativeRayletClient
|
||||
* Method: nativeSetResource
|
||||
* Signature: (JLjava/lang/String;D[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_NativeRayletClient_nativeSetResource(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jstring resourceName,
|
||||
jdouble capacity, jbyteArray nodeId) {
|
||||
const auto node_id = JavaByteArrayToId<ClientID>(env, nodeId);
|
||||
const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE);
|
||||
|
||||
auto status =
|
||||
GetRayletClientFromPointer(nativeCoreWorkerPointer)
|
||||
.SetResource(native_resource_name, static_cast<double>(capacity), node_id);
|
||||
env->ReleaseStringUTFChars(resourceName, native_resource_name);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1,39 +0,0 @@
|
||||
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||
#include <jni.h>
|
||||
/* Header for class org_ray_runtime_raylet_NativeRayletClient */
|
||||
|
||||
#ifndef _Included_org_ray_runtime_raylet_NativeRayletClient
|
||||
#define _Included_org_ray_runtime_raylet_NativeRayletClient
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_NativeRayletClient
|
||||
* Method: nativePrepareCheckpoint
|
||||
* Signature: (J[B)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_raylet_NativeRayletClient_nativePrepareCheckpoint(JNIEnv *, jclass,
|
||||
jlong, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_NativeRayletClient
|
||||
* Method: nativeNotifyActorResumedFromCheckpoint
|
||||
* Signature: (J[B[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_raylet_NativeRayletClient_nativeNotifyActorResumedFromCheckpoint(
|
||||
JNIEnv *, jclass, jlong, jbyteArray, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_NativeRayletClient
|
||||
* Method: nativeSetResource
|
||||
* Signature: (JLjava/lang/String;D[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_NativeRayletClient_nativeSetResource(
|
||||
JNIEnv *, jclass, jlong, jstring, jdouble, jbyteArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
@@ -0,0 +1,55 @@
|
||||
#include "ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h"
|
||||
#include <jni.h>
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/core_worker/common.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
#include "ray/core_worker/lib/java/jni_utils.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
using ray::ClientID;
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_task_NativeTaskExecutor
|
||||
* Method: nativePrepareCheckpoint
|
||||
* Signature: (J)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) {
|
||||
auto &core_worker = *reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
|
||||
const auto &actor_id = core_worker.GetWorkerContext().GetCurrentActorID();
|
||||
const auto &task_spec = core_worker.GetWorkerContext().GetCurrentTask();
|
||||
RAY_CHECK(task_spec->IsActorTask());
|
||||
ActorCheckpointID checkpoint_id;
|
||||
auto status = core_worker.GetRayletClient().PrepareActorCheckpoint(
|
||||
actor_id, checkpoint_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
jbyteArray result = env->NewByteArray(checkpoint_id.Size());
|
||||
env->SetByteArrayRegion(result, 0, checkpoint_id.Size(),
|
||||
reinterpret_cast<const jbyte *>(checkpoint_id.Data()));
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_task_NativeTaskExecutor
|
||||
* Method: nativeNotifyActorResumedFromCheckpoint
|
||||
* Signature: (J[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_task_NativeTaskExecutor_nativeNotifyActorResumedFromCheckpoint(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray checkpointId) {
|
||||
auto &core_worker = *reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
|
||||
const auto &actor_id = core_worker.GetWorkerContext().GetCurrentActorID();
|
||||
const auto checkpoint_id = JavaByteArrayToId<ActorCheckpointID>(env, checkpointId);
|
||||
auto status = core_worker.GetRayletClient().NotifyActorResumedFromCheckpoint(
|
||||
actor_id, checkpoint_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,33 @@
|
||||
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||
#include <jni.h>
|
||||
/* Header for class org_ray_runtime_task_NativeTaskExecutor */
|
||||
|
||||
#ifndef _Included_org_ray_runtime_task_NativeTaskExecutor
|
||||
#define _Included_org_ray_runtime_task_NativeTaskExecutor
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#undef org_ray_runtime_task_NativeTaskExecutor_NUM_ACTOR_CHECKPOINTS_TO_KEEP
|
||||
#define org_ray_runtime_task_NativeTaskExecutor_NUM_ACTOR_CHECKPOINTS_TO_KEEP 20L
|
||||
/*
|
||||
* Class: org_ray_runtime_task_NativeTaskExecutor
|
||||
* Method: nativePrepareCheckpoint
|
||||
* Signature: (J)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(JNIEnv *, jclass,
|
||||
jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_task_NativeTaskExecutor
|
||||
* Method: nativeNotifyActorResumedFromCheckpoint
|
||||
* Signature: (J[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_task_NativeTaskExecutor_nativeNotifyActorResumedFromCheckpoint(
|
||||
JNIEnv *, jclass, jlong, jbyteArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
@@ -77,11 +77,14 @@ inline ray::TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject call
|
||||
inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env,
|
||||
jobject actorCreationOptions) {
|
||||
uint64_t max_reconstructions = 0;
|
||||
bool use_direct_call;
|
||||
std::unordered_map<std::string, double> resources;
|
||||
std::vector<std::string> dynamic_worker_options;
|
||||
if (actorCreationOptions) {
|
||||
max_reconstructions = static_cast<uint64_t>(env->GetIntField(
|
||||
actorCreationOptions, java_actor_creation_options_max_reconstructions));
|
||||
use_direct_call = env->GetBooleanField(actorCreationOptions,
|
||||
java_actor_creation_options_use_direct_call);
|
||||
jobject java_resources =
|
||||
env->GetObjectField(actorCreationOptions, java_base_task_options_resources);
|
||||
resources = ToResources(env, java_resources);
|
||||
@@ -91,10 +94,14 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env,
|
||||
std::string jvm_options = JavaStringToNativeString(env, java_jvm_options);
|
||||
dynamic_worker_options.emplace_back(jvm_options);
|
||||
}
|
||||
} else {
|
||||
use_direct_call =
|
||||
env->GetStaticBooleanField(java_actor_creation_options_class,
|
||||
java_actor_creation_options_default_use_direct_call);
|
||||
}
|
||||
|
||||
ray::ActorCreationOptions action_creation_options{
|
||||
static_cast<uint64_t>(max_reconstructions), false, resources,
|
||||
static_cast<uint64_t>(max_reconstructions), use_direct_call, resources,
|
||||
dynamic_worker_options};
|
||||
return action_creation_options;
|
||||
}
|
||||
|
||||
@@ -22,12 +22,13 @@ CoreWorkerTaskExecutionInterface::CoreWorkerTaskExecutionInterface(
|
||||
task_receivers_.emplace(
|
||||
TaskTransportType::RAYLET,
|
||||
std::unique_ptr<CoreWorkerRayletTaskReceiver>(new CoreWorkerRayletTaskReceiver(
|
||||
raylet_client, object_interface_, *main_service_, worker_server_, func)));
|
||||
worker_context_, raylet_client, object_interface_, *main_service_,
|
||||
worker_server_, func)));
|
||||
task_receivers_.emplace(
|
||||
TaskTransportType::DIRECT_ACTOR,
|
||||
std::unique_ptr<CoreWorkerDirectActorTaskReceiver>(
|
||||
new CoreWorkerDirectActorTaskReceiver(object_interface_, *main_service_,
|
||||
worker_server_, func)));
|
||||
new CoreWorkerDirectActorTaskReceiver(worker_context_, object_interface_,
|
||||
*main_service_, worker_server_, func)));
|
||||
|
||||
// Start RPC server after all the task receivers are properly initialized.
|
||||
worker_server_.Run();
|
||||
|
||||
@@ -173,7 +173,8 @@ Status CoreWorkerTaskInterface::CreateActor(
|
||||
actor_creation_options.resources, actor_creation_options.resources,
|
||||
TaskTransportType::RAYLET, &return_ids);
|
||||
builder.SetActorCreationTaskSpec(actor_id, actor_creation_options.max_reconstructions,
|
||||
actor_creation_options.dynamic_worker_options);
|
||||
actor_creation_options.dynamic_worker_options,
|
||||
actor_creation_options.is_direct_call);
|
||||
|
||||
*actor_handle = std::unique_ptr<ActorHandle>(new ActorHandle(
|
||||
actor_id, ActorHandleID::Nil(), function.language,
|
||||
|
||||
@@ -21,12 +21,11 @@ CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter(
|
||||
: io_service_(io_service),
|
||||
gcs_client_(gcs_client),
|
||||
client_call_manager_(io_service),
|
||||
store_provider_(std::move(store_provider)) {
|
||||
RAY_CHECK_OK(SubscribeActorUpdates());
|
||||
}
|
||||
store_provider_(std::move(store_provider)) {}
|
||||
|
||||
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(
|
||||
const TaskSpecification &task_spec) {
|
||||
RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId();
|
||||
if (HasByReferenceArgs(task_spec)) {
|
||||
return Status::Invalid("direct actor call only supports by-value arguments");
|
||||
}
|
||||
@@ -41,6 +40,12 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(
|
||||
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
|
||||
if (subscribed_actors_.find(actor_id) == subscribed_actors_.end()) {
|
||||
RAY_CHECK_OK(SubscribeActorUpdates(actor_id));
|
||||
subscribed_actors_.insert(actor_id);
|
||||
}
|
||||
|
||||
auto iter = actor_states_.find(actor_id);
|
||||
if (iter == actor_states_.end() ||
|
||||
iter->second.state_ == ActorTableData::RECONSTRUCTING) {
|
||||
@@ -51,6 +56,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(
|
||||
// to have a timeout to mark it as invalid if it doesn't show up in the
|
||||
// specified time.
|
||||
pending_requests_[actor_id].emplace_back(std::move(request));
|
||||
RAY_LOG(DEBUG) << "Actor " << actor_id << " is not yet created.";
|
||||
return Status::OK();
|
||||
} else if (iter->second.state_ == ActorTableData::ALIVE) {
|
||||
// Actor is alive, submit the request.
|
||||
@@ -62,17 +68,19 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(
|
||||
|
||||
// Submit request.
|
||||
auto &client = rpc_clients_[actor_id];
|
||||
PushTask(*client, *request, task_id, num_returns);
|
||||
PushTask(*client, *request, actor_id, task_id, num_returns);
|
||||
return Status::OK();
|
||||
} else {
|
||||
// Actor is dead, treat the task as failure.
|
||||
RAY_CHECK(iter->second.state_ == ActorTableData::DEAD);
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED);
|
||||
return Status::IOError("Actor is dead.");
|
||||
// Return OK here so that we can get the error from store with get operation.
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
Status CoreWorkerDirectActorTaskSubmitter::SubscribeActorUpdates() {
|
||||
Status CoreWorkerDirectActorTaskSubmitter::SubscribeActorUpdates(
|
||||
const ActorID &actor_id) {
|
||||
// Register a callback to handle actor notifications.
|
||||
auto actor_notification_callback = [this](const ActorID &actor_id,
|
||||
const ActorTableData &actor_data) {
|
||||
@@ -92,6 +100,19 @@ Status CoreWorkerDirectActorTaskSubmitter::SubscribeActorUpdates() {
|
||||
} else {
|
||||
// Remove rpc client if it's dead or being reconstructed.
|
||||
rpc_clients_.erase(actor_id);
|
||||
|
||||
// For tasks that have been sent and are waiting for replies, treat them
|
||||
// as failed when the destination actor is dead or reconstructing.
|
||||
auto iter = waiting_reply_tasks_.find(actor_id);
|
||||
if (iter != waiting_reply_tasks_.end()) {
|
||||
for (const auto &entry : iter->second) {
|
||||
const auto &task_id = entry.first;
|
||||
const auto num_returns = entry.second;
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED);
|
||||
}
|
||||
waiting_reply_tasks_.erase(actor_id);
|
||||
}
|
||||
|
||||
// If this actor is permanently dead and there are pending requests, treat
|
||||
// the pending tasks as failed.
|
||||
if (actor_data.state() == ActorTableData::DEAD &&
|
||||
@@ -111,7 +132,8 @@ Status CoreWorkerDirectActorTaskSubmitter::SubscribeActorUpdates() {
|
||||
<< ", port: " << actor_data.port();
|
||||
};
|
||||
|
||||
return gcs_client_.Actors().AsyncSubscribe(actor_notification_callback, nullptr);
|
||||
return gcs_client_.Actors().AsyncSubscribe(actor_id, actor_notification_callback,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks(
|
||||
@@ -125,7 +147,8 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks(
|
||||
auto &requests = pending_requests_[actor_id];
|
||||
while (!requests.empty()) {
|
||||
const auto &request = *requests.front();
|
||||
PushTask(*client, request, TaskID::FromBinary(request.task_spec().task_id()),
|
||||
PushTask(*client, request, actor_id,
|
||||
TaskID::FromBinary(request.task_spec().task_id()),
|
||||
request.task_spec().num_returns());
|
||||
requests.pop_front();
|
||||
}
|
||||
@@ -133,11 +156,18 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks(
|
||||
|
||||
void CoreWorkerDirectActorTaskSubmitter::PushTask(rpc::DirectActorClient &client,
|
||||
const rpc::PushTaskRequest &request,
|
||||
const ActorID &actor_id,
|
||||
const TaskID &task_id,
|
||||
int num_returns) {
|
||||
auto status = client.PushTask(
|
||||
request,
|
||||
[this, task_id, num_returns](Status status, const rpc::PushTaskReply &reply) {
|
||||
RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id;
|
||||
waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns));
|
||||
auto status =
|
||||
client.PushTask(request, [this, actor_id, task_id, num_returns](
|
||||
Status status, const rpc::PushTaskReply &reply) {
|
||||
{
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
waiting_reply_tasks_[actor_id].erase(task_id);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED);
|
||||
return;
|
||||
@@ -170,6 +200,8 @@ void CoreWorkerDirectActorTaskSubmitter::PushTask(rpc::DirectActorClient &client
|
||||
|
||||
void CoreWorkerDirectActorTaskSubmitter::TreatTaskAsFailed(
|
||||
const TaskID &task_id, int num_returns, const rpc::ErrorType &error_type) {
|
||||
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
|
||||
<< ", error_type: " << ErrorType_Name(error_type);
|
||||
for (int i = 0; i < num_returns; i++) {
|
||||
const auto object_id = ObjectID::ForTaskReturn(
|
||||
task_id, /*index=*/i + 1,
|
||||
@@ -181,16 +213,24 @@ void CoreWorkerDirectActorTaskSubmitter::TreatTaskAsFailed(
|
||||
}
|
||||
}
|
||||
|
||||
bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) const {
|
||||
bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) {
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
|
||||
if (subscribed_actors_.find(actor_id) == subscribed_actors_.end()) {
|
||||
RAY_CHECK_OK(SubscribeActorUpdates(actor_id));
|
||||
subscribed_actors_.insert(actor_id);
|
||||
}
|
||||
|
||||
auto iter = actor_states_.find(actor_id);
|
||||
return (iter != actor_states_.end() && iter->second.state_ == ActorTableData::ALIVE);
|
||||
}
|
||||
|
||||
CoreWorkerDirectActorTaskReceiver::CoreWorkerDirectActorTaskReceiver(
|
||||
CoreWorkerObjectInterface &object_interface, boost::asio::io_service &io_service,
|
||||
rpc::GrpcServer &server, const TaskHandler &task_handler)
|
||||
: object_interface_(object_interface),
|
||||
WorkerContext &worker_context, CoreWorkerObjectInterface &object_interface,
|
||||
boost::asio::io_service &io_service, rpc::GrpcServer &server,
|
||||
const TaskHandler &task_handler)
|
||||
: worker_context_(worker_context),
|
||||
object_interface_(object_interface),
|
||||
task_service_(io_service, *this),
|
||||
task_handler_(task_handler) {
|
||||
server.RegisterService(task_service_);
|
||||
@@ -200,12 +240,18 @@ void CoreWorkerDirectActorTaskReceiver::HandlePushTask(
|
||||
const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
const TaskSpecification task_spec(request.task_spec());
|
||||
RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId();
|
||||
if (HasByReferenceArgs(task_spec)) {
|
||||
send_reply_callback(
|
||||
Status::Invalid("direct actor call only supports by value arguments"), nullptr,
|
||||
nullptr);
|
||||
return;
|
||||
}
|
||||
if (task_spec.IsActorTask() && !worker_context_.CurrentActorUseDirectCall()) {
|
||||
send_reply_callback(Status::Invalid("This actor doesn't accept direct calls."),
|
||||
nullptr, nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto num_returns = task_spec.NumReturns();
|
||||
RAY_CHECK(task_spec.IsActorCreationTask() || task_spec.IsActorTask());
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define RAY_CORE_WORKER_DIRECT_ACTOR_TRANSPORT_H
|
||||
|
||||
#include <list>
|
||||
#include <set>
|
||||
|
||||
#include "ray/core_worker/object_interface.h"
|
||||
#include "ray/core_worker/transport/transport.h"
|
||||
@@ -39,8 +40,8 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter {
|
||||
Status SubmitTask(const TaskSpecification &task_spec) override;
|
||||
|
||||
private:
|
||||
/// Subscribe to all actor updates.
|
||||
Status SubscribeActorUpdates();
|
||||
/// Subscribe to updates of an actor.
|
||||
Status SubscribeActorUpdates(const ActorID &actor_id);
|
||||
|
||||
/// Push a task to a remote actor via the given client.
|
||||
/// Note, this function doesn't return any error status code. If an error occurs while
|
||||
@@ -48,11 +49,12 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter {
|
||||
///
|
||||
/// \param[in] client The RPC client to send tasks to an actor.
|
||||
/// \param[in] request The request to send.
|
||||
/// \param[in] actor_id Actor ID.
|
||||
/// \param[in] task_id The ID of a task.
|
||||
/// \param[in] num_returns Number of return objects.
|
||||
/// \return Void.
|
||||
void PushTask(rpc::DirectActorClient &client, const rpc::PushTaskRequest &request,
|
||||
const TaskID &task_id, int num_returns);
|
||||
const ActorID &actor_id, const TaskID &task_id, int num_returns);
|
||||
|
||||
/// Treat a task as failed.
|
||||
///
|
||||
@@ -78,7 +80,7 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter {
|
||||
///
|
||||
/// \param[in] actor_id The actor ID.
|
||||
/// \return Whether this actor is alive.
|
||||
bool IsActorAlive(const ActorID &actor_id) const;
|
||||
bool IsActorAlive(const ActorID &actor_id);
|
||||
|
||||
/// The IO event loop.
|
||||
boost::asio::io_service &io_service_;
|
||||
@@ -92,24 +94,22 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter {
|
||||
/// Mutex to proect the various maps below.
|
||||
mutable std::mutex mutex_;
|
||||
|
||||
/// Map from actor id to actor state. This currently includes all actors in the system.
|
||||
///
|
||||
/// TODO(zhijunfu): this map currently keeps track of all the actors in the system,
|
||||
/// like `actor_registry_` in raylet. Later after new GCS client interface supports
|
||||
/// subscribing updates for a specific actor, this will be updated to only include
|
||||
/// entries for actors that the transport submits tasks to.
|
||||
/// Map from actor id to actor state. This only includes actors that we send tasks to.
|
||||
std::unordered_map<ActorID, ActorStateData> actor_states_;
|
||||
|
||||
/// Map from actor id to rpc client. This only includes actors that we send tasks to.
|
||||
///
|
||||
/// TODO(zhijunfu): this will be moved into `actor_states_` later when we can
|
||||
/// subscribe updates for a specific actor.
|
||||
std::unordered_map<ActorID, std::unique_ptr<rpc::DirectActorClient>> rpc_clients_;
|
||||
|
||||
/// Map from actor id to the actor's pending requests.
|
||||
std::unordered_map<ActorID, std::list<std::unique_ptr<rpc::PushTaskRequest>>>
|
||||
pending_requests_;
|
||||
|
||||
/// Map from actor id to the tasks that are waiting for reply.
|
||||
std::unordered_map<ActorID, std::unordered_map<TaskID, int>> waiting_reply_tasks_;
|
||||
|
||||
/// The set of actors which are subscribed for further updates.
|
||||
std::unordered_set<ActorID> subscribed_actors_;
|
||||
|
||||
/// The store provider.
|
||||
std::unique_ptr<CoreWorkerStoreProvider> store_provider_;
|
||||
|
||||
@@ -119,7 +119,8 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter {
|
||||
class CoreWorkerDirectActorTaskReceiver : public CoreWorkerTaskReceiver,
|
||||
public rpc::DirectActorHandler {
|
||||
public:
|
||||
CoreWorkerDirectActorTaskReceiver(CoreWorkerObjectInterface &object_interface,
|
||||
CoreWorkerDirectActorTaskReceiver(WorkerContext &worker_context,
|
||||
CoreWorkerObjectInterface &object_interface,
|
||||
boost::asio::io_service &io_service,
|
||||
rpc::GrpcServer &server,
|
||||
const TaskHandler &task_handler);
|
||||
@@ -135,6 +136,8 @@ class CoreWorkerDirectActorTaskReceiver : public CoreWorkerTaskReceiver,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
private:
|
||||
// Worker context.
|
||||
WorkerContext &worker_context_;
|
||||
// Object interface.
|
||||
CoreWorkerObjectInterface &object_interface_;
|
||||
/// The rpc service for `DirectActorService`.
|
||||
|
||||
@@ -14,10 +14,11 @@ Status CoreWorkerRayletTaskSubmitter::SubmitTask(const TaskSpecification &task)
|
||||
}
|
||||
|
||||
CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver(
|
||||
std::unique_ptr<RayletClient> &raylet_client,
|
||||
WorkerContext &worker_context, std::unique_ptr<RayletClient> &raylet_client,
|
||||
CoreWorkerObjectInterface &object_interface, boost::asio::io_service &io_service,
|
||||
rpc::GrpcServer &server, const TaskHandler &task_handler)
|
||||
: raylet_client_(raylet_client),
|
||||
: worker_context_(worker_context),
|
||||
raylet_client_(raylet_client),
|
||||
object_interface_(object_interface),
|
||||
task_service_(io_service, *this),
|
||||
task_handler_(task_handler) {
|
||||
@@ -30,6 +31,12 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask(
|
||||
const Task task(request.task());
|
||||
const auto &task_spec = task.GetTaskSpecification();
|
||||
RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId();
|
||||
if (task_spec.IsActorTask() && worker_context_.CurrentActorUseDirectCall()) {
|
||||
send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr,
|
||||
nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
auto status = task_handler_(task_spec, &results);
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ class CoreWorkerRayletTaskSubmitter : public CoreWorkerTaskSubmitter {
|
||||
class CoreWorkerRayletTaskReceiver : public CoreWorkerTaskReceiver,
|
||||
public rpc::WorkerTaskHandler {
|
||||
public:
|
||||
CoreWorkerRayletTaskReceiver(std::unique_ptr<RayletClient> &raylet_client,
|
||||
CoreWorkerRayletTaskReceiver(WorkerContext &worker_context,
|
||||
std::unique_ptr<RayletClient> &raylet_client,
|
||||
CoreWorkerObjectInterface &object_interface,
|
||||
boost::asio::io_service &io_service,
|
||||
rpc::GrpcServer &server, const TaskHandler &task_handler);
|
||||
@@ -49,6 +50,8 @@ class CoreWorkerRayletTaskReceiver : public CoreWorkerTaskReceiver,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
private:
|
||||
// Worker context.
|
||||
WorkerContext &worker_context_;
|
||||
/// Raylet client.
|
||||
std::unique_ptr<RayletClient> &raylet_client_;
|
||||
// Object interface.
|
||||
|
||||
@@ -90,6 +90,8 @@ message ActorCreationTaskSpec {
|
||||
// the placeholder strings (`RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_0`,
|
||||
// `RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_1`, etc) in the worker command.
|
||||
repeated string dynamic_worker_options = 4;
|
||||
// Whether direct actor call is used.
|
||||
bool is_direct_call = 5;
|
||||
}
|
||||
|
||||
// Task spec of an actor task.
|
||||
|
||||
@@ -109,6 +109,8 @@ message ActorTableData {
|
||||
string ip_address = 9;
|
||||
// The port that the actor is listening on.
|
||||
int32 port = 10;
|
||||
// Whether direct actor call is used.
|
||||
bool is_direct_call = 11;
|
||||
}
|
||||
|
||||
message ErrorTableData {
|
||||
|
||||
@@ -98,16 +98,19 @@ void ActorRegistration::AddHandle(const ActorHandleID &handle_id,
|
||||
int ActorRegistration::NumHandles() const { return frontier_.size(); }
|
||||
|
||||
std::shared_ptr<ActorCheckpointData> ActorRegistration::GenerateCheckpointData(
|
||||
const ActorID &actor_id, const Task &task) {
|
||||
const auto actor_handle_id = task.GetTaskSpecification().ActorHandleId();
|
||||
const auto dummy_object = task.GetTaskSpecification().ActorDummyObject();
|
||||
// Make a copy of the actor registration, and extend its frontier to include
|
||||
// the most recent task.
|
||||
// Note(hchen): this is needed because this method is called before
|
||||
// `FinishAssignedTask`, which will be called when the worker tries to fetch
|
||||
// the next task.
|
||||
const ActorID &actor_id, const Task *task) {
|
||||
// Make a copy of the actor registration
|
||||
ActorRegistration copy = *this;
|
||||
copy.ExtendFrontier(actor_handle_id, dummy_object);
|
||||
if (task) {
|
||||
const auto actor_handle_id = task->GetTaskSpecification().ActorHandleId();
|
||||
const auto dummy_object = task->GetTaskSpecification().ActorDummyObject();
|
||||
// Extend its frontier to include the most recent task.
|
||||
// NOTE(hchen): For non-direct-call actors, this is needed because this method is
|
||||
// called before `FinishAssignedTask`, which will be called when the worker tries to
|
||||
// fetch the next task. For direct-call actors, checkpoint data doesn't contain
|
||||
// frontier info, so we don't need to do `ExtendFrontier` here.
|
||||
copy.ExtendFrontier(actor_handle_id, dummy_object);
|
||||
}
|
||||
|
||||
// Use actor's current state to generate checkpoint data.
|
||||
auto checkpoint_data = std::make_shared<ActorCheckpointData>();
|
||||
|
||||
@@ -133,10 +133,11 @@ class ActorRegistration {
|
||||
/// Generate checkpoint data based on actor's current state.
|
||||
///
|
||||
/// \param actor_id ID of this actor.
|
||||
/// \param task The task that just finished on the actor.
|
||||
/// \param task The task that just finished on the actor. (nullptr when it's direct
|
||||
/// call.)
|
||||
/// \return A shared pointer to the generated checkpoint data.
|
||||
std::shared_ptr<ActorCheckpointData> GenerateCheckpointData(const ActorID &actor_id,
|
||||
const Task &task);
|
||||
const Task *task);
|
||||
|
||||
private:
|
||||
/// Information from the global actor table about this actor, including the
|
||||
|
||||
@@ -1,292 +0,0 @@
|
||||
#include "ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h"
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/core_worker/lib/java/jni_utils.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "ray/util/logging.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeInit
|
||||
* Signature: (Ljava/lang/String;[BZ[B)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
|
||||
JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker,
|
||||
jbyteArray jobId) {
|
||||
const auto worker_id = JavaByteArrayToId<WorkerID>(env, workerId);
|
||||
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
|
||||
const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE);
|
||||
auto raylet_client = new std::unique_ptr<RayletClient>(
|
||||
new RayletClient(nativeString, worker_id, isWorker, job_id, Language::JAVA));
|
||||
env->ReleaseStringUTFChars(sockName, nativeString);
|
||||
return reinterpret_cast<jlong>(raylet_client);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeSubmitTask
|
||||
* Signature: (J[BLjava/nio/ByteBuffer;II)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask(
|
||||
JNIEnv *env, jclass, jlong client, jbyteArray taskSpec) {
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
|
||||
jbyte *data = env->GetByteArrayElements(taskSpec, NULL);
|
||||
jsize size = env->GetArrayLength(taskSpec);
|
||||
ray::rpc::TaskSpec task_spec_message;
|
||||
task_spec_message.ParseFromArray(data, size);
|
||||
env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT);
|
||||
|
||||
ray::TaskSpecification task_spec(task_spec_message);
|
||||
auto status = raylet_client->SubmitTask(task_spec);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeGetTask
|
||||
* Signature: (J)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask(
|
||||
JNIEnv *env, jclass, jlong client) {
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
|
||||
std::unique_ptr<ray::TaskSpecification> spec;
|
||||
auto status = raylet_client->GetTask(&spec);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
|
||||
// Serialize the task spec and copy to Java byte array.
|
||||
auto task_data = spec->Serialize();
|
||||
|
||||
jbyteArray result = env->NewByteArray(task_data.size());
|
||||
if (result == nullptr) {
|
||||
return nullptr; /* out of memory error thrown */
|
||||
}
|
||||
|
||||
env->SetByteArrayRegion(result, 0, task_data.size(),
|
||||
reinterpret_cast<const jbyte *>(task_data.data()));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeDestroy
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy(
|
||||
JNIEnv *env, jclass, jlong client) {
|
||||
auto raylet_client = reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
auto status = (*raylet_client)->Disconnect();
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
delete raylet_client;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeWaitObject
|
||||
* Signature: (J[[BIIZ[B)[Z
|
||||
*/
|
||||
JNIEXPORT jbooleanArray JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject(
|
||||
JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jint numReturns,
|
||||
jint timeoutMillis, jboolean isWaitLocal, jbyteArray currentTaskId) {
|
||||
std::vector<ObjectID> object_ids;
|
||||
auto len = env->GetArrayLength(objectIds);
|
||||
for (int i = 0; i < len; i++) {
|
||||
jbyteArray object_id_bytes =
|
||||
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
|
||||
const auto object_id = JavaByteArrayToId<ObjectID>(env, object_id_bytes);
|
||||
object_ids.push_back(object_id);
|
||||
env->DeleteLocalRef(object_id_bytes);
|
||||
}
|
||||
const auto current_task_id = JavaByteArrayToId<TaskID>(env, currentTaskId);
|
||||
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
|
||||
// Invoke wait.
|
||||
WaitResultPair result;
|
||||
auto status =
|
||||
raylet_client->Wait(object_ids, numReturns, timeoutMillis,
|
||||
static_cast<bool>(isWaitLocal), current_task_id, &result);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
|
||||
// Convert result to java object.
|
||||
jboolean put_value = true;
|
||||
jbooleanArray resultArray = env->NewBooleanArray(object_ids.size());
|
||||
for (uint i = 0; i < result.first.size(); ++i) {
|
||||
for (uint j = 0; j < object_ids.size(); ++j) {
|
||||
if (result.first[i] == object_ids[j]) {
|
||||
env->SetBooleanArrayRegion(resultArray, j, 1, &put_value);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
put_value = false;
|
||||
for (uint i = 0; i < result.second.size(); ++i) {
|
||||
for (uint j = 0; j < object_ids.size(); ++j) {
|
||||
if (result.second[i] == object_ids[j]) {
|
||||
env->SetBooleanArrayRegion(resultArray, j, 1, &put_value);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return resultArray;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeGenerateActorCreationTaskId
|
||||
* Signature: ([B[BI)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorCreationTaskId(
|
||||
JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId,
|
||||
jint parent_task_counter) {
|
||||
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
|
||||
const auto parent_task_id = JavaByteArrayToId<TaskID>(env, parentTaskId);
|
||||
|
||||
const ActorID actor_id = ray::ActorID::Of(job_id, parent_task_id, parent_task_counter);
|
||||
const TaskID actor_creation_task_id = ray::TaskID::ForActorCreationTask(actor_id);
|
||||
jbyteArray result = env->NewByteArray(actor_creation_task_id.Size());
|
||||
if (nullptr == result) {
|
||||
return nullptr;
|
||||
}
|
||||
env->SetByteArrayRegion(result, 0, actor_creation_task_id.Size(),
|
||||
reinterpret_cast<const jbyte *>(actor_creation_task_id.Data()));
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeGenerateActorTaskId
|
||||
* Signature: ([B[BI[B)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorTaskId(
|
||||
JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId,
|
||||
jint parent_task_counter, jbyteArray actorId) {
|
||||
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
|
||||
const auto parent_task_id = JavaByteArrayToId<TaskID>(env, parentTaskId);
|
||||
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
|
||||
const TaskID actor_task_id =
|
||||
ray::TaskID::ForActorTask(job_id, parent_task_id, parent_task_counter, actor_id);
|
||||
|
||||
jbyteArray result = env->NewByteArray(actor_task_id.Size());
|
||||
if (nullptr == result) {
|
||||
return nullptr;
|
||||
}
|
||||
env->SetByteArrayRegion(result, 0, actor_task_id.Size(),
|
||||
reinterpret_cast<const jbyte *>(actor_task_id.Data()));
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeGenerateNormalTaskId
|
||||
* Signature: ([B[BI)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateNormalTaskId(
|
||||
JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId,
|
||||
jint parent_task_counter) {
|
||||
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
|
||||
const auto parent_task_id = JavaByteArrayToId<TaskID>(env, parentTaskId);
|
||||
const TaskID task_id =
|
||||
ray::TaskID::ForNormalTask(job_id, parent_task_id, parent_task_counter);
|
||||
|
||||
jbyteArray result = env->NewByteArray(task_id.Size());
|
||||
if (nullptr == result) {
|
||||
return nullptr;
|
||||
}
|
||||
env->SetByteArrayRegion(result, 0, task_id.Size(),
|
||||
reinterpret_cast<const jbyte *>(task_id.Data()));
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeFreePlasmaObjects
|
||||
* Signature: (J[[BZZ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects(
|
||||
JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean localOnly,
|
||||
jboolean deleteCreatingTasks) {
|
||||
std::vector<ObjectID> object_ids;
|
||||
auto len = env->GetArrayLength(objectIds);
|
||||
for (int i = 0; i < len; i++) {
|
||||
jbyteArray object_id_bytes =
|
||||
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
|
||||
const auto object_id = JavaByteArrayToId<ObjectID>(env, object_id_bytes);
|
||||
object_ids.push_back(object_id);
|
||||
env->DeleteLocalRef(object_id_bytes);
|
||||
}
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
auto status = raylet_client->FreeObjects(object_ids, localOnly, deleteCreatingTasks);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativePrepareCheckpoint
|
||||
* Signature: (J[B)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env, jclass,
|
||||
jlong client,
|
||||
jbyteArray actorId) {
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
|
||||
ActorCheckpointID checkpoint_id;
|
||||
auto status = raylet_client->PrepareActorCheckpoint(actor_id, checkpoint_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
jbyteArray result = env->NewByteArray(checkpoint_id.Size());
|
||||
env->SetByteArrayRegion(result, 0, checkpoint_id.Size(),
|
||||
reinterpret_cast<const jbyte *>(checkpoint_id.Data()));
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeNotifyActorResumedFromCheckpoint
|
||||
* Signature: (J[B[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint(
|
||||
JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) {
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
|
||||
const auto checkpoint_id = JavaByteArrayToId<ActorCheckpointID>(env, checkpointId);
|
||||
auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeSetResource
|
||||
* Signature: (JLjava/lang/String;D[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource(
|
||||
JNIEnv *env, jclass, jlong client, jstring resourceName, jdouble capacity,
|
||||
jbyteArray nodeId) {
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
const auto node_id = JavaByteArrayToId<ClientID>(env, nodeId);
|
||||
const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE);
|
||||
|
||||
auto status = raylet_client->SetResource(native_resource_name,
|
||||
static_cast<double>(capacity), node_id);
|
||||
env->ReleaseStringUTFChars(resourceName, native_resource_name);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1225,13 +1225,19 @@ void NodeManager::ProcessPrepareActorCheckpointRequest(
|
||||
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
|
||||
RAY_CHECK(worker && worker->GetActorId() == actor_id);
|
||||
|
||||
// Find the task that is running on this actor.
|
||||
const auto task_id = worker->GetAssignedTaskId();
|
||||
const Task &task = local_queues_.GetTaskOfState(task_id, TaskState::RUNNING);
|
||||
// Generate checkpoint id and data.
|
||||
ActorCheckpointID checkpoint_id = ActorCheckpointID::FromRandom();
|
||||
auto checkpoint_data =
|
||||
actor_entry->second.GenerateCheckpointData(actor_entry->first, task);
|
||||
std::shared_ptr<ActorCheckpointData> checkpoint_data;
|
||||
if (actor_entry->second.GetTableData().is_direct_call()) {
|
||||
checkpoint_data =
|
||||
actor_entry->second.GenerateCheckpointData(actor_entry->first, nullptr);
|
||||
} else {
|
||||
// Find the task that is running on this actor.
|
||||
const auto task_id = worker->GetAssignedTaskId();
|
||||
const Task &task = local_queues_.GetTaskOfState(task_id, TaskState::RUNNING);
|
||||
// Generate checkpoint data.
|
||||
checkpoint_data =
|
||||
actor_entry->second.GenerateCheckpointData(actor_entry->first, &task);
|
||||
}
|
||||
|
||||
// Write checkpoint data to GCS.
|
||||
RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Add(
|
||||
@@ -1914,6 +1920,7 @@ std::shared_ptr<ActorTableData> NodeManager::CreateActorTableDataFromCreationTas
|
||||
// This is the first time that the actor has been created, so the number
|
||||
// of remaining reconstructions is the max.
|
||||
actor_info_ptr->set_remaining_reconstructions(task_spec.MaxActorReconstructions());
|
||||
actor_info_ptr->set_is_direct_call(task_spec.IsDirectCall());
|
||||
} else {
|
||||
// If we've already seen this actor, it means that this actor was reconstructed.
|
||||
// Thus, its previous state must be RECONSTRUCTING.
|
||||
|
||||
@@ -128,12 +128,15 @@ void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set,
|
||||
task.GetTaskExecutionSpec().GetMessage());
|
||||
request.set_resource_ids(resource_id_set.Serialize());
|
||||
|
||||
auto status = rpc_client_->AssignTask(
|
||||
request, [](Status status, const rpc::AssignTaskReply &reply) {
|
||||
// Worker has finished this task. There's nothing to do here
|
||||
// and assigning new task will be done when raylet receives
|
||||
// `TaskDone` message.
|
||||
});
|
||||
auto status = rpc_client_->AssignTask(request, [](Status status,
|
||||
const rpc::AssignTaskReply &reply) {
|
||||
if (!status.ok()) {
|
||||
RAY_LOG(ERROR) << "Worker failed to finish executing task: " << status.ToString();
|
||||
}
|
||||
// Worker has finished this task. There's nothing to do here
|
||||
// and assigning new task will be done when raylet receives
|
||||
// `TaskDone` message.
|
||||
});
|
||||
finish_assign_callback(status);
|
||||
if (!status.ok()) {
|
||||
RAY_LOG(ERROR) << "Failed to assign task " << task.GetTaskSpecification().TaskId()
|
||||
|
||||
Reference in New Issue
Block a user