diff --git a/java/api/src/main/java/io/ray/api/Ray.java b/java/api/src/main/java/io/ray/api/Ray.java index c29a9420d..c4044486e 100644 --- a/java/api/src/main/java/io/ray/api/Ray.java +++ b/java/api/src/main/java/io/ray/api/Ray.java @@ -7,6 +7,7 @@ import io.ray.api.runtime.RayRuntimeFactory; import io.ray.api.runtimecontext.RuntimeContext; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.concurrent.Callable; /** @@ -137,6 +138,34 @@ public final class Ray extends RayCall { return runtime.wait(waitList, waitList.size(), Integer.MAX_VALUE); } + /** + * Get a handle to a named actor of current job. + *

+ * Gets a handle to a named actor with the given name. The actor must + * have been created with name specified. + * + * @param name The name of the named actor. + * @return an ActorHandle to the actor if the actor of specified name exists or an + * Optional.empty() + */ + public static Optional getActor(String name) { + return runtime.getActor(name, false); + } + + /** + * Get a handle to a global named actor. + *

+ * Gets a handle to a global named actor with the given name. The actor must + * have been created with global name specified. + * + * @param name The global name of the named actor. + * @return an ActorHandle to the actor if the actor of specified name exists or an + * Optional.empty() + */ + public static Optional getGlobalActor(String name) { + return runtime.getActor(name, true); + } + /** * If users want to use Ray API in their own threads, call this method to get the async context * and then call {@link #setAsyncContext} at the beginning of the new thread. diff --git a/java/api/src/main/java/io/ray/api/call/ActorCreator.java b/java/api/src/main/java/io/ray/api/call/ActorCreator.java index f6f9da8fe..16a24bda6 100644 --- a/java/api/src/main/java/io/ray/api/call/ActorCreator.java +++ b/java/api/src/main/java/io/ray/api/call/ActorCreator.java @@ -19,6 +19,12 @@ public class ActorCreator extends BaseActorCreator> { } /** + * Set the JVM options for the Java worker that this actor is running in. + * + * Note, if this is set, this actor won't share Java worker with other actors or tasks. + * + * @param jvmOptions JVM options for the Java worker that this actor is running in. + * @return self * @see io.ray.api.options.ActorCreationOptions.Builder#setJvmOptions(java.lang.String) */ public ActorCreator setJvmOptions(String jvmOptions) { diff --git a/java/api/src/main/java/io/ray/api/call/BaseActorCreator.java b/java/api/src/main/java/io/ray/api/call/BaseActorCreator.java index a924243a3..ec281705b 100644 --- a/java/api/src/main/java/io/ray/api/call/BaseActorCreator.java +++ b/java/api/src/main/java/io/ray/api/call/BaseActorCreator.java @@ -1,5 +1,6 @@ package io.ray.api.call; +import io.ray.api.Ray; import io.ray.api.options.ActorCreationOptions; import java.util.Map; @@ -11,6 +12,35 @@ import java.util.Map; public class BaseActorCreator { protected ActorCreationOptions.Builder builder = new ActorCreationOptions.Builder(); + /** + * Set the actor name of a named actor. + * This named actor is only accessible from this job by this name via + * {@link Ray#getActor(java.lang.String)}. If you want create a named actor that is accessible + * from all jobs, use {@link BaseActorCreator#setGlobalName(java.lang.String)} instead. + * + * @param name The name of the named actor. + * @return self + * @see io.ray.api.options.ActorCreationOptions.Builder#setName(String) + */ + public T setName(String name) { + builder.setName(name); + return self(); + } + + /** + * Set the name of this actor. This actor will be accessible from all jobs by this name via + * {@link Ray#getGlobalActor(java.lang.String)}. If you want to create a named actor that is + * only accessible from this job, use {@link BaseActorCreator#setName(java.lang.String)} instead. + * + * @param name The name of the named actor. + * @return self + * @see io.ray.api.options.ActorCreationOptions.Builder#setGlobalName(String) + */ + public T setGlobalName(String name) { + builder.setGlobalName(name); + return self(); + } + /** * Set a custom resource requirement to reserve for the lifetime of this actor. * This method can be called multiple times. If the same resource is set multiple times, @@ -55,9 +85,9 @@ public class BaseActorCreator { } /** - /** + * /** * Set the max number of concurrent calls to allow for this actor. - * + *

* The max concurrency defaults to 1 for threaded execution. * Note that the execution order is not guaranteed when max_concurrency > 1. * diff --git a/java/api/src/main/java/io/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/io/ray/api/options/ActorCreationOptions.java index 53be5c6b4..363e915e9 100644 --- a/java/api/src/main/java/io/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/io/ray/api/options/ActorCreationOptions.java @@ -1,5 +1,6 @@ package io.ray.api.options; +import io.ray.api.Ray; import java.util.HashMap; import java.util.Map; @@ -7,15 +8,17 @@ import java.util.Map; * The options for creating actor. */ public class ActorCreationOptions extends BaseTaskOptions { + public final boolean global; + public final String name; public final int maxRestarts; - public final String jvmOptions; - public final int maxConcurrency; - private ActorCreationOptions(Map resources, int maxRestarts, - String jvmOptions, int maxConcurrency) { + private ActorCreationOptions(boolean global, String name, Map resources, + int maxRestarts, String jvmOptions, int maxConcurrency) { super(resources); + this.global = global; + this.name = name; this.maxRestarts = maxRestarts; this.jvmOptions = jvmOptions; this.maxConcurrency = maxConcurrency; @@ -25,12 +28,42 @@ public class ActorCreationOptions extends BaseTaskOptions { * The inner class for building ActorCreationOptions. */ public static class Builder { - + private boolean global; + private String name; private Map resources = new HashMap<>(); private int maxRestarts = 0; private String jvmOptions = null; private int maxConcurrency = 1; + /** + * Set the actor name of a named actor. + * This named actor is only accessible from this job by this name via + * {@link Ray#getActor(java.lang.String)}. If you want create a named actor that is accessible + * from all jobs, use {@link Builder#setGlobalName(java.lang.String)} instead. + * + * @param name The name of the named actor. + * @return self + */ + public Builder setName(String name) { + this.name = name; + this.global = false; + return this; + } + + /** + * Set the name of this actor. This actor will be accessible from all jobs by this name via + * {@link Ray#getGlobalActor(java.lang.String)}. If you want to create a named actor that is + * only accessible from this job, use {@link Builder#setName(java.lang.String)} instead. + * + * @param name The name of the named actor. + * @return self + */ + public Builder setGlobalName(String name) { + this.name = name; + this.global = true; + return this; + } + /** * Set a custom resource requirement to reserve for the lifetime of this actor. * This method can be called multiple times. If the same resource is set multiple times, @@ -73,7 +106,7 @@ public class ActorCreationOptions extends BaseTaskOptions { /** * Set the JVM options for the Java worker that this actor is running in. - * + *

* Note, if this is set, this actor won't share Java worker with other actors or tasks. * * @param jvmOptions JVM options for the Java worker that this actor is running in. @@ -86,7 +119,7 @@ public class ActorCreationOptions extends BaseTaskOptions { /** * Set the max number of concurrent calls to allow for this actor. - * + *

* The max concurrency defaults to 1 for threaded execution. * Note that the execution order is not guaranteed when max_concurrency > 1. * @@ -104,7 +137,7 @@ public class ActorCreationOptions extends BaseTaskOptions { public ActorCreationOptions build() { return new ActorCreationOptions( - resources, maxRestarts, jvmOptions, maxConcurrency); + global, name, resources, maxRestarts, jvmOptions, maxConcurrency); } } diff --git a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java index 07222790e..318b880e8 100644 --- a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java @@ -9,12 +9,14 @@ import io.ray.api.function.PyActorClass; import io.ray.api.function.PyActorMethod; import io.ray.api.function.PyFunction; import io.ray.api.function.RayFunc; +import io.ray.api.id.ActorId; import io.ray.api.id.ObjectId; import io.ray.api.id.UniqueId; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.CallOptions; import io.ray.api.runtimecontext.RuntimeContext; import java.util.List; +import java.util.Optional; import java.util.concurrent.Callable; /** @@ -82,6 +84,20 @@ public interface RayRuntime { */ void setResource(String resourceName, double capacity, UniqueId nodeId); + T getActorHandle(ActorId actorId); + + /** + * Get a handle to a named actor. + *

+ * Gets a handle to a named actor with the given name. The actor must + * have been created with name specified. + * + * @param name The name of the named actor. + * @param global Whether the named actor is global. + * @return ActorHandle to the actor. + */ + Optional getActor(String name, boolean global); + /** * Kill the actor immediately. * diff --git a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java index c7ed5b009..b86698c60 100644 --- a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java @@ -13,6 +13,7 @@ import io.ray.api.function.PyActorClass; import io.ray.api.function.PyActorMethod; import io.ray.api.function.PyFunction; import io.ray.api.function.RayFunc; +import io.ray.api.id.ActorId; import io.ray.api.id.ObjectId; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.CallOptions; @@ -154,6 +155,11 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { return (PyActorHandle) createActorImpl(functionDescriptor, args, options); } + @SuppressWarnings("unchecked") + @Override + public T getActorHandle(ActorId actorId) { + return (T) taskSubmitter.getActor(actorId); + } @Override public void setAsyncContext(Object asyncContext) { diff --git a/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java index ed9bb4f51..778f3f07e 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java @@ -9,6 +9,7 @@ import io.ray.runtime.context.LocalModeWorkerContext; import io.ray.runtime.object.LocalModeObjectStore; import io.ray.runtime.task.LocalModeTaskExecutor; import io.ray.runtime.task.LocalModeTaskSubmitter; +import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -66,6 +67,12 @@ public class RayDevRuntime extends AbstractRayRuntime { throw new UnsupportedOperationException(); } + @SuppressWarnings("unchecked") + @Override + public Optional getActor(String name, boolean global) { + return (Optional) ((LocalModeTaskSubmitter)taskSubmitter).getActor(name, global); + } + @Override public Object getAsyncContext() { return null; diff --git a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java index 4d8b72c9e..fb506c518 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java @@ -2,6 +2,7 @@ package io.ray.runtime; import com.google.common.base.Preconditions; import io.ray.api.BaseActorHandle; +import io.ray.api.id.ActorId; import io.ray.api.id.JobId; import io.ray.api.id.UniqueId; import io.ray.runtime.config.RayConfig; @@ -19,6 +20,7 @@ import io.ray.runtime.util.JniUtils; import java.io.File; import java.io.IOException; import java.util.Map; +import java.util.Optional; import org.apache.commons.io.FileUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -140,6 +142,18 @@ public final class RayNativeRuntime extends AbstractRayRuntime { nativeSetResource(resourceName, capacity, nodeId.getBytes()); } + @SuppressWarnings("unchecked") + @Override + public Optional getActor(String name, boolean global) { + byte[] actorIdBytes = nativeGetActorIdOfNamedActor(name, global); + ActorId actorId = ActorId.fromBytes(actorIdBytes); + if (actorId.isNil()) { + return Optional.empty(); + } else { + return Optional.of((T) getActorHandle(actorId)); + } + } + @Override public void killActor(BaseActorHandle actor, boolean noRestart) { nativeKillActor(actor.getId().getBytes(), noRestart); @@ -164,7 +178,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime { nativeRunTaskExecutor(taskExecutor); } - private static native void nativeInitialize(int workerMode, String ndoeIpAddress, + private static native void nativeInitialize( + int workerMode, String ndoeIpAddress, int nodeManagerPort, String driverName, String storeSocket, String rayletSocket, byte[] jobId, GcsClientOptions gcsClientOptions, int numWorkersPerProcess, String logDir, Map rayletConfigParameters); @@ -177,6 +192,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime { private static native void nativeKillActor(byte[] actorId, boolean noRestart); + private static native byte[] nativeGetActorIdOfNamedActor(String actorName, boolean global); + private static native void nativeSetCoreWorker(byte[] workerId); static class AsyncContext { diff --git a/java/runtime/src/main/java/io/ray/runtime/actor/LocalModeActorHandle.java b/java/runtime/src/main/java/io/ray/runtime/actor/LocalModeActorHandle.java index c880f330b..ad48590ca 100644 --- a/java/runtime/src/main/java/io/ray/runtime/actor/LocalModeActorHandle.java +++ b/java/runtime/src/main/java/io/ray/runtime/actor/LocalModeActorHandle.java @@ -38,6 +38,10 @@ public class LocalModeActorHandle implements ActorHandle, Externalizable { return this.previousActorTaskDummyObjectId.getAndSet(previousActorTaskDummyObjectId); } + public LocalModeActorHandle copy() { + return new LocalModeActorHandle(this.actorId, this.previousActorTaskDummyObjectId.get()); + } + @Override public synchronized void writeExternal(ObjectOutput out) throws IOException { out.writeObject(actorId); diff --git a/java/runtime/src/main/java/io/ray/runtime/actor/NativeActorHandle.java b/java/runtime/src/main/java/io/ray/runtime/actor/NativeActorHandle.java index f8010e99c..73c88dc93 100644 --- a/java/runtime/src/main/java/io/ray/runtime/actor/NativeActorHandle.java +++ b/java/runtime/src/main/java/io/ray/runtime/actor/NativeActorHandle.java @@ -35,6 +35,12 @@ public abstract class NativeActorHandle implements BaseActorHandle, Externalizab NativeActorHandle() { } + public static NativeActorHandle create(byte[] actorId) { + Language language = Language.forNumber(nativeGetLanguage(actorId)); + Preconditions.checkState(language != null, "Language shouldn't be null"); + return create(actorId, language); + } + public static NativeActorHandle create(byte[] actorId, Language language) { switch (language) { case JAVA: diff --git a/java/runtime/src/main/java/io/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/io/ray/runtime/runner/RunManager.java index 6e2ed7345..4c8ea51b8 100644 --- a/java/runtime/src/main/java/io/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/io/ray/runtime/runner/RunManager.java @@ -219,7 +219,9 @@ public class RunManager { // Register the number of Redis shards in the primary shard, so that clients // know how many redis shards to expect under RedisShards. client.set("NumRedisShards", Integer.toString(rayConfig.numberRedisShards)); - + // Set session dir for this cluster, so that the drivers which connected to this + // cluster will fetch this session dir as its self's session dir. + client.set("session_dir", rayConfig.getSessionDir()); // start redis shards for (int i = 0; i < rayConfig.numberRedisShards; i++) { String shard = startRedisInstance(rayConfig.nodeIp, diff --git a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java index 05a622170..2358982a6 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java @@ -3,7 +3,9 @@ package io.ray.runtime.task; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; +import io.ray.api.ActorHandle; import io.ray.api.BaseActorHandle; +import io.ray.api.Ray; import io.ray.api.id.ActorId; import io.ray.api.id.ObjectId; import io.ray.api.id.TaskId; @@ -32,6 +34,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Random; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -39,6 +42,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.RejectedExecutionException; import java.util.stream.Collectors; +import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -61,11 +65,14 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { /// The thread pool to execute normal tasks. private final ExecutorService normalTaskExecutorService; + private final Map actorHandles = new ConcurrentHashMap<>(); + + private final Map namedActors = new ConcurrentHashMap<>(); private final Map actorContexts = new ConcurrentHashMap<>(); public LocalModeTaskSubmitter(RayRuntimeInternal runtime, TaskExecutor taskExecutor, - LocalModeObjectStore objectStore) { + LocalModeObjectStore objectStore) { this.runtime = runtime; this.taskExecutor = taskExecutor; this.objectStore = objectStore; @@ -126,11 +133,11 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { ByteString.copyFrom(runtime.getRayConfig().getJobId().getBytes())) .setTaskId(ByteString.copyFrom(taskIdBytes)) .setFunctionDescriptor(Common.FunctionDescriptor.newBuilder() - .setJavaFunctionDescriptor( - Common.JavaFunctionDescriptor.newBuilder() - .setClassName(functionDescriptorList.get(0)) - .setFunctionName(functionDescriptorList.get(1)) - .setSignature(functionDescriptorList.get(2)))) + .setJavaFunctionDescriptor( + Common.JavaFunctionDescriptor.newBuilder() + .setClassName(functionDescriptorList.get(0)) + .setFunctionName(functionDescriptorList.get(1)) + .setSignature(functionDescriptorList.get(2)))) .addAllArgs(args.stream().map(arg -> arg.id != null ? TaskArg.newBuilder() .setObjectRef(ObjectReference.newBuilder().setObjectId( ByteString.copyFrom(arg.id.getBytes()))).build() @@ -152,8 +159,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { } @Override - public BaseActorHandle createActor(FunctionDescriptor functionDescriptor, List args, - ActorCreationOptions options) { + public BaseActorHandle createActor( + FunctionDescriptor functionDescriptor, List args, + ActorCreationOptions options) throws IllegalArgumentException { ActorId actorId = ActorId.fromRandom(); TaskSpec taskSpec = getTaskSpecBuilder(TaskType.ACTOR_CREATION_TASK, functionDescriptor, args) .setNumReturns(1) @@ -162,7 +170,15 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { .build()) .build(); submitTaskSpec(taskSpec); - return new LocalModeActorHandle(actorId, getReturnIds(taskSpec).get(0)); + final LocalModeActorHandle actorHandle + = new LocalModeActorHandle(actorId, getReturnIds(taskSpec).get(0)); + actorHandles.put(actorId, actorHandle.copy()); + if (StringUtils.isNotBlank(options.name)) { + Preconditions.checkArgument(!namedActors.containsKey(options.name), + String.format("Actor of name %s exists", options.name)); + namedActors.put(options.name, actorHandle); + } + return actorHandle; } @Override @@ -191,6 +207,21 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { } } + @Override + public BaseActorHandle getActor(ActorId actorId) { + return actorHandles.get(actorId).copy(); + } + + public Optional getActor(String name, boolean global) { + String fullName = global ? name : + String.format("%s-%s", Ray.getRuntimeContext().getCurrentJobId(), name); + if (namedActors.containsKey(fullName)) { + return Optional.of(namedActors.get(fullName)); + } else { + return Optional.empty(); + } + } + public void shutdown() { // Shutdown actor task executor service. synchronized (actorTaskExecutorServices) { @@ -300,7 +331,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { // If the task is an actor task or an actor creation task, // put the dummy object in object store, so those tasks which depends on it // can be executed. - putObject = new NativeRayObject(new byte[]{1}, null); + putObject = new NativeRayObject(new byte[] {1}, null); } else { putObject = returnObjects.get(i); } @@ -310,13 +341,13 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { private static JavaFunctionDescriptor getJavaFunctionDescriptor(TaskSpec taskSpec) { Common.FunctionDescriptor functionDescriptor = - taskSpec.getFunctionDescriptor(); + taskSpec.getFunctionDescriptor(); if (functionDescriptor.getFunctionDescriptorCase() == - Common.FunctionDescriptor.FunctionDescriptorCase.JAVA_FUNCTION_DESCRIPTOR) { + Common.FunctionDescriptor.FunctionDescriptorCase.JAVA_FUNCTION_DESCRIPTOR) { return new JavaFunctionDescriptor( - functionDescriptor.getJavaFunctionDescriptor().getClassName(), - functionDescriptor.getJavaFunctionDescriptor().getFunctionName(), - functionDescriptor.getJavaFunctionDescriptor().getSignature()); + functionDescriptor.getJavaFunctionDescriptor().getClassName(), + functionDescriptor.getJavaFunctionDescriptor().getFunctionName(), + functionDescriptor.getJavaFunctionDescriptor().getSignature()); } else { throw new RuntimeException("Can't build non java function descriptor"); } diff --git a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java index 5b0d79efe..eba97f1fb 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java @@ -3,13 +3,17 @@ package io.ray.runtime.task; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import io.ray.api.BaseActorHandle; +import io.ray.api.Ray; +import io.ray.api.id.ActorId; import io.ray.api.id.ObjectId; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.CallOptions; import io.ray.runtime.actor.NativeActorHandle; import io.ray.runtime.functionmanager.FunctionDescriptor; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; +import org.apache.commons.lang3.StringUtils; /** * Task submitter for cluster mode. This is a wrapper class for core worker task interface. @@ -29,12 +33,23 @@ public class NativeTaskSubmitter implements TaskSubmitter { @Override public BaseActorHandle createActor(FunctionDescriptor functionDescriptor, List args, - ActorCreationOptions options) { + ActorCreationOptions options) throws IllegalArgumentException { + if (StringUtils.isNotBlank(options.name)) { + Optional actor = + options.global ? Ray.getGlobalActor(options.name) : Ray.getActor(options.name); + Preconditions.checkArgument(!actor.isPresent(), + String.format("Actor of name %s exists", options.name)); + } byte[] actorId = nativeCreateActor(functionDescriptor, functionDescriptor.hashCode(), args, options); return NativeActorHandle.create(actorId, functionDescriptor.getLanguage()); } + @Override + public BaseActorHandle getActor(ActorId actorId) { + return NativeActorHandle.create(actorId.getBytes()); + } + @Override public List submitActorTask( BaseActorHandle actor, FunctionDescriptor functionDescriptor, diff --git a/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java index c642a8630..1ea6b86bb 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java @@ -1,6 +1,7 @@ package io.ray.runtime.task; import io.ray.api.BaseActorHandle; +import io.ray.api.id.ActorId; import io.ray.api.id.ObjectId; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.CallOptions; @@ -29,9 +30,10 @@ public interface TaskSubmitter { * @param args Arguments of this task. * @param options Options for this actor creation task. * @return Handle to the actor. + * @throws IllegalArgumentException if actor of specified name exists */ BaseActorHandle createActor(FunctionDescriptor functionDescriptor, List args, - ActorCreationOptions options); + ActorCreationOptions options) throws IllegalArgumentException; /** * Submit an actor task. @@ -44,4 +46,7 @@ public interface TaskSubmitter { */ List submitActorTask(BaseActorHandle actor, FunctionDescriptor functionDescriptor, List args, int numReturns, CallOptions options); + + BaseActorHandle getActor(ActorId actorId); + } diff --git a/java/test/src/main/java/io/ray/test/NamedActorTest.java b/java/test/src/main/java/io/ray/test/NamedActorTest.java new file mode 100644 index 000000000..b9bdb2dfe --- /dev/null +++ b/java/test/src/main/java/io/ray/test/NamedActorTest.java @@ -0,0 +1,99 @@ +package io.ray.test; + +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import io.ray.runtime.config.RayConfig; +import java.io.IOException; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class NamedActorTest extends BaseTest { + + public static class Counter { + + private int value = 0; + + public int increment() { + this.value += 1; + return this.value; + } + } + + @Test + public void testNamedActor() { + String name = "named-actor-counter"; + // Create an actor. + ActorHandle actor = Ray.actor(Counter::new).setName(name).remote(); + Assert.assertEquals(actor.task(Counter::increment).remote().get(), Integer.valueOf(1)); + // Get the named actor. + Assert.assertTrue(Ray.getActor(name).isPresent()); + Optional> namedActor = Ray.getActor(name); + Assert.assertTrue(namedActor.isPresent()); + // Verify that this handle is correct. + Assert.assertEquals(namedActor.get().task(Counter::increment).remote().get(), + Integer.valueOf(2)); + } + + @Test + public void testGlobalActor() throws IOException, InterruptedException { + String name = "global-actor-counter"; + // Create an actor. + ActorHandle actor = Ray.actor(Counter::new).setGlobalName(name).remote(); + Assert.assertEquals(actor.task(Counter::increment).remote().get(), Integer.valueOf(1)); + + Assert.assertFalse(Ray.getActor(name).isPresent()); + + // Get the global actor. + Optional> namedActor = Ray.getGlobalActor(name); + Assert.assertTrue(namedActor.isPresent()); + // Verify that this handle is correct. + Assert.assertEquals(namedActor.get().task(Counter::increment).remote().get(), + Integer.valueOf(2)); + + // Get the global actor from another driver. + RayConfig rayConfig = TestUtils.getRuntime().getRayConfig(); + ProcessBuilder builder = new ProcessBuilder( + "java", + "-cp", + System.getProperty("java.class.path"), + "-Dray.redis.address=" + rayConfig.getRedisAddress(), + "-Dray.object-store.socket-name=" + rayConfig.objectStoreSocketName, + "-Dray.raylet.socket-name=" + rayConfig.rayletSocketName, + "-Dray.raylet.node-manager-port=" + rayConfig.getNodeManagerPort(), + NamedActorTest.class.getName(), + name); + builder.redirectError(ProcessBuilder.Redirect.INHERIT); + Process driver = builder.start(); + Assert.assertTrue(driver.waitFor(60, TimeUnit.SECONDS)); + Assert.assertEquals(driver.exitValue(), 0, + "The driver exited with code " + driver.exitValue()); + + Assert.assertEquals(namedActor.get().task(Counter::increment).remote().get(), + Integer.valueOf(4)); + } + + public static void main(String[] args) { + Ray.init(); + String actorName = args[0]; + // Get the global actor. + Optional> namedActor = Ray.getGlobalActor(actorName); + Assert.assertTrue(namedActor.isPresent()); + // Verify that this handle is correct. + Assert.assertEquals(namedActor.get().task(Counter::increment).remote().get(), + Integer.valueOf(3)); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testActorDuplicatedName() { + String name = "named-actor-counter"; + // Create an actor. + ActorHandle actor = Ray.actor(Counter::new).setName(name).remote(); + // Ensure async actor creation is finished. + Assert.assertEquals(actor.task(Counter::increment).remote().get(), Integer.valueOf(1)); + // Registering with the same name should fail. + Ray.actor(Counter::new).setName(name).remote(); + } + +} diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index 6c8455411..8f9e40de4 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -19,6 +19,7 @@ #include #include "ray/common/id.h" +#include "ray/core_worker/actor_handle.h" #include "ray/core_worker/core_worker.h" #include "ray/core_worker/lib/java/jni_utils.h" @@ -177,6 +178,27 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeSetResource( THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_RayNativeRuntime_nativeGetActorIdOfNamedActor(JNIEnv *env, jclass, + jstring actor_name, + jboolean global) { + const char *native_actor_name = env->GetStringUTFChars(actor_name, JNI_FALSE); + auto full_name = GetActorFullName(global, native_actor_name); + + auto *actor_handle = + ray::CoreWorkerProcess::GetCoreWorker().GetNamedActorHandle(full_name); + ray::ActorID actor_id; + if (actor_handle) { + actor_id = actor_handle->GetActorID(); + } else { + actor_id = ray::ActorID::Nil(); + } + jbyteArray bytes = env->NewByteArray(actor_id.Size()); + env->SetByteArrayRegion(bytes, 0, actor_id.Size(), + reinterpret_cast(actor_id.Data())); + return bytes; +} + JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeKillActor( JNIEnv *env, jclass, jbyteArray actorId, jboolean noRestart) { auto status = ray::CoreWorkerProcess::GetCoreWorker().KillActor( diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h index 338a7dfd8..fa509697a 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h @@ -65,6 +65,15 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeKillActor(JNIE jbyteArray, jboolean); +/* + * Class: io_ray_runtime_RayNativeRuntime + * Method: nativeGetActorIdOfNamedActor + * Signature: (Ljava/lang/String;Z)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_RayNativeRuntime_nativeGetActorIdOfNamedActor(JNIEnv *, jclass, + jstring, jboolean); + /* * Class: io_ray_runtime_RayNativeRuntime * Method: nativeSetCoreWorker diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.h b/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.h index ddd248dbe..b46edddcd 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.h @@ -26,32 +26,34 @@ extern "C" { * Method: nativeGetLanguage * Signature: ([B)I */ -JNIEXPORT jint JNICALL Java_io_ray_runtime_actor_NativeActorHandle_nativeGetLanguage - (JNIEnv *, jclass, jbyteArray); +JNIEXPORT jint JNICALL Java_io_ray_runtime_actor_NativeActorHandle_nativeGetLanguage( + JNIEnv *, jclass, jbyteArray); /* * Class: io_ray_runtime_actor_NativeActorHandle * Method: nativeGetActorCreationTaskFunctionDescriptor * Signature: ([B)Ljava/util/List; */ -JNIEXPORT jobject JNICALL Java_io_ray_runtime_actor_NativeActorHandle_nativeGetActorCreationTaskFunctionDescriptor - (JNIEnv *, jclass, jbyteArray); +JNIEXPORT jobject JNICALL +Java_io_ray_runtime_actor_NativeActorHandle_nativeGetActorCreationTaskFunctionDescriptor( + JNIEnv *, jclass, jbyteArray); /* * Class: io_ray_runtime_actor_NativeActorHandle * Method: nativeSerialize * Signature: ([B)[B */ -JNIEXPORT jbyteArray JNICALL Java_io_ray_runtime_actor_NativeActorHandle_nativeSerialize - (JNIEnv *, jclass, jbyteArray); +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_actor_NativeActorHandle_nativeSerialize(JNIEnv *, jclass, jbyteArray); /* * Class: io_ray_runtime_actor_NativeActorHandle * Method: nativeDeserialize * Signature: ([B)[B */ -JNIEXPORT jbyteArray JNICALL Java_io_ray_runtime_actor_NativeActorHandle_nativeDeserialize - (JNIEnv *, jclass, jbyteArray); +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_actor_NativeActorHandle_nativeDeserialize(JNIEnv *, jclass, + jbyteArray); #ifdef __cplusplus } diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index af0e97d99..33884f05a 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -103,11 +103,20 @@ inline ray::TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject call inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, jobject actorCreationOptions) { + bool global = false; + std::string name = ""; int64_t max_restarts = 0; std::unordered_map resources; std::vector dynamic_worker_options; uint64_t max_concurrency = 1; if (actorCreationOptions) { + global = + env->GetBooleanField(actorCreationOptions, java_actor_creation_options_global); + auto java_name = (jstring)env->GetObjectField(actorCreationOptions, + java_actor_creation_options_name); + if (java_name) { + name = JavaStringToNativeString(env, java_name); + } max_restarts = env->GetIntField(actorCreationOptions, java_actor_creation_options_max_restarts); jobject java_resources = @@ -123,7 +132,7 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, actorCreationOptions, java_actor_creation_options_max_concurrency)); } - std::string name = ""; + auto full_name = GetActorFullName(global, name); ray::ActorCreationOptions actor_creation_options{ max_restarts, 0, // TODO: Allow setting max_task_retries from Java. @@ -132,7 +141,7 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, resources, dynamic_worker_options, /*is_detached=*/false, - name, + full_name, /*is_asyncio=*/false}; return actor_creation_options; } @@ -195,7 +204,6 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask( ray::CoreWorkerProcess::GetCoreWorker().SubmitActorTask( actor_id, ray_function, task_args, task_options, &return_ids); - // This is to avoid creating an empty java list and boost performance. if (return_ids.empty()) { return nullptr; diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 1805112f6..b2c3065fd 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -69,6 +69,8 @@ jclass java_base_task_options_class; jfieldID java_base_task_options_resources; jclass java_actor_creation_options_class; +jfieldID java_actor_creation_options_global; +jfieldID java_actor_creation_options_name; jfieldID java_actor_creation_options_max_restarts; jfieldID java_actor_creation_options_jvm_options; jfieldID java_actor_creation_options_max_concurrency; @@ -176,6 +178,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_actor_creation_options_class = LoadClass(env, "io/ray/api/options/ActorCreationOptions"); + java_actor_creation_options_global = + env->GetFieldID(java_actor_creation_options_class, "global", "Z"); + java_actor_creation_options_name = + env->GetFieldID(java_actor_creation_options_class, "name", "Ljava/lang/String;"); java_actor_creation_options_max_restarts = env->GetFieldID(java_actor_creation_options_class, "maxRestarts", "I"); java_actor_creation_options_jvm_options = env->GetFieldID( diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index f13d6f7b4..b18cfa2fc 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -21,6 +21,7 @@ #include "ray/common/id.h" #include "ray/common/ray_object.h" #include "ray/common/status.h" +#include "ray/core_worker/core_worker.h" /// Boolean class extern jclass java_boolean_class; @@ -116,6 +117,10 @@ extern jfieldID java_base_task_options_resources; /// ActorCreationOptions class extern jclass java_actor_creation_options_class; +/// global field of ActorCreationOptions class +extern jfieldID java_actor_creation_options_global; +/// name field of ActorCreationOptions class +extern jfieldID java_actor_creation_options_name; /// maxRestarts field of ActorCreationOptions class extern jfieldID java_actor_creation_options_max_restarts; /// jvmOptions field of ActorCreationOptions class @@ -426,3 +431,10 @@ inline jobject NativeRayFunctionDescriptorToJavaStringList( RAY_LOG(FATAL) << "Unknown function descriptor type: " << function_descriptor->Type(); return NativeStringVectorToJavaStringList(env, std::vector()); } + +// Return an actor fullname with job id prepended if this tis a global actor. +inline std::string GetActorFullName(bool global, std::string name) { + return global ? name + : ::ray::CoreWorkerProcess::GetCoreWorker().GetCurrentJobId().Hex() + "-" + + name; +} \ No newline at end of file