mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 15:44:53 +08:00
[Java] Support concurrent actor calls API. (#7022)
* WIP Temp change Attach native thread to jvm * Fix run mode * Address comments.
This commit is contained in:
@@ -21,12 +21,15 @@ public class ActorCreationOptions extends BaseTaskOptions {
|
||||
|
||||
public final String jvmOptions;
|
||||
|
||||
public final int maxConcurrency;
|
||||
|
||||
private ActorCreationOptions(Map<String, Double> resources, int maxReconstructions,
|
||||
boolean useDirectCall, String jvmOptions) {
|
||||
boolean useDirectCall, String jvmOptions, int maxConcurrency) {
|
||||
super(resources);
|
||||
this.maxReconstructions = maxReconstructions;
|
||||
this.useDirectCall = useDirectCall;
|
||||
this.jvmOptions = jvmOptions;
|
||||
this.maxConcurrency = maxConcurrency;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -38,6 +41,7 @@ public class ActorCreationOptions extends BaseTaskOptions {
|
||||
private int maxReconstructions = NO_RECONSTRUCTION;
|
||||
private boolean useDirectCall = DEFAULT_USE_DIRECT_CALL;
|
||||
private String jvmOptions = null;
|
||||
private int maxConcurrency = 1;
|
||||
|
||||
public Builder setResources(Map<String, Double> resources) {
|
||||
this.resources = resources;
|
||||
@@ -62,8 +66,23 @@ public class ActorCreationOptions extends BaseTaskOptions {
|
||||
return this;
|
||||
}
|
||||
|
||||
// The max number of concurrent calls to allow for this actor.
|
||||
//
|
||||
// This only works with direct actor calls. The max concurrency defaults to 1
|
||||
// for threaded execution. Note that the execution order is not guaranteed
|
||||
// when max_concurrency > 1.
|
||||
public Builder setMaxConcurrency(int maxConcurrency) {
|
||||
if (maxConcurrency <= 0) {
|
||||
throw new IllegalArgumentException("maxConcurrency must be greater than 0.");
|
||||
}
|
||||
|
||||
this.maxConcurrency = maxConcurrency;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ActorCreationOptions createActorCreationOptions() {
|
||||
return new ActorCreationOptions(resources, maxReconstructions, useDirectCall, jvmOptions);
|
||||
return new ActorCreationOptions(
|
||||
resources, maxReconstructions, useDirectCall, jvmOptions, maxConcurrency);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -95,8 +95,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
new GcsClientOptions(rayConfig));
|
||||
Preconditions.checkState(nativeCoreWorkerPointer != 0);
|
||||
|
||||
taskExecutor = new NativeTaskExecutor(nativeCoreWorkerPointer, this);
|
||||
workerContext = new NativeWorkerContext(nativeCoreWorkerPointer);
|
||||
taskExecutor = new NativeTaskExecutor(nativeCoreWorkerPointer, this);
|
||||
objectStore = new NativeObjectStore(workerContext, nativeCoreWorkerPointer);
|
||||
taskSubmitter = new NativeTaskSubmitter(nativeCoreWorkerPointer);
|
||||
|
||||
@@ -153,13 +153,17 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
}
|
||||
|
||||
public void run() {
|
||||
nativeRunTaskExecutor(nativeCoreWorkerPointer, taskExecutor);
|
||||
nativeRunTaskExecutor(nativeCoreWorkerPointer);
|
||||
}
|
||||
|
||||
public long getNativeCoreWorkerPointer() {
|
||||
return nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
public TaskExecutor getTaskExecutor() {
|
||||
return taskExecutor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Register this worker or driver to GCS.
|
||||
*/
|
||||
@@ -189,8 +193,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
String rayletSocket, String nodeIpAddress, int nodeManagerPort, byte[] jobId,
|
||||
GcsClientOptions gcsClientOptions);
|
||||
|
||||
private static native void nativeRunTaskExecutor(long nativeCoreWorkerPointer,
|
||||
TaskExecutor taskExecutor);
|
||||
private static native void nativeRunTaskExecutor(long nativeCoreWorkerPointer);
|
||||
|
||||
private static native void nativeDestroyCoreWorker(long nativeCoreWorkerPointer);
|
||||
|
||||
|
||||
@@ -3,11 +3,15 @@ package org.ray.runtime.task;
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
@@ -23,6 +27,10 @@ public abstract class TaskExecutor {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class);
|
||||
|
||||
// A helper map to help we get the corresponding executor for the given worker in JNI.
|
||||
private static ConcurrentHashMap<UniqueId, TaskExecutor> taskExecutors
|
||||
= new ConcurrentHashMap<>();
|
||||
|
||||
protected final AbstractRayRuntime runtime;
|
||||
|
||||
/**
|
||||
@@ -37,6 +45,13 @@ public abstract class TaskExecutor {
|
||||
|
||||
protected TaskExecutor(AbstractRayRuntime runtime) {
|
||||
this.runtime = runtime;
|
||||
if (RayConfig.getInstance().runMode == RunMode.CLUSTER) {
|
||||
taskExecutors.put(runtime.getWorkerContext().getCurrentWorkerId(), this);
|
||||
}
|
||||
}
|
||||
|
||||
public static TaskExecutor get(byte[] workerId) {
|
||||
return taskExecutors.get(new UniqueId(workerId));
|
||||
}
|
||||
|
||||
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
|
||||
@Test(groups = {"directCall"})
|
||||
public class ActorConcurrentCallTest extends BaseTest {
|
||||
|
||||
@RayRemote
|
||||
public static class ConcurrentActor {
|
||||
private final CountDownLatch countDownLatch = new CountDownLatch(3);
|
||||
|
||||
public String countDown() {
|
||||
countDownLatch.countDown();
|
||||
try {
|
||||
countDownLatch.await();
|
||||
return "ok";
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testConcurrentCall() {
|
||||
TestUtils.skipTestIfDirectActorCallDisabled();
|
||||
|
||||
ActorCreationOptions op = new ActorCreationOptions.Builder()
|
||||
.setMaxConcurrency(3)
|
||||
.createActorCreationOptions();
|
||||
RayActor<ConcurrentActor> actor = Ray.createActor(ConcurrentActor::new, op);
|
||||
RayObject<String> obj1 = Ray.call(ConcurrentActor::countDown, actor);
|
||||
RayObject<String> obj2 = Ray.call(ConcurrentActor::countDown, actor);
|
||||
RayObject<String> obj3 = Ray.call(ConcurrentActor::countDown, actor);
|
||||
|
||||
List<Integer> expectedResult = ImmutableList.of(1, 2, 3);
|
||||
Assert.assertEquals(obj1.get(), "ok");
|
||||
Assert.assertEquals(obj2.get(), "ok");
|
||||
Assert.assertEquals(obj3.get(), "ok");
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user