[Java] Named java actor (#9037)

This commit is contained in:
chaokunyang
2020-07-16 11:31:18 +08:00
committed by GitHub
parent 5e2571e214
commit 9318e76b81
21 changed files with 405 additions and 40 deletions
@@ -13,6 +13,7 @@ import io.ray.api.function.PyActorClass;
import io.ray.api.function.PyActorMethod;
import io.ray.api.function.PyFunction;
import io.ray.api.function.RayFunc;
import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId;
import io.ray.api.options.ActorCreationOptions;
import io.ray.api.options.CallOptions;
@@ -154,6 +155,11 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
return (PyActorHandle) createActorImpl(functionDescriptor, args, options);
}
@SuppressWarnings("unchecked")
@Override
public <T extends BaseActorHandle> T getActorHandle(ActorId actorId) {
return (T) taskSubmitter.getActor(actorId);
}
@Override
public void setAsyncContext(Object asyncContext) {
@@ -9,6 +9,7 @@ import io.ray.runtime.context.LocalModeWorkerContext;
import io.ray.runtime.object.LocalModeObjectStore;
import io.ray.runtime.task.LocalModeTaskExecutor;
import io.ray.runtime.task.LocalModeTaskSubmitter;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -66,6 +67,12 @@ public class RayDevRuntime extends AbstractRayRuntime {
throw new UnsupportedOperationException();
}
@SuppressWarnings("unchecked")
@Override
public <T extends BaseActorHandle> Optional<T> getActor(String name, boolean global) {
return (Optional<T>) ((LocalModeTaskSubmitter)taskSubmitter).getActor(name, global);
}
@Override
public Object getAsyncContext() {
return null;
@@ -2,6 +2,7 @@ package io.ray.runtime;
import com.google.common.base.Preconditions;
import io.ray.api.BaseActorHandle;
import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.config.RayConfig;
@@ -19,6 +20,7 @@ import io.ray.runtime.util.JniUtils;
import java.io.File;
import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -140,6 +142,18 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
nativeSetResource(resourceName, capacity, nodeId.getBytes());
}
@SuppressWarnings("unchecked")
@Override
public <T extends BaseActorHandle> Optional<T> getActor(String name, boolean global) {
byte[] actorIdBytes = nativeGetActorIdOfNamedActor(name, global);
ActorId actorId = ActorId.fromBytes(actorIdBytes);
if (actorId.isNil()) {
return Optional.empty();
} else {
return Optional.of((T) getActorHandle(actorId));
}
}
@Override
public void killActor(BaseActorHandle actor, boolean noRestart) {
nativeKillActor(actor.getId().getBytes(), noRestart);
@@ -164,7 +178,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
nativeRunTaskExecutor(taskExecutor);
}
private static native void nativeInitialize(int workerMode, String ndoeIpAddress,
private static native void nativeInitialize(
int workerMode, String ndoeIpAddress,
int nodeManagerPort, String driverName, String storeSocket, String rayletSocket,
byte[] jobId, GcsClientOptions gcsClientOptions, int numWorkersPerProcess,
String logDir, Map<String, String> rayletConfigParameters);
@@ -177,6 +192,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
private static native void nativeKillActor(byte[] actorId, boolean noRestart);
private static native byte[] nativeGetActorIdOfNamedActor(String actorName, boolean global);
private static native void nativeSetCoreWorker(byte[] workerId);
static class AsyncContext {
@@ -38,6 +38,10 @@ public class LocalModeActorHandle implements ActorHandle, Externalizable {
return this.previousActorTaskDummyObjectId.getAndSet(previousActorTaskDummyObjectId);
}
public LocalModeActorHandle copy() {
return new LocalModeActorHandle(this.actorId, this.previousActorTaskDummyObjectId.get());
}
@Override
public synchronized void writeExternal(ObjectOutput out) throws IOException {
out.writeObject(actorId);
@@ -35,6 +35,12 @@ public abstract class NativeActorHandle implements BaseActorHandle, Externalizab
NativeActorHandle() {
}
public static NativeActorHandle create(byte[] actorId) {
Language language = Language.forNumber(nativeGetLanguage(actorId));
Preconditions.checkState(language != null, "Language shouldn't be null");
return create(actorId, language);
}
public static NativeActorHandle create(byte[] actorId, Language language) {
switch (language) {
case JAVA:
@@ -219,7 +219,9 @@ public class RunManager {
// Register the number of Redis shards in the primary shard, so that clients
// know how many redis shards to expect under RedisShards.
client.set("NumRedisShards", Integer.toString(rayConfig.numberRedisShards));
// Set session dir for this cluster, so that the drivers which connected to this
// cluster will fetch this session dir as its self's session dir.
client.set("session_dir", rayConfig.getSessionDir());
// start redis shards
for (int i = 0; i < rayConfig.numberRedisShards; i++) {
String shard = startRedisInstance(rayConfig.nodeIp,
@@ -3,7 +3,9 @@ package io.ray.runtime.task;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import io.ray.api.ActorHandle;
import io.ray.api.BaseActorHandle;
import io.ray.api.Ray;
import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId;
import io.ray.api.id.TaskId;
@@ -32,6 +34,7 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
@@ -39,6 +42,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -61,11 +65,14 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
/// The thread pool to execute normal tasks.
private final ExecutorService normalTaskExecutorService;
private final Map<ActorId, LocalModeActorHandle> actorHandles = new ConcurrentHashMap<>();
private final Map<String, ActorHandle> namedActors = new ConcurrentHashMap<>();
private final Map<ActorId, TaskExecutor.ActorContext> actorContexts = new ConcurrentHashMap<>();
public LocalModeTaskSubmitter(RayRuntimeInternal runtime, TaskExecutor taskExecutor,
LocalModeObjectStore objectStore) {
LocalModeObjectStore objectStore) {
this.runtime = runtime;
this.taskExecutor = taskExecutor;
this.objectStore = objectStore;
@@ -126,11 +133,11 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
ByteString.copyFrom(runtime.getRayConfig().getJobId().getBytes()))
.setTaskId(ByteString.copyFrom(taskIdBytes))
.setFunctionDescriptor(Common.FunctionDescriptor.newBuilder()
.setJavaFunctionDescriptor(
Common.JavaFunctionDescriptor.newBuilder()
.setClassName(functionDescriptorList.get(0))
.setFunctionName(functionDescriptorList.get(1))
.setSignature(functionDescriptorList.get(2))))
.setJavaFunctionDescriptor(
Common.JavaFunctionDescriptor.newBuilder()
.setClassName(functionDescriptorList.get(0))
.setFunctionName(functionDescriptorList.get(1))
.setSignature(functionDescriptorList.get(2))))
.addAllArgs(args.stream().map(arg -> arg.id != null ? TaskArg.newBuilder()
.setObjectRef(ObjectReference.newBuilder().setObjectId(
ByteString.copyFrom(arg.id.getBytes()))).build()
@@ -152,8 +159,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
@Override
public BaseActorHandle createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions options) {
public BaseActorHandle createActor(
FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions options) throws IllegalArgumentException {
ActorId actorId = ActorId.fromRandom();
TaskSpec taskSpec = getTaskSpecBuilder(TaskType.ACTOR_CREATION_TASK, functionDescriptor, args)
.setNumReturns(1)
@@ -162,7 +170,15 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
.build())
.build();
submitTaskSpec(taskSpec);
return new LocalModeActorHandle(actorId, getReturnIds(taskSpec).get(0));
final LocalModeActorHandle actorHandle
= new LocalModeActorHandle(actorId, getReturnIds(taskSpec).get(0));
actorHandles.put(actorId, actorHandle.copy());
if (StringUtils.isNotBlank(options.name)) {
Preconditions.checkArgument(!namedActors.containsKey(options.name),
String.format("Actor of name %s exists", options.name));
namedActors.put(options.name, actorHandle);
}
return actorHandle;
}
@Override
@@ -191,6 +207,21 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
}
@Override
public BaseActorHandle getActor(ActorId actorId) {
return actorHandles.get(actorId).copy();
}
public Optional<BaseActorHandle> getActor(String name, boolean global) {
String fullName = global ? name :
String.format("%s-%s", Ray.getRuntimeContext().getCurrentJobId(), name);
if (namedActors.containsKey(fullName)) {
return Optional.of(namedActors.get(fullName));
} else {
return Optional.empty();
}
}
public void shutdown() {
// Shutdown actor task executor service.
synchronized (actorTaskExecutorServices) {
@@ -300,7 +331,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
// If the task is an actor task or an actor creation task,
// put the dummy object in object store, so those tasks which depends on it
// can be executed.
putObject = new NativeRayObject(new byte[]{1}, null);
putObject = new NativeRayObject(new byte[] {1}, null);
} else {
putObject = returnObjects.get(i);
}
@@ -310,13 +341,13 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
private static JavaFunctionDescriptor getJavaFunctionDescriptor(TaskSpec taskSpec) {
Common.FunctionDescriptor functionDescriptor =
taskSpec.getFunctionDescriptor();
taskSpec.getFunctionDescriptor();
if (functionDescriptor.getFunctionDescriptorCase() ==
Common.FunctionDescriptor.FunctionDescriptorCase.JAVA_FUNCTION_DESCRIPTOR) {
Common.FunctionDescriptor.FunctionDescriptorCase.JAVA_FUNCTION_DESCRIPTOR) {
return new JavaFunctionDescriptor(
functionDescriptor.getJavaFunctionDescriptor().getClassName(),
functionDescriptor.getJavaFunctionDescriptor().getFunctionName(),
functionDescriptor.getJavaFunctionDescriptor().getSignature());
functionDescriptor.getJavaFunctionDescriptor().getClassName(),
functionDescriptor.getJavaFunctionDescriptor().getFunctionName(),
functionDescriptor.getJavaFunctionDescriptor().getSignature());
} else {
throw new RuntimeException("Can't build non java function descriptor");
}
@@ -3,13 +3,17 @@ package io.ray.runtime.task;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.ray.api.BaseActorHandle;
import io.ray.api.Ray;
import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId;
import io.ray.api.options.ActorCreationOptions;
import io.ray.api.options.CallOptions;
import io.ray.runtime.actor.NativeActorHandle;
import io.ray.runtime.functionmanager.FunctionDescriptor;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
/**
* Task submitter for cluster mode. This is a wrapper class for core worker task interface.
@@ -29,12 +33,23 @@ public class NativeTaskSubmitter implements TaskSubmitter {
@Override
public BaseActorHandle createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions options) {
ActorCreationOptions options) throws IllegalArgumentException {
if (StringUtils.isNotBlank(options.name)) {
Optional<BaseActorHandle> actor =
options.global ? Ray.getGlobalActor(options.name) : Ray.getActor(options.name);
Preconditions.checkArgument(!actor.isPresent(),
String.format("Actor of name %s exists", options.name));
}
byte[] actorId = nativeCreateActor(functionDescriptor, functionDescriptor.hashCode(), args,
options);
return NativeActorHandle.create(actorId, functionDescriptor.getLanguage());
}
@Override
public BaseActorHandle getActor(ActorId actorId) {
return NativeActorHandle.create(actorId.getBytes());
}
@Override
public List<ObjectId> submitActorTask(
BaseActorHandle actor, FunctionDescriptor functionDescriptor,
@@ -1,6 +1,7 @@
package io.ray.runtime.task;
import io.ray.api.BaseActorHandle;
import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId;
import io.ray.api.options.ActorCreationOptions;
import io.ray.api.options.CallOptions;
@@ -29,9 +30,10 @@ public interface TaskSubmitter {
* @param args Arguments of this task.
* @param options Options for this actor creation task.
* @return Handle to the actor.
* @throws IllegalArgumentException if actor of specified name exists
*/
BaseActorHandle createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions options);
ActorCreationOptions options) throws IllegalArgumentException;
/**
* Submit an actor task.
@@ -44,4 +46,7 @@ public interface TaskSubmitter {
*/
List<ObjectId> submitActorTask(BaseActorHandle actor, FunctionDescriptor functionDescriptor,
List<FunctionArg> args, int numReturns, CallOptions options);
BaseActorHandle getActor(ActorId actorId);
}