diff --git a/BUILD.bazel b/BUILD.bazel index 80b6c1b96..05da0f0a7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -748,8 +748,6 @@ cc_binary( srcs = glob([ "src/ray/core_worker/lib/java/*.h", "src/ray/core_worker/lib/java/*.cc", - "src/ray/raylet/lib/java/*.h", - "src/ray/raylet/lib/java/*.cc", ]) + [ "@bazel_tools//tools/jdk:jni_header", ] + select({ diff --git a/java/api/src/main/java/org/ray/api/ObjectType.java b/java/api/src/main/java/org/ray/api/ObjectType.java deleted file mode 100644 index c0dd63f22..000000000 --- a/java/api/src/main/java/org/ray/api/ObjectType.java +++ /dev/null @@ -1,6 +0,0 @@ -package org.ray.api; - -public enum ObjectType { - PUT_OBJECT, - RETURN_OBJECT, -} diff --git a/java/api/src/main/java/org/ray/api/id/ActorId.java b/java/api/src/main/java/org/ray/api/id/ActorId.java index 1953b2403..340f5c1d1 100644 --- a/java/api/src/main/java/org/ray/api/id/ActorId.java +++ b/java/api/src/main/java/org/ray/api/id/ActorId.java @@ -2,14 +2,11 @@ package org.ray.api.id; import java.io.Serializable; import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.util.Arrays; import java.util.Random; public class ActorId extends BaseId implements Serializable { - private static final int UNIQUE_BYTES_LENGTH = 4; - - public static final int LENGTH = UNIQUE_BYTES_LENGTH + JobId.LENGTH; + public static final int LENGTH = 8; public static final ActorId NIL = nil(); @@ -25,19 +22,6 @@ public class ActorId extends BaseId implements Serializable { return new ActorId(bytes); } - public static ActorId generateActorId(JobId jobId) { - byte[] uniqueBytes = new byte[ActorId.UNIQUE_BYTES_LENGTH]; - new Random().nextBytes(uniqueBytes); - - byte[] bytes = new byte[ActorId.LENGTH]; - ByteBuffer wbb = ByteBuffer.wrap(bytes); - wbb.order(ByteOrder.LITTLE_ENDIAN); - - System.arraycopy(uniqueBytes, 0, bytes, 0, ActorId.UNIQUE_BYTES_LENGTH); - System.arraycopy(jobId.getBytes(), 0, bytes, ActorId.UNIQUE_BYTES_LENGTH, JobId.LENGTH); - return new ActorId(bytes); - } - /** * Generate a nil ActorId. */ @@ -47,6 +31,15 @@ public class ActorId extends BaseId implements Serializable { return new ActorId(b); } + /** + * Generate an ActorId with random value. Used for local mode and test only. + */ + public static ActorId fromRandom() { + byte[] b = new byte[LENGTH]; + new Random().nextBytes(b); + return new ActorId(b); + } + @Override public int size() { return LENGTH; diff --git a/java/api/src/main/java/org/ray/api/id/ObjectId.java b/java/api/src/main/java/org/ray/api/id/ObjectId.java index bf140ee90..83df2a8ca 100644 --- a/java/api/src/main/java/org/ray/api/id/ObjectId.java +++ b/java/api/src/main/java/org/ray/api/id/ObjectId.java @@ -2,10 +2,8 @@ package org.ray.api.id; import java.io.Serializable; import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.util.Arrays; import java.util.Random; -import org.ray.api.ObjectType; /** * Represents the id of a Ray object. @@ -16,20 +14,6 @@ public class ObjectId extends BaseId implements Serializable { public static final ObjectId NIL = genNil(); - private static int CREATED_BY_TASK_FLAG_BITS_OFFSET = 15; - - private static int OBJECT_TYPE_FLAG_BITS_OFFSET = 14; - - private static int TRANSPORT_TYPE_FLAG_BITS_OFFSET = 11; - - private static int FLAGS_BYTES_POS = TaskId.LENGTH; - - private static int FLAGS_BYTES_LENGTH = 2; - - private static int INDEX_BYTES_POS = FLAGS_BYTES_POS + FLAGS_BYTES_LENGTH; - - private static int INDEX_BYTES_LENGTH = 4; - /** * Create an ObjectId from a ByteBuffer. */ @@ -55,48 +39,6 @@ public class ObjectId extends BaseId implements Serializable { return new ObjectId(b); } - /** - * Compute the object ID of an object put by the task. - */ - public static ObjectId forPut(TaskId taskId, int putIndex) { - short flags = 0; - flags = setCreatedByTaskFlag(flags, true); - // Set a default transport type with value 0. - flags = (short) (flags | (0x0 << TRANSPORT_TYPE_FLAG_BITS_OFFSET)); - flags = setObjectTypeFlag(flags, ObjectType.PUT_OBJECT); - - byte[] bytes = new byte[ObjectId.LENGTH]; - System.arraycopy(taskId.getBytes(), 0, bytes, 0, TaskId.LENGTH); - - ByteBuffer wbb = ByteBuffer.wrap(bytes); - wbb.order(ByteOrder.LITTLE_ENDIAN); - wbb.putShort(FLAGS_BYTES_POS, flags); - - wbb.putInt(INDEX_BYTES_POS, putIndex); - return new ObjectId(bytes); - } - - /** - * Compute the object ID of an object return by the task. - */ - public static ObjectId forReturn(TaskId taskId, int returnIndex) { - short flags = 0; - flags = setCreatedByTaskFlag(flags, true); - // Set a default transport type with value 0. - flags = (short) (flags | (0x0 << TRANSPORT_TYPE_FLAG_BITS_OFFSET)); - flags = setObjectTypeFlag(flags, ObjectType.RETURN_OBJECT); - - byte[] bytes = new byte[ObjectId.LENGTH]; - System.arraycopy(taskId.getBytes(), 0, bytes, 0, TaskId.LENGTH); - - ByteBuffer wbb = ByteBuffer.wrap(bytes); - wbb.order(ByteOrder.LITTLE_ENDIAN); - wbb.putShort(FLAGS_BYTES_POS, flags); - - wbb.putInt(INDEX_BYTES_POS, returnIndex); - return new ObjectId(bytes); - } - public ObjectId(byte[] id) { super(id); } @@ -106,25 +48,4 @@ public class ObjectId extends BaseId implements Serializable { return LENGTH; } - public TaskId getTaskId() { - byte[] taskIdBytes = Arrays.copyOf(getBytes(), TaskId.LENGTH); - return TaskId.fromBytes(taskIdBytes); - } - - private static short setCreatedByTaskFlag(short flags, boolean createdByTask) { - if (createdByTask) { - return (short) (flags | (0x1 << CREATED_BY_TASK_FLAG_BITS_OFFSET)); - } else { - return (short) (flags | (0x0 << CREATED_BY_TASK_FLAG_BITS_OFFSET)); - } - } - - private static short setObjectTypeFlag(short flags, ObjectType objectType) { - if (objectType == ObjectType.RETURN_OBJECT) { - return (short)(flags | (0x1 << OBJECT_TYPE_FLAG_BITS_OFFSET)); - } else { - return (short)(flags | (0x0 << OBJECT_TYPE_FLAG_BITS_OFFSET)); - } - } - } diff --git a/java/api/src/main/java/org/ray/api/id/TaskId.java b/java/api/src/main/java/org/ray/api/id/TaskId.java index 0f2ee1e03..517bb2e32 100644 --- a/java/api/src/main/java/org/ray/api/id/TaskId.java +++ b/java/api/src/main/java/org/ray/api/id/TaskId.java @@ -11,9 +11,7 @@ import java.util.Random; */ public class TaskId extends BaseId implements Serializable { - private static final int UNIQUE_BYTES_LENGTH = 6; - - public static final int LENGTH = UNIQUE_BYTES_LENGTH + ActorId.LENGTH; + public static final int LENGTH = 14; public static final TaskId NIL = genNil(); @@ -38,15 +36,6 @@ public class TaskId extends BaseId implements Serializable { return new TaskId(bytes); } - /** - * Get the id of the actor to which this task belongs - */ - public ActorId getActorId() { - byte[] actorIdBytes = new byte[ActorId.LENGTH]; - System.arraycopy(getBytes(), UNIQUE_BYTES_LENGTH, actorIdBytes, 0, ActorId.LENGTH); - return ActorId.fromByteBuffer(ByteBuffer.wrap(actorIdBytes)); - } - /** * Generate a nil TaskId. */ diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 28ebe56ab..55d5fc5f9 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -1,45 +1,35 @@ package org.ray.runtime; import com.google.common.base.Preconditions; -import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.lang.reflect.Field; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.nio.file.StandardCopyOption; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.RayPyActor; import org.ray.api.WaitResult; import org.ray.api.exception.RayException; import org.ray.api.function.RayFunc; -import org.ray.api.id.ActorId; -import org.ray.api.id.JobId; import org.ray.api.id.ObjectId; -import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.api.options.ActorCreationOptions; -import org.ray.api.options.BaseTaskOptions; import org.ray.api.options.CallOptions; import org.ray.api.runtime.RayRuntime; import org.ray.api.runtimecontext.RuntimeContext; import org.ray.runtime.config.RayConfig; +import org.ray.runtime.context.RuntimeContextImpl; +import org.ray.runtime.context.WorkerContext; import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.functionmanager.FunctionManager; import org.ray.runtime.functionmanager.PyFunctionDescriptor; import org.ray.runtime.gcs.GcsClient; -import org.ray.runtime.objectstore.ObjectStoreProxy; +import org.ray.runtime.generated.Common.Language; +import org.ray.runtime.object.ObjectStore; +import org.ray.runtime.object.RayObjectImpl; import org.ray.runtime.raylet.RayletClient; -import org.ray.runtime.raylet.RayletClientImpl; import org.ray.runtime.task.ArgumentsBuilder; -import org.ray.runtime.task.TaskLanguage; -import org.ray.runtime.task.TaskSpec; +import org.ray.runtime.task.FunctionArg; +import org.ray.runtime.task.TaskExecutor; +import org.ray.runtime.task.TaskSubmitter; import org.ray.runtime.util.StringUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,71 +40,24 @@ import org.slf4j.LoggerFactory; public abstract class AbstractRayRuntime implements RayRuntime { private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class); - + public static final String PYTHON_INIT_METHOD_NAME = "__init__"; protected RayConfig rayConfig; - protected WorkerContext workerContext; - protected Worker worker; - protected RayletClient rayletClient; - protected ObjectStoreProxy objectStoreProxy; + protected TaskExecutor taskExecutor; protected FunctionManager functionManager; protected RuntimeContext runtimeContext; protected GcsClient gcsClient; - static { - try { - LOGGER.debug("Loading native libraries."); - // Load native libraries. - String[] libraries = new String[]{"core_worker_library_java"}; - for (String library : libraries) { - String fileName = System.mapLibraryName(library); - // Copy the file from resources to a temp dir, and load the native library. - File file = File.createTempFile(fileName, ""); - file.deleteOnExit(); - InputStream in = AbstractRayRuntime.class.getResourceAsStream("/" + fileName); - Preconditions.checkNotNull(in, "{} doesn't exist.", fileName); - Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING); - System.load(file.getAbsolutePath()); - } - LOGGER.debug("Native libraries loaded."); - } catch (IOException e) { - throw new RuntimeException("Couldn't load native libraries.", e); - } - } + protected ObjectStore objectStore; + protected TaskSubmitter taskSubmitter; + protected RayletClient rayletClient; + protected WorkerContext workerContext; public AbstractRayRuntime(RayConfig rayConfig) { this.rayConfig = rayConfig; functionManager = new FunctionManager(rayConfig.jobResourcePath); - worker = new Worker(this); runtimeContext = new RuntimeContextImpl(this); } - protected void resetLibraryPath() { - if (rayConfig.libraryPath.isEmpty()) { - return; - } - - String path = System.getProperty("java.library.path"); - if (Strings.isNullOrEmpty(path)) { - path = ""; - } else { - path += ":"; - } - path += String.join(":", rayConfig.libraryPath); - - // This is a hack to reset library path at runtime, - // see https://stackoverflow.com/questions/15409223/. - System.setProperty("java.library.path", path); - // Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed. - final Field sysPathsField; - try { - sysPathsField = ClassLoader.class.getDeclaredField("sys_paths"); - sysPathsField.setAccessible(true); - sysPathsField.set(null, null); - } catch (NoSuchFieldException | IllegalAccessException e) { - LOGGER.error("Failed to set library path.", e); - } - } - /** * Start runtime. */ @@ -125,31 +68,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { @Override public RayObject put(T obj) { - ObjectId objectId = ObjectId.forPut(workerContext.getCurrentTaskId(), - workerContext.nextPutIndex()); - put(objectId, obj); - return new RayObjectImpl<>(objectId); - } - - public void put(ObjectId objectId, T obj) { - TaskId taskId = workerContext.getCurrentTaskId(); - LOGGER.debug("Putting object {}, for task {} ", objectId, taskId); - objectStoreProxy.put(objectId, obj); - } - - - /** - * Store a serialized object in the object store. - * - * @param obj The serialized Java object to be stored. - * @return A RayObject instance that represents the in-store object. - */ - public RayObject putSerialized(byte[] obj) { - ObjectId objectId = ObjectId.forPut(workerContext.getCurrentTaskId(), - workerContext.nextPutIndex()); - TaskId taskId = workerContext.getCurrentTaskId(); - LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId); - objectStoreProxy.putSerialized(objectId, obj); + ObjectId objectId = objectStore.put(obj); return new RayObjectImpl<>(objectId); } @@ -161,12 +80,12 @@ public abstract class AbstractRayRuntime implements RayRuntime { @Override public List get(List objectIds) { - return objectStoreProxy.get(objectIds); + return objectStore.get(objectIds); } @Override public void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { - rayletClient.freePlasmaObjects(objectIds, localOnly, deleteCreatingTasks); + objectStore.delete(objectIds, localOnly, deleteCreatingTasks); } @Override @@ -180,44 +99,33 @@ public abstract class AbstractRayRuntime implements RayRuntime { @Override public WaitResult wait(List> waitList, int numReturns, int timeoutMs) { - return rayletClient.wait(waitList, numReturns, - timeoutMs, workerContext.getCurrentTaskId()); + return objectStore.wait(waitList, numReturns, timeoutMs); } @Override public RayObject call(RayFunc func, Object[] args, CallOptions options) { - TaskSpec spec = createTaskSpec(func, null, RayActorImpl.NIL, args, false, false, options); - rayletClient.submitTask(spec); - return new RayObjectImpl(spec.returnIds[0]); + FunctionDescriptor functionDescriptor = + functionManager.getFunction(workerContext.getCurrentJobId(), func) + .functionDescriptor; + return callNormalFunction(functionDescriptor, args, options); } @Override public RayObject call(RayFunc func, RayActor actor, Object[] args) { - if (!(actor instanceof RayActorImpl)) { - throw new IllegalArgumentException("Unsupported actor type: " + actor.getClass().getName()); - } - RayActorImpl actorImpl = (RayActorImpl) actor; - TaskSpec spec; - synchronized (actor) { - spec = createTaskSpec(func, null, actorImpl, args, false, true, null); - actorImpl.setTaskCursor(spec.returnIds[1]); - actorImpl.clearNewActorHandles(); - } - rayletClient.submitTask(spec); - return new RayObjectImpl(spec.returnIds[0]); + FunctionDescriptor functionDescriptor = + functionManager.getFunction(workerContext.getCurrentJobId(), func) + .functionDescriptor; + return callActorFunction(actor, functionDescriptor, args); } @Override @SuppressWarnings("unchecked") public RayActor createActor(RayFunc actorFactoryFunc, Object[] args, ActorCreationOptions options) { - TaskSpec spec = createTaskSpec(actorFactoryFunc, null, RayActorImpl.NIL, - args, true, false, options); - RayActorImpl actor = new RayActorImpl(spec.taskId.getActorId()); - actor.increaseTaskCounter(); - actor.setTaskCursor(spec.returnIds[0]); - rayletClient.submitTask(spec); - return (RayActor) actor; + FunctionDescriptor functionDescriptor = + functionManager.getFunction(workerContext.getCurrentJobId(), actorFactoryFunc) + .functionDescriptor; + return (RayActor) createActorImpl(functionDescriptor, args, options); } private void checkPyArguments(Object[] args) { @@ -233,146 +141,69 @@ public abstract class AbstractRayRuntime implements RayRuntime { public RayObject callPy(String moduleName, String functionName, Object[] args, CallOptions options) { checkPyArguments(args); - PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, "", functionName); - TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, false, false, options); - rayletClient.submitTask(spec); - return new RayObjectImpl(spec.returnIds[0]); + PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(moduleName, "", + functionName); + return callNormalFunction(functionDescriptor, args, options); } @Override public RayObject callPy(RayPyActor pyActor, String functionName, Object... args) { checkPyArguments(args); - PyFunctionDescriptor desc = new PyFunctionDescriptor(pyActor.getModuleName(), + PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(), pyActor.getClassName(), functionName); - RayPyActorImpl actorImpl = (RayPyActorImpl) pyActor; - TaskSpec spec; - synchronized (pyActor) { - spec = createTaskSpec(null, desc, actorImpl, args, false, true, null); - actorImpl.setTaskCursor(spec.returnIds[1]); - actorImpl.clearNewActorHandles(); - } - rayletClient.submitTask(spec); - return new RayObjectImpl(spec.returnIds[0]); + return callActorFunction(pyActor, functionDescriptor, args); } @Override public RayPyActor createPyActor(String moduleName, String className, Object[] args, ActorCreationOptions options) { checkPyArguments(args); - PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, className, "__init__"); - TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, true, false, options); - RayPyActorImpl actor = new RayPyActorImpl(spec.actorCreationId, moduleName, className); - actor.increaseTaskCounter(); - actor.setTaskCursor(spec.returnIds[0]); - rayletClient.submitTask(spec); + PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(moduleName, className, + PYTHON_INIT_METHOD_NAME); + return (RayPyActor) createActorImpl(functionDescriptor, args, options); + } + + private RayObject callNormalFunction(FunctionDescriptor functionDescriptor, + Object[] args, CallOptions options) { + List functionArgs = ArgumentsBuilder + .wrap(args, functionDescriptor.getLanguage() != Language.JAVA); + List returnIds = taskSubmitter.submitTask(functionDescriptor, + functionArgs, 1, options); + return new RayObjectImpl(returnIds.get(0)); + } + + private RayObject callActorFunction(RayActor rayActor, + FunctionDescriptor functionDescriptor, Object[] args) { + List functionArgs = ArgumentsBuilder + .wrap(args, functionDescriptor.getLanguage() != Language.JAVA); + List returnIds = taskSubmitter.submitActorTask(rayActor, + functionDescriptor, functionArgs, 1, null); + return new RayObjectImpl(returnIds.get(0)); + } + + private RayActor createActorImpl(FunctionDescriptor functionDescriptor, + Object[] args, ActorCreationOptions options) { + List functionArgs = ArgumentsBuilder + .wrap(args, functionDescriptor.getLanguage() != Language.JAVA); + if (functionDescriptor.getLanguage() != Language.JAVA && options != null) { + Preconditions.checkState(StringUtil.isNullOrEmpty(options.jvmOptions)); + } + RayActor actor = taskSubmitter + .createActor(functionDescriptor, functionArgs, + options); return actor; } - /** - * Create the task specification. - * - * @param func The target remote function. - * @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a Python - * task. - * @param actor The actor handle. If the task is not an actor task, actor id must be NIL. - * @param args The arguments for the remote function. - * @param isActorCreationTask Whether this task is an actor creation task. - * @param isActorTask Whether this task is an actor task. - * @return A TaskSpec object. - */ - private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDescriptor, - RayActorImpl actor, Object[] args, - boolean isActorCreationTask, boolean isActorTask, BaseTaskOptions taskOptions) { - Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null)); - - ActorId actorCreationId = ActorId.NIL; - TaskId taskId = null; - final JobId currentJobId = workerContext.getCurrentJobId(); - final TaskId currentTaskId = workerContext.getCurrentTaskId(); - final int taskIndex = workerContext.nextTaskIndex(); - if (isActorCreationTask) { - taskId = RayletClientImpl.generateActorCreationTaskId(currentJobId, currentTaskId, taskIndex); - actorCreationId = taskId.getActorId(); - } else if (isActorTask) { - taskId = RayletClientImpl.generateActorTaskId(currentJobId, currentTaskId, taskIndex, actor.getId()); - } else { - taskId = RayletClientImpl.generateNormalTaskId(currentJobId, currentTaskId, taskIndex); - } - - int numReturns = actor.getId().isNil() ? 1 : 2; - - Map resources; - if (null == taskOptions) { - resources = new HashMap<>(); - } else { - resources = new HashMap<>(taskOptions.resources); - } - - int maxActorReconstruction = 0; - List dynamicWorkerOptions = ImmutableList.of(); - if (taskOptions instanceof ActorCreationOptions) { - maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions; - String jvmOptions = ((ActorCreationOptions) taskOptions).jvmOptions; - if (!StringUtil.isNullOrEmpty(jvmOptions)) { - dynamicWorkerOptions = ImmutableList.of(((ActorCreationOptions) taskOptions).jvmOptions); - } - } - - TaskLanguage language; - FunctionDescriptor functionDescriptor; - if (func != null) { - language = TaskLanguage.JAVA; - functionDescriptor = functionManager.getFunction(workerContext.getCurrentJobId(), func) - .getFunctionDescriptor(); - } else { - language = TaskLanguage.PYTHON; - functionDescriptor = pyFunctionDescriptor; - } - - ObjectId previousActorTaskDummyObjectId = ObjectId.NIL; - if (isActorTask) { - previousActorTaskDummyObjectId = actor.getTaskCursor(); - } - - return new TaskSpec( - workerContext.getCurrentJobId(), - taskId, - workerContext.getCurrentTaskId(), - -1, - actorCreationId, - maxActorReconstruction, - actor.getId(), - actor.getHandleId(), - actor.increaseTaskCounter(), - previousActorTaskDummyObjectId, - actor.getNewActorHandles().toArray(new UniqueId[0]), - ArgumentsBuilder.wrap(args, language == TaskLanguage.PYTHON), - numReturns, - resources, - language, - functionDescriptor, - dynamicWorkerOptions - ); - } - - public void loop() { - worker.loop(); - } - - public Worker getWorker() { - return worker; - } - public WorkerContext getWorkerContext() { return workerContext; } - public RayletClient getRayletClient() { - return rayletClient; + public ObjectStore getObjectStore() { + return objectStore; } - public ObjectStoreProxy getObjectStoreProxy() { - return objectStoreProxy; + public RayletClient getRayletClient() { + return rayletClient; } public FunctionManager getFunctionManager() { diff --git a/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java deleted file mode 100644 index 97fea9d56..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java +++ /dev/null @@ -1,130 +0,0 @@ -package org.ray.runtime; - -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import java.util.ArrayList; -import java.util.List; -import org.ray.api.RayActor; -import org.ray.api.id.ActorId; -import org.ray.api.id.ObjectId; -import org.ray.api.id.UniqueId; -import org.ray.runtime.util.Sha1Digestor; - -public class RayActorImpl implements RayActor, Externalizable { - - public static final RayActorImpl NIL = new RayActorImpl(); - - /** - * Id of this actor. - */ - protected ActorId id; - /** - * Handle id of this actor. - */ - protected UniqueId handleId; - /** - * The number of tasks that have been invoked on this actor. - */ - protected int taskCounter; - /** - * The unique id of the last return of the last task. - * It's used as a dependency for the next task. - */ - protected ObjectId taskCursor; - /** - * The number of times that this actor handle has been forked. - * It's used to make sure ids of actor handles are unique. - */ - protected int numForks; - - /** - * The new actor handles that were created from this handle - * since the last task on this handle was submitted. This is - * used to garbage-collect dummy objects that are no longer - * necessary in the backend. - */ - protected List newActorHandles; - - public RayActorImpl() { - this(ActorId.NIL, UniqueId.NIL); - } - - public RayActorImpl(ActorId id) { - this(id, UniqueId.NIL); - } - - public RayActorImpl(ActorId id, UniqueId handleId) { - this.id = id; - this.handleId = handleId; - this.taskCounter = 0; - this.taskCursor = null; - this.newActorHandles = new ArrayList<>(); - numForks = 0; - } - - @Override - public ActorId getId() { - return id; - } - - @Override - public UniqueId getHandleId() { - return handleId; - } - - public void setTaskCursor(ObjectId taskCursor) { - this.taskCursor = taskCursor; - } - - public List getNewActorHandles() { - return this.newActorHandles; - } - - public void clearNewActorHandles() { - this.newActorHandles.clear(); - } - - public ObjectId getTaskCursor() { - return taskCursor; - } - - public int increaseTaskCounter() { - return taskCounter++; - } - - public RayActorImpl fork() { - RayActorImpl ret = new RayActorImpl<>(); - ret.id = this.id; - ret.taskCounter = 0; - ret.numForks = 0; - ret.taskCursor = this.taskCursor; - ret.handleId = this.computeNextActorHandleId(); - newActorHandles.add(ret.handleId); - return ret; - } - - protected UniqueId computeNextActorHandleId() { - byte[] bytes = Sha1Digestor.digest(handleId.getBytes(), ++numForks); - return new UniqueId(bytes); - } - - @Override - public void writeExternal(ObjectOutput out) throws IOException { - out.writeObject(this.id); - out.writeObject(this.handleId); - out.writeObject(this.taskCursor); - out.writeObject(this.taskCounter); - out.writeObject(this.numForks); - } - - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - this.id = (ActorId) in.readObject(); - this.handleId = (UniqueId) in.readObject(); - this.taskCursor = (ObjectId) in.readObject(); - this.taskCounter = (int) in.readObject(); - this.numForks = (int) in.readObject(); - } -} diff --git a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java index a491d89e5..7653177a1 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java @@ -3,9 +3,11 @@ package org.ray.runtime; import java.util.concurrent.atomic.AtomicInteger; import org.ray.api.id.JobId; import org.ray.runtime.config.RayConfig; -import org.ray.runtime.objectstore.MockObjectInterface; -import org.ray.runtime.objectstore.ObjectStoreProxy; -import org.ray.runtime.raylet.MockRayletClient; +import org.ray.runtime.context.LocalModeWorkerContext; +import org.ray.runtime.object.LocalModeObjectStore; +import org.ray.runtime.raylet.LocalModeRayletClient; +import org.ray.runtime.task.LocalModeTaskSubmitter; +import org.ray.runtime.task.TaskExecutor; public class RayDevRuntime extends AbstractRayRuntime { @@ -13,37 +15,26 @@ public class RayDevRuntime extends AbstractRayRuntime { super(rayConfig); } - private MockObjectInterface objectInterface; - private AtomicInteger jobCounter = new AtomicInteger(0); @Override public void start() { - // Reset library path at runtime. - resetLibraryPath(); - - objectInterface = new MockObjectInterface(workerContext); if (rayConfig.getJobId().isNil()) { rayConfig.setJobId(nextJobId()); } - workerContext = new WorkerContext(rayConfig.workerMode, - rayConfig.getJobId(), rayConfig.runMode); - objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterface); - rayletClient = new MockRayletClient(this, rayConfig.numberExecThreadsForDevRuntime); + taskExecutor = new TaskExecutor(this); + workerContext = new LocalModeWorkerContext(rayConfig.getJobId()); + objectStore = new LocalModeObjectStore(workerContext); + taskSubmitter = new LocalModeTaskSubmitter(this, (LocalModeObjectStore) objectStore, + rayConfig.numberExecThreadsForDevRuntime); + ((LocalModeObjectStore) objectStore).addObjectPutCallback( + objectId -> ((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId)); + rayletClient = new LocalModeRayletClient(); } @Override public void shutdown() { - rayletClient.destroy(); - } - - public MockObjectInterface getObjectInterface() { - return objectInterface; - } - - @Override - public Worker getWorker() { - return ((MockRayletClient) rayletClient).getCurrentWorker(); + taskExecutor = null; } private JobId nextJobId() { diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index cf804ee02..ab1b6d0b7 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -1,16 +1,28 @@ package org.ray.runtime; +import com.google.common.base.Preconditions; +import com.google.common.base.Strings; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.Field; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; import java.util.HashMap; import java.util.Map; import org.ray.api.id.JobId; import org.ray.runtime.config.RayConfig; +import org.ray.runtime.context.NativeWorkerContext; import org.ray.runtime.gcs.GcsClient; +import org.ray.runtime.gcs.GcsClientOptions; import org.ray.runtime.gcs.RedisClient; import org.ray.runtime.generated.Common.WorkerType; -import org.ray.runtime.objectstore.ObjectInterfaceImpl; -import org.ray.runtime.objectstore.ObjectStoreProxy; -import org.ray.runtime.raylet.RayletClientImpl; +import org.ray.runtime.object.NativeObjectStore; +import org.ray.runtime.raylet.NativeRayletClient; import org.ray.runtime.runner.RunManager; +import org.ray.runtime.task.NativeTaskSubmitter; +import org.ray.runtime.task.TaskExecutor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -23,12 +35,65 @@ public final class RayNativeRuntime extends AbstractRayRuntime { private RunManager manager = null; - private ObjectInterfaceImpl objectInterfaceImpl = null; + /** + * The native pointer of core worker. + */ + private long nativeCoreWorkerPointer; + + static { + try { + LOGGER.debug("Loading native libraries."); + // Load native libraries. + String[] libraries = new String[]{"core_worker_library_java"}; + for (String library : libraries) { + String fileName = System.mapLibraryName(library); + // Copy the file from resources to a temp dir, and load the native library. + File file = File.createTempFile(fileName, ""); + file.deleteOnExit(); + InputStream in = AbstractRayRuntime.class.getResourceAsStream("/" + fileName); + Preconditions.checkNotNull(in, "{} doesn't exist.", fileName); + Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING); + System.load(file.getAbsolutePath()); + } + LOGGER.debug("Native libraries loaded."); + } catch (IOException e) { + throw new RuntimeException("Couldn't load native libraries.", e); + } + nativeSetup(RayConfig.create().logDir); + Runtime.getRuntime().addShutdownHook(new Thread(RayNativeRuntime::nativeShutdownHook)); + } public RayNativeRuntime(RayConfig rayConfig) { super(rayConfig); } + protected void resetLibraryPath() { + if (rayConfig.libraryPath.isEmpty()) { + return; + } + + String path = System.getProperty("java.library.path"); + if (Strings.isNullOrEmpty(path)) { + path = ""; + } else { + path += ":"; + } + path += String.join(":", rayConfig.libraryPath); + + // This is a hack to reset library path at runtime, + // see https://stackoverflow.com/questions/15409223/. + System.setProperty("java.library.path", path); + // Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed. + final Field sysPathsField; + try { + sysPathsField = ClassLoader.class.getDeclaredField("sys_paths"); + sysPathsField.setAccessible(true); + sysPathsField.set(null, null); + } catch (NoSuchFieldException | IllegalAccessException e) { + LOGGER.error("Failed to set library path.", e); + } + } + @Override public void start() { // Reset library path at runtime. @@ -44,20 +109,18 @@ public final class RayNativeRuntime extends AbstractRayRuntime { if (rayConfig.getJobId() == JobId.NIL) { rayConfig.setJobId(gcsClient.nextJobId()); } - - workerContext = new WorkerContext(rayConfig.workerMode, - rayConfig.getJobId(), rayConfig.runMode); - rayletClient = new RayletClientImpl( - rayConfig.rayletSocketName, - workerContext.getCurrentWorkerId(), - rayConfig.workerMode == WorkerType.WORKER, - workerContext.getCurrentJobId() - ); - // TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis. - objectInterfaceImpl = new ObjectInterfaceImpl(workerContext, rayletClient, - rayConfig.objectStoreSocketName); - objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterfaceImpl); + nativeCoreWorkerPointer = nativeInitCoreWorker(rayConfig.workerMode.getNumber(), + rayConfig.objectStoreSocketName, rayConfig.rayletSocketName, + (rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(), + new GcsClientOptions(rayConfig)); + Preconditions.checkState(nativeCoreWorkerPointer != 0); + + taskExecutor = new TaskExecutor(this); + workerContext = new NativeWorkerContext(nativeCoreWorkerPointer); + objectStore = new NativeObjectStore(workerContext, nativeCoreWorkerPointer); + taskSubmitter = new NativeTaskSubmitter(nativeCoreWorkerPointer); + rayletClient = new NativeRayletClient(nativeCoreWorkerPointer); // register registerWorker(); @@ -71,8 +134,14 @@ public final class RayNativeRuntime extends AbstractRayRuntime { if (null != manager) { manager.cleanup(); } - objectInterfaceImpl.destroy(); - workerContext.destroy(); + if (nativeCoreWorkerPointer != 0) { + nativeDestroyCoreWorker(nativeCoreWorkerPointer); + nativeCoreWorkerPointer = 0; + } + } + + public void run() { + nativeRunTaskExecutor(nativeCoreWorkerPointer, taskExecutor); } /** @@ -99,4 +168,16 @@ public final class RayNativeRuntime extends AbstractRayRuntime { redisClient.hmset("Workers:" + workerId, workerInfo); } } + + private static native long nativeInitCoreWorker(int workerMode, String storeSocket, + String rayletSocket, byte[] jobId, GcsClientOptions gcsClientOptions); + + private static native void nativeRunTaskExecutor(long nativeCoreWorkerPointer, + TaskExecutor taskExecutor); + + private static native void nativeDestroyCoreWorker(long nativeCoreWorkerPointer); + + private static native void nativeSetup(String logDir); + + private static native void nativeShutdownHook(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java deleted file mode 100644 index 817a3ffca..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java +++ /dev/null @@ -1,71 +0,0 @@ -package org.ray.runtime; - -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import org.ray.api.RayPyActor; -import org.ray.api.id.ActorId; - -public class RayPyActorImpl extends RayActorImpl implements RayPyActor { - - public static final RayPyActorImpl NIL = new RayPyActorImpl(ActorId.NIL, null, null); - - /** - * Module name of the Python actor class. - */ - private String moduleName; - - /** - * Name of the Python actor class. - */ - private String className; - - // Note that this empty constructor must be public - // since it'll be needed when deserializing. - public RayPyActorImpl() {} - - public RayPyActorImpl(ActorId id, String moduleName, String className) { - super(id); - this.moduleName = moduleName; - this.className = className; - } - - @Override - public String getModuleName() { - return moduleName; - } - - @Override - public String getClassName() { - return className; - } - - public RayPyActorImpl fork() { - RayPyActorImpl ret = new RayPyActorImpl(); - ret.id = this.id; - ret.taskCounter = 0; - ret.numForks = 0; - ret.taskCursor = this.taskCursor; - ret.moduleName = this.moduleName; - ret.className = this.className; - ret.handleId = this.computeNextActorHandleId(); - newActorHandles.add(ret.handleId); - return ret; - } - - @Override - public void writeExternal(ObjectOutput out) throws IOException { - super.writeExternal(out); - out.writeObject(this.moduleName); - out.writeObject(this.className); - } - - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - super.readExternal(in); - this.moduleName = (String) in.readObject(); - this.className = (String) in.readObject(); - } - -} - diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java deleted file mode 100644 index 9d2eeddaa..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java +++ /dev/null @@ -1,139 +0,0 @@ -package org.ray.runtime; - -import com.google.common.base.Preconditions; -import java.nio.ByteBuffer; -import org.ray.api.id.JobId; -import org.ray.api.id.TaskId; -import org.ray.api.id.UniqueId; -import org.ray.runtime.config.RunMode; -import org.ray.runtime.generated.Common.WorkerType; -import org.ray.runtime.raylet.RayletClientImpl; -import org.ray.runtime.task.TaskSpec; - -/** - * This is a wrapper class for worker context of core worker. - */ -public class WorkerContext { - - /** - * The native pointer of worker context of core worker. - */ - private final long nativeWorkerContextPointer; - - private ClassLoader currentClassLoader; - - /** - * The ID of main thread which created the worker context. - */ - private long mainThreadId; - - /** - * The run-mode of this worker. - */ - private RunMode runMode; - - public WorkerContext(WorkerType workerType, JobId jobId, RunMode runMode) { - this.nativeWorkerContextPointer = nativeCreateWorkerContext(workerType.getNumber(), jobId.getBytes()); - mainThreadId = Thread.currentThread().getId(); - this.runMode = runMode; - currentClassLoader = null; - } - - public long getNativeWorkerContext() { - return nativeWorkerContextPointer; - } - - /** - * @return For the main thread, this method returns the ID of this worker's current running task; - * for other threads, this method returns a random ID. - */ - public TaskId getCurrentTaskId() { - return TaskId.fromBytes(nativeGetCurrentTaskId(nativeWorkerContextPointer)); - } - - /** - * Set the current task which is being executed by the current worker. Note, this method can only - * be called from the main thread. - */ - public void setCurrentTask(TaskSpec task, ClassLoader classLoader) { - if (runMode == RunMode.CLUSTER) { - Preconditions.checkState( - Thread.currentThread().getId() == mainThreadId, - "This method should only be called from the main thread." - ); - } - - Preconditions.checkNotNull(task); - byte[] taskSpec = RayletClientImpl.convertTaskSpecToProtobuf(task); - nativeSetCurrentTask(nativeWorkerContextPointer, taskSpec); - currentClassLoader = classLoader; - } - - /** - * Increment the put index and return the new value. - */ - public int nextPutIndex() { - return nativeGetNextPutIndex(nativeWorkerContextPointer); - } - - /** - * Increment the task index and return the new value. - */ - public int nextTaskIndex() { - return nativeGetNextTaskIndex(nativeWorkerContextPointer); - } - - /** - * @return The ID of the current worker. - */ - public UniqueId getCurrentWorkerId() { - return new UniqueId(nativeGetCurrentWorkerId(nativeWorkerContextPointer)); - } - - /** - * The ID of the current job. - */ - public JobId getCurrentJobId() { - return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeWorkerContextPointer)); - } - - /** - * @return The class loader which is associated with the current job. - */ - public ClassLoader getCurrentClassLoader() { - return currentClassLoader; - } - - /** - * Get the current task. - */ - public TaskSpec getCurrentTask() { - byte[] bytes = nativeGetCurrentTask(nativeWorkerContextPointer); - if (bytes == null) { - return null; - } - return RayletClientImpl.parseTaskSpecFromProtobuf(bytes); - } - - public void destroy() { - nativeDestroy(nativeWorkerContextPointer); - } - - private static native long nativeCreateWorkerContext(int workerType, byte[] jobId); - - private static native byte[] nativeGetCurrentTaskId(long nativeWorkerContextPointer); - - private static native void nativeSetCurrentTask(long nativeWorkerContextPointer, byte[] taskSpec); - - private static native byte[] nativeGetCurrentTask(long nativeWorkerContextPointer); - - private static native ByteBuffer nativeGetCurrentJobId(long nativeWorkerContextPointer); - - private static native byte[] nativeGetCurrentWorkerId(long nativeWorkerContextPointer); - - private static native int nativeGetNextTaskIndex(long nativeWorkerContextPointer); - - private static native int nativeGetNextPutIndex(long nativeWorkerContextPointer); - - private static native void nativeDestroy(long nativeWorkerContextPointer); -} diff --git a/java/runtime/src/main/java/org/ray/runtime/actor/LocalModeRayActor.java b/java/runtime/src/main/java/org/ray/runtime/actor/LocalModeRayActor.java new file mode 100644 index 000000000..4ffe36d3d --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/actor/LocalModeRayActor.java @@ -0,0 +1,58 @@ +package org.ray.runtime.actor; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.concurrent.atomic.AtomicReference; +import org.ray.api.RayActor; +import org.ray.api.id.ActorId; +import org.ray.api.id.ObjectId; +import org.ray.api.id.UniqueId; + +/** + * RayActor implementation for local mode. + */ +public class LocalModeRayActor implements RayActor, Externalizable { + + private ActorId actorId; + + private AtomicReference previousActorTaskDummyObjectId = new AtomicReference<>(); + + public LocalModeRayActor(ActorId actorId, ObjectId previousActorTaskDummyObjectId) { + this.actorId = actorId; + this.previousActorTaskDummyObjectId.set(previousActorTaskDummyObjectId); + } + + /** + * Required by FST + */ + public LocalModeRayActor() { + } + + @Override + public ActorId getId() { + return actorId; + } + + @Override + public UniqueId getHandleId() { + return UniqueId.NIL; + } + + public ObjectId exchangePreviousActorTaskDummyObjectId(ObjectId previousActorTaskDummyObjectId) { + return this.previousActorTaskDummyObjectId.getAndSet(previousActorTaskDummyObjectId); + } + + @Override + public synchronized void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(actorId); + out.writeObject(previousActorTaskDummyObjectId.get()); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + actorId = (ActorId) in.readObject(); + previousActorTaskDummyObjectId.set((ObjectId) in.readObject()); + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java new file mode 100644 index 000000000..ecdf03053 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java @@ -0,0 +1,101 @@ +package org.ray.runtime.actor; + +import com.google.common.base.Preconditions; +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.List; +import org.ray.api.RayActor; +import org.ray.api.RayPyActor; +import org.ray.api.id.ActorId; +import org.ray.api.id.UniqueId; +import org.ray.runtime.generated.Common.Language; + +/** + * RayActor implementation for cluster mode. This is a wrapper class for C++ ActorHandle. + */ +public class NativeRayActor implements RayActor, RayPyActor, Externalizable { + + /** + * Address of native actor handle. + */ + private long nativeActorHandle; + + public NativeRayActor(long nativeActorHandle) { + Preconditions.checkState(nativeActorHandle != 0); + this.nativeActorHandle = nativeActorHandle; + } + + /** + * Required by FST + */ + public NativeRayActor() { + } + + public long getNativeActorHandle() { + return nativeActorHandle; + } + + @Override + public ActorId getId() { + return ActorId.fromBytes(nativeGetActorId(nativeActorHandle)); + } + + @Override + public UniqueId getHandleId() { + return new UniqueId(nativeGetActorHandleId(nativeActorHandle)); + } + + public Language getLanguage() { + return Language.forNumber(nativeGetLanguage(nativeActorHandle)); + } + + @Override + public String getModuleName() { + Preconditions.checkState(getLanguage() == Language.PYTHON); + return nativeGetActorCreationTaskFunctionDescriptor(nativeActorHandle).get(0); + } + + @Override + public String getClassName() { + Preconditions.checkState(getLanguage() == Language.PYTHON); + return nativeGetActorCreationTaskFunctionDescriptor(nativeActorHandle).get(1); + } + + public NativeRayActor fork() { + return new NativeRayActor(nativeFork(nativeActorHandle)); + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(nativeSerialize(nativeActorHandle)); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + nativeActorHandle = nativeDeserialize((byte[]) in.readObject()); + } + + @Override + protected void finalize() { + nativeFree(nativeActorHandle); + } + + private static native long nativeFork(long nativeActorHandle); + + private static native byte[] nativeGetActorId(long nativeActorHandle); + + private static native byte[] nativeGetActorHandleId(long nativeActorHandle); + + private static native int nativeGetLanguage(long nativeActorHandle); + + private static native List nativeGetActorCreationTaskFunctionDescriptor( + long nativeActorHandle); + + private static native byte[] nativeSerialize(long nativeActorHandle); + + private static native long nativeDeserialize(byte[] data); + + private static native void nativeFree(long nativeActorHandle); +} diff --git a/java/runtime/src/main/java/org/ray/runtime/util/RayActorSerializer.java b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActorSerializer.java similarity index 65% rename from java/runtime/src/main/java/org/ray/runtime/util/RayActorSerializer.java rename to java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActorSerializer.java index 24c9a3284..11102cdd8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/RayActorSerializer.java +++ b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActorSerializer.java @@ -1,4 +1,4 @@ -package org.ray.runtime.util; +package org.ray.runtime.actor; import java.io.IOException; import org.nustaq.serialization.FSTBasicObjectSerializer; @@ -6,20 +6,22 @@ import org.nustaq.serialization.FSTClazzInfo; import org.nustaq.serialization.FSTClazzInfo.FSTFieldInfo; import org.nustaq.serialization.FSTObjectInput; import org.nustaq.serialization.FSTObjectOutput; -import org.ray.runtime.RayActorImpl; -public class RayActorSerializer extends FSTBasicObjectSerializer { +/** + * To deal with serialization about {@link NativeRayActor}. + */ +public class NativeRayActorSerializer extends FSTBasicObjectSerializer { @Override public void writeObject(FSTObjectOutput out, Object toWrite, FSTClazzInfo clzInfo, FSTClazzInfo.FSTFieldInfo referencedBy, int streamPosition) throws IOException { - ((RayActorImpl) toWrite).fork().writeExternal(out); + ((NativeRayActor) toWrite).fork().writeExternal(out); } @Override public void readObject(FSTObjectInput in, Object toRead, FSTClazzInfo clzInfo, - FSTFieldInfo referencedBy) throws Exception { + FSTFieldInfo referencedBy) throws Exception { super.readObject(in, toRead, clzInfo, referencedBy); - ((RayActorImpl) toRead).readExternal(in); + ((NativeRayActor) toRead).readExternal(in); } } diff --git a/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java new file mode 100644 index 000000000..1f05c3d59 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java @@ -0,0 +1,70 @@ +package org.ray.runtime.context; + +import com.google.common.base.Preconditions; +import org.ray.api.id.ActorId; +import org.ray.api.id.JobId; +import org.ray.api.id.TaskId; +import org.ray.api.id.UniqueId; +import org.ray.runtime.generated.Common.TaskSpec; +import org.ray.runtime.generated.Common.TaskType; +import org.ray.runtime.task.LocalModeTaskSubmitter; +import sun.reflect.generics.reflectiveObjects.NotImplementedException; + +/** + * Worker context for local mode. + */ +public class LocalModeWorkerContext implements WorkerContext { + + private final JobId jobId; + private ThreadLocal currentTask = new ThreadLocal<>(); + + public LocalModeWorkerContext(JobId jobId) { + this.jobId = jobId; + } + + @Override + public UniqueId getCurrentWorkerId() { + throw new NotImplementedException(); + } + + @Override + public JobId getCurrentJobId() { + return jobId; + } + + @Override + public ActorId getCurrentActorId() { + TaskSpec taskSpec = currentTask.get(); + if (taskSpec == null) { + return ActorId.NIL; + } + return LocalModeTaskSubmitter.getActorId(taskSpec); + } + + @Override + public ClassLoader getCurrentClassLoader() { + return null; + } + + @Override + public void setCurrentClassLoader(ClassLoader currentClassLoader) { + } + + @Override + public TaskType getCurrentTaskType() { + TaskSpec taskSpec = currentTask.get(); + Preconditions.checkNotNull(taskSpec, "Current task is not set."); + return taskSpec.getType(); + } + + @Override + public TaskId getCurrentTaskId() { + TaskSpec taskSpec = currentTask.get(); + Preconditions.checkState(taskSpec != null); + return TaskId.fromBytes(taskSpec.getTaskId().toByteArray()); + } + + public void setCurrentTask(TaskSpec taskSpec) { + currentTask.set(taskSpec); + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java new file mode 100644 index 000000000..b42a7b234 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java @@ -0,0 +1,72 @@ +package org.ray.runtime.context; + +import java.nio.ByteBuffer; +import org.ray.api.id.ActorId; +import org.ray.api.id.JobId; +import org.ray.api.id.TaskId; +import org.ray.api.id.UniqueId; +import org.ray.runtime.generated.Common.TaskType; + +/** + * Worker context for cluster mode. This is a wrapper class for worker context of core worker. + */ +public class NativeWorkerContext implements WorkerContext { + + /** + * The native pointer of core worker. + */ + private final long nativeCoreWorkerPointer; + + private ClassLoader currentClassLoader; + + public NativeWorkerContext(long nativeCoreWorkerPointer) { + this.nativeCoreWorkerPointer = nativeCoreWorkerPointer; + } + + @Override + public UniqueId getCurrentWorkerId() { + return UniqueId.fromByteBuffer(nativeGetCurrentWorkerId(nativeCoreWorkerPointer)); + } + + @Override + public JobId getCurrentJobId() { + return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeCoreWorkerPointer)); + } + + @Override + public ActorId getCurrentActorId() { + return ActorId.fromByteBuffer(nativeGetCurrentActorId(nativeCoreWorkerPointer)); + } + + @Override + public ClassLoader getCurrentClassLoader() { + return currentClassLoader; + } + + @Override + public void setCurrentClassLoader(ClassLoader currentClassLoader) { + if (this.currentClassLoader != currentClassLoader) { + this.currentClassLoader = currentClassLoader; + } + } + + @Override + public TaskType getCurrentTaskType() { + return TaskType.forNumber(nativeGetCurrentTaskType(nativeCoreWorkerPointer)); + } + + @Override + public TaskId getCurrentTaskId() { + return TaskId.fromByteBuffer(nativeGetCurrentTaskId(nativeCoreWorkerPointer)); + } + + private static native int nativeGetCurrentTaskType(long nativeCoreWorkerPointer); + + private static native ByteBuffer nativeGetCurrentTaskId(long nativeCoreWorkerPointer); + + private static native ByteBuffer nativeGetCurrentJobId(long nativeCoreWorkerPointer); + + private static native ByteBuffer nativeGetCurrentWorkerId(long nativeCoreWorkerPointer); + + private static native ByteBuffer nativeGetCurrentActorId(long nativeCoreWorkerPointer); +} diff --git a/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java b/java/runtime/src/main/java/org/ray/runtime/context/RuntimeContextImpl.java similarity index 76% rename from java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java rename to java/runtime/src/main/java/org/ray/runtime/context/RuntimeContextImpl.java index 73d361393..0680128fe 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/context/RuntimeContextImpl.java @@ -1,14 +1,14 @@ -package org.ray.runtime; +package org.ray.runtime.context; import com.google.common.base.Preconditions; import java.util.List; - import org.ray.api.id.ActorId; import org.ray.api.id.JobId; import org.ray.api.runtimecontext.NodeInfo; import org.ray.api.runtimecontext.RuntimeContext; +import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.config.RunMode; -import org.ray.runtime.task.TaskSpec; +import org.ray.runtime.generated.Common.TaskType; public class RuntimeContextImpl implements RuntimeContext { @@ -25,16 +25,16 @@ public class RuntimeContextImpl implements RuntimeContext { @Override public ActorId getCurrentActorId() { - Worker worker = runtime.getWorker(); - Preconditions.checkState(worker != null && !worker.getCurrentActorId().isNil(), + ActorId actorId = runtime.getWorkerContext().getCurrentActorId(); + Preconditions.checkState(actorId != null && !actorId.isNil(), "This method should only be called from an actor."); - return worker.getCurrentActorId(); + return actorId; } @Override public boolean wasCurrentActorReconstructed() { - TaskSpec currentTask = runtime.getWorkerContext().getCurrentTask(); - Preconditions.checkState(currentTask != null && currentTask.isActorCreationTask(), + TaskType currentTaskType = runtime.getWorkerContext().getCurrentTaskType(); + Preconditions.checkState(currentTaskType == TaskType.ACTOR_CREATION_TASK, "This method can only be called from an actor creation task."); if (isSingleProcess()) { return false; diff --git a/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java new file mode 100644 index 000000000..4a526c85e --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java @@ -0,0 +1,49 @@ +package org.ray.runtime.context; + +import org.ray.api.id.ActorId; +import org.ray.api.id.JobId; +import org.ray.api.id.TaskId; +import org.ray.api.id.UniqueId; +import org.ray.runtime.generated.Common.TaskType; + +/** + * The context of worker. + */ +public interface WorkerContext { + + /** + * ID of the current worker. + */ + UniqueId getCurrentWorkerId(); + + /** + * ID of the current job. + */ + JobId getCurrentJobId(); + + /** + * ID of the current actor. + */ + ActorId getCurrentActorId(); + + /** + * The class loader that is associated with the current job. It's used for locating classes when + * dealing with serialization and deserialization in {@link org.ray.runtime.util.Serializer}. + */ + ClassLoader getCurrentClassLoader(); + + /** + * Set the current class loader. + */ + void setCurrentClassLoader(ClassLoader currentClassLoader); + + /** + * Type of the current task. + */ + TaskType getCurrentTaskType(); + + /** + * ID of the current task. + */ + TaskId getCurrentTaskId(); +} diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionDescriptor.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionDescriptor.java index 3d0b36b35..ec923e6f9 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionDescriptor.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionDescriptor.java @@ -1,5 +1,8 @@ package org.ray.runtime.functionmanager; +import java.util.List; +import org.ray.runtime.generated.Common.Language; + /** * Base interface of a Ray task's function descriptor. * @@ -8,4 +11,13 @@ package org.ray.runtime.functionmanager; */ public interface FunctionDescriptor { + /** + * @return A list of strings represents the functions. + */ + List toList(); + + /** + * @return The language of the function. + */ + Language getLanguage(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/JavaFunctionDescriptor.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/JavaFunctionDescriptor.java index aac416fa5..25b34539c 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/JavaFunctionDescriptor.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/JavaFunctionDescriptor.java @@ -1,6 +1,9 @@ package org.ray.runtime.functionmanager; import com.google.common.base.Objects; +import com.google.common.collect.ImmutableList; +import java.util.List; +import org.ray.runtime.generated.Common.Language; /** * Represents metadata of Java function. @@ -49,4 +52,14 @@ public final class JavaFunctionDescriptor implements FunctionDescriptor { public int hashCode() { return Objects.hashCode(className, name, typeDescriptor); } + + @Override + public List toList() { + return ImmutableList.of(className, name, typeDescriptor); + } + + @Override + public Language getLanguage() { + return Language.JAVA; + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/PyFunctionDescriptor.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/PyFunctionDescriptor.java index 1fe13f0fb..6845e79c7 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/PyFunctionDescriptor.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/PyFunctionDescriptor.java @@ -1,5 +1,9 @@ package org.ray.runtime.functionmanager; +import java.util.Arrays; +import java.util.List; +import org.ray.runtime.generated.Common.Language; + /** * Represents metadata of a Python function. */ @@ -21,5 +25,15 @@ public class PyFunctionDescriptor implements FunctionDescriptor { public String toString() { return moduleName + "." + className + "." + functionName; } + + @Override + public List toList() { + return Arrays.asList(moduleName, className, functionName); + } + + @Override + public Language getLanguage() { + return Language.PYTHON; + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClientOptions.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClientOptions.java new file mode 100644 index 000000000..4f0b6b901 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClientOptions.java @@ -0,0 +1,18 @@ +package org.ray.runtime.gcs; + +import org.ray.runtime.config.RayConfig; + +/** + * Options to create GCS Client. + */ +public class GcsClientOptions { + public String ip; + public int port; + public String password; + + public GcsClientOptions(RayConfig rayConfig) { + ip = rayConfig.getRedisIp(); + port = rayConfig.getRedisPort(); + password = rayConfig.redisPassword; + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java b/java/runtime/src/main/java/org/ray/runtime/object/LocalModeObjectStore.java similarity index 77% rename from java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java rename to java/runtime/src/main/java/org/ray/runtime/object/LocalModeObjectStore.java index 1f53f19d0..9034bec70 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java +++ b/java/runtime/src/main/java/org/ray/runtime/object/LocalModeObjectStore.java @@ -1,4 +1,4 @@ -package org.ray.runtime.objectstore; +package org.ray.runtime.object; import com.google.common.base.Preconditions; import java.util.ArrayList; @@ -8,22 +8,24 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; import java.util.stream.Collectors; import org.ray.api.id.ObjectId; -import org.ray.runtime.WorkerContext; +import org.ray.runtime.context.WorkerContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class MockObjectInterface implements ObjectInterface { +/** + * Object store methods for local mode. + */ +public class LocalModeObjectStore extends ObjectStore { - private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectInterface.class); + private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeObjectStore.class); private static final int GET_CHECK_INTERVAL_MS = 100; private final Map pool = new ConcurrentHashMap<>(); private final List> objectPutCallbacks = new ArrayList<>(); - private final WorkerContext workerContext; - public MockObjectInterface(WorkerContext workerContext) { - this.workerContext = workerContext; + public LocalModeObjectStore(WorkerContext workerContext) { + super(workerContext); } public void addObjectPutCallback(Consumer callback) { @@ -35,15 +37,14 @@ public class MockObjectInterface implements ObjectInterface { } @Override - public ObjectId put(NativeRayObject obj) { - ObjectId objectId = ObjectId.forPut(workerContext.getCurrentTaskId(), - workerContext.nextPutIndex()); - put(obj, objectId); + public ObjectId putRaw(NativeRayObject obj) { + ObjectId objectId = ObjectId.fromRandom(); + putRaw(obj, objectId); return objectId; } @Override - public void put(NativeRayObject obj, ObjectId objectId) { + public void putRaw(NativeRayObject obj, ObjectId objectId) { Preconditions.checkNotNull(obj); Preconditions.checkNotNull(objectId); pool.putIfAbsent(objectId, obj); @@ -53,7 +54,7 @@ public class MockObjectInterface implements ObjectInterface { } @Override - public List get(List objectIds, long timeoutMs) { + public List getRaw(List objectIds, long timeoutMs) { waitInternal(objectIds, objectIds.size(), timeoutMs); return objectIds.stream().map(pool::get).collect(Collectors.toList()); } diff --git a/java/runtime/src/main/java/org/ray/runtime/object/NativeObjectStore.java b/java/runtime/src/main/java/org/ray/runtime/object/NativeObjectStore.java new file mode 100644 index 000000000..9c8798f0f --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/object/NativeObjectStore.java @@ -0,0 +1,70 @@ +package org.ray.runtime.object; + +import java.util.List; +import java.util.stream.Collectors; +import org.ray.api.id.BaseId; +import org.ray.api.id.ObjectId; +import org.ray.runtime.context.WorkerContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Object store methods for cluster mode. This is a wrapper class for core worker object interface. + */ +public class NativeObjectStore extends ObjectStore { + + private static final Logger LOGGER = LoggerFactory.getLogger(NativeObjectStore.class); + + /** + * The native pointer of core worker. + */ + private final long nativeCoreWorkerPointer; + + public NativeObjectStore(WorkerContext workerContext, long nativeCoreWorkerPointer) { + super(workerContext); + this.nativeCoreWorkerPointer = nativeCoreWorkerPointer; + } + + @Override + public ObjectId putRaw(NativeRayObject obj) { + return new ObjectId(nativePut(nativeCoreWorkerPointer, obj)); + } + + @Override + public void putRaw(NativeRayObject obj, ObjectId objectId) { + nativePut(nativeCoreWorkerPointer, objectId.getBytes(), obj); + } + + @Override + public List getRaw(List objectIds, long timeoutMs) { + return nativeGet(nativeCoreWorkerPointer, toBinaryList(objectIds), timeoutMs); + } + + @Override + public List wait(List objectIds, int numObjects, long timeoutMs) { + return nativeWait(nativeCoreWorkerPointer, toBinaryList(objectIds), numObjects, timeoutMs); + } + + @Override + public void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { + nativeDelete(nativeCoreWorkerPointer, toBinaryList(objectIds), localOnly, deleteCreatingTasks); + } + + private static List toBinaryList(List ids) { + return ids.stream().map(BaseId::getBytes).collect(Collectors.toList()); + } + + private static native byte[] nativePut(long nativeCoreWorkerPointer, NativeRayObject obj); + + private static native void nativePut(long nativeCoreWorkerPointer, byte[] objectId, + NativeRayObject obj); + + private static native List nativeGet(long nativeCoreWorkerPointer, + List ids, long timeoutMs); + + private static native List nativeWait(long nativeCoreWorkerPointer, + List objectIds, int numObjects, long timeoutMs); + + private static native void nativeDelete(long nativeCoreWorkerPointer, List objectIds, + boolean localOnly, boolean deleteCreatingTasks); +} diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/NativeRayObject.java b/java/runtime/src/main/java/org/ray/runtime/object/NativeRayObject.java similarity index 71% rename from java/runtime/src/main/java/org/ray/runtime/objectstore/NativeRayObject.java rename to java/runtime/src/main/java/org/ray/runtime/object/NativeRayObject.java index 7146765c2..20111b7a6 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/NativeRayObject.java +++ b/java/runtime/src/main/java/org/ray/runtime/object/NativeRayObject.java @@ -1,5 +1,8 @@ -package org.ray.runtime.objectstore; +package org.ray.runtime.object; +/** + * Binary representation of ray object. + */ public class NativeRayObject { public byte[] data; diff --git a/java/runtime/src/main/java/org/ray/runtime/object/ObjectStore.java b/java/runtime/src/main/java/org/ray/runtime/object/ObjectStore.java new file mode 100644 index 000000000..b25442605 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/object/ObjectStore.java @@ -0,0 +1,217 @@ +package org.ray.runtime.object; + +import com.google.common.base.Preconditions; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import org.ray.api.RayObject; +import org.ray.api.WaitResult; +import org.ray.api.exception.RayActorException; +import org.ray.api.exception.RayException; +import org.ray.api.exception.RayTaskException; +import org.ray.api.exception.RayWorkerException; +import org.ray.api.exception.UnreconstructableException; +import org.ray.api.id.ObjectId; +import org.ray.runtime.context.WorkerContext; +import org.ray.runtime.generated.Gcs.ErrorType; +import org.ray.runtime.util.Serializer; + +/** + * A class that is used to put/get objects to/from the object store. + */ +public abstract class ObjectStore { + + private static final byte[] WORKER_EXCEPTION_META = String + .valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes(); + private static final byte[] ACTOR_EXCEPTION_META = String + .valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes(); + private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String + .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes(); + + private static final byte[] TASK_EXECUTION_EXCEPTION_META = String + .valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes(); + + private static final byte[] RAW_TYPE_META = "RAW".getBytes(); + + private final WorkerContext workerContext; + + public ObjectStore(WorkerContext workerContext) { + this.workerContext = workerContext; + } + + /** + * Put a raw object into object store. + * + * @param obj The ray object. + * @return Generated ID of the object. + */ + public abstract ObjectId putRaw(NativeRayObject obj); + + /** + * Put a raw object with specified ID into object store. + * + * @param obj The ray object. + * @param objectId Object ID specified by user. + */ + public abstract void putRaw(NativeRayObject obj, ObjectId objectId); + + /** + * Serialize and put an object to the object store. + * + * @param object The object to put. + * @return Id of the object. + */ + public ObjectId put(Object object) { + return putRaw(serialize(object)); + } + + /** + * Get a list of raw objects from the object store. + * + * @param objectIds IDs of the objects to get. + * @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative. + * @return Result list of objects data. + */ + public abstract List getRaw(List objectIds, long timeoutMs); + + /** + * Get a list of objects from the object store. + * + * @param ids List of the object ids. + * @param Type of these objects. + * @return A list of GetResult objects. + */ + @SuppressWarnings("unchecked") + public List get(List ids) { + // Pass -1 as timeout to wait until all objects are available in object store. + List dataAndMetaList = getRaw(ids, -1); + + List results = new ArrayList<>(); + for (int i = 0; i < dataAndMetaList.size(); i++) { + NativeRayObject dataAndMeta = dataAndMetaList.get(i); + Object object = null; + if (dataAndMeta != null) { + object = deserialize(dataAndMeta, ids.get(i)); + } + if (object instanceof RayException) { + // If the object is a `RayException`, it means that an error occurred during task + // execution. + throw (RayException) object; + } + results.add((T) object); + } + // This check must be placed after the throw exception statement. + // Because if there was any exception, The get operation would return early + // and wouldn't wait until all objects exist. + Preconditions.checkState(dataAndMetaList.stream().allMatch(Objects::nonNull)); + return results; + } + + /** + * Wait for a list of objects to appear in the object store. + * + * @param objectIds IDs of the objects to wait for. + * @param numObjects Number of objects that should appear. + * @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative. + * @return A bitset that indicates each object has appeared or not. + */ + public abstract List wait(List objectIds, int numObjects, long timeoutMs); + + /** + * Wait for a list of RayObjects to be locally available, until specified number of objects are + * ready, or specified timeout has passed. + * + * @param waitList A list of RayObject to wait for. + * @param numReturns The number of objects that should be returned. + * @param timeoutMs The maximum time in milliseconds to wait before returning. + * @return Two lists, one containing locally available objects, one containing the rest. + */ + public WaitResult wait(List> waitList, int numReturns, int timeoutMs) { + Preconditions.checkNotNull(waitList); + if (waitList.isEmpty()) { + return new WaitResult<>(Collections.emptyList(), Collections.emptyList()); + } + + List ids = waitList.stream().map(RayObject::getId).collect(Collectors.toList()); + + List ready = wait(ids, numReturns, timeoutMs); + List> readyList = new ArrayList<>(); + List> unreadyList = new ArrayList<>(); + + for (int i = 0; i < ready.size(); i++) { + if (ready.get(i)) { + readyList.add(waitList.get(i)); + } else { + unreadyList.add(waitList.get(i)); + } + } + + return new WaitResult<>(readyList, unreadyList); + } + + /** + * Delete a list of objects from the object store. + * + * @param objectIds IDs of the objects to delete. + * @param localOnly Whether only delete the objects in local node, or all nodes in the cluster. + * @param deleteCreatingTasks Whether also delete the tasks that created these objects. + */ + public abstract void delete(List objectIds, boolean localOnly, + boolean deleteCreatingTasks); + + /** + * Deserialize an object. + * + * @param nativeRayObject The object to deserialize. + * @param objectId The associated object ID of the object. + * @return The deserialized object. + */ + public Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId) { + byte[] meta = nativeRayObject.metadata; + byte[] data = nativeRayObject.data; + + // If meta is not null, deserialize the object from meta. + if (meta != null && meta.length > 0) { + // If meta is not null, deserialize the object from meta. + if (Arrays.equals(meta, RAW_TYPE_META)) { + return data; + } else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) { + return RayWorkerException.INSTANCE; + } else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) { + return RayActorException.INSTANCE; + } else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) { + return new UnreconstructableException(objectId); + } else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) { + return Serializer.decode(data, workerContext.getCurrentClassLoader()); + } + throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta)); + } else { + // If data is not null, deserialize the Java object. + return Serializer.decode(data, workerContext.getCurrentClassLoader()); + } + } + + /** + * Serialize an object. + * + * @param object The object to serialize. + * @return The serialized object. + */ + public NativeRayObject serialize(Object object) { + if (object instanceof NativeRayObject) { + return (NativeRayObject) object; + } else if (object instanceof byte[]) { + // If the object is a byte array, skip serializing it and use a special metadata to + // indicate it's raw binary. So that this object can also be read by Python. + return new NativeRayObject((byte[]) object, RAW_TYPE_META); + } else if (object instanceof RayTaskException) { + return new NativeRayObject(Serializer.encode(object), + TASK_EXECUTION_EXCEPTION_META); + } else { + return new NativeRayObject(Serializer.encode(object), null); + } + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java b/java/runtime/src/main/java/org/ray/runtime/object/RayObjectImpl.java similarity index 83% rename from java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java rename to java/runtime/src/main/java/org/ray/runtime/object/RayObjectImpl.java index 9f8e567f8..197445831 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/object/RayObjectImpl.java @@ -1,10 +1,13 @@ -package org.ray.runtime; +package org.ray.runtime.object; import java.io.Serializable; import org.ray.api.Ray; import org.ray.api.RayObject; import org.ray.api.id.ObjectId; +/** + * Implementation of {@link RayObject}. + */ public final class RayObjectImpl implements RayObject, Serializable { private final ObjectId id; diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterface.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterface.java deleted file mode 100644 index 5780dbd6c..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterface.java +++ /dev/null @@ -1,54 +0,0 @@ -package org.ray.runtime.objectstore; - -import java.util.List; -import org.ray.api.id.ObjectId; - -/** - * The interface that contains all worker methods that are related to object store. - */ -public interface ObjectInterface { - - /** - * Put an object into object store. - * - * @param obj The ray object. - * @return Generated ID of the object. - */ - ObjectId put(NativeRayObject obj); - - /** - * Put an object with specified ID into object store. - * - * @param obj The ray object. - * @param objectId Object ID specified by user. - */ - void put(NativeRayObject obj, ObjectId objectId); - - /** - * Get a list of objects from the object store. - * - * @param objectIds IDs of the objects to get. - * @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative. - * @return Result list of objects data. - */ - List get(List objectIds, long timeoutMs); - - /** - * Wait for a list of objects to appear in the object store. - * - * @param objectIds IDs of the objects to wait for. - * @param numObjects Number of objects that should appear. - * @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative. - * @return A bitset that indicates each object has appeared or not. - */ - List wait(List objectIds, int numObjects, long timeoutMs); - - /** - * Delete a list of objects from the object store. - * - * @param objectIds IDs of the objects to delete. - * @param localOnly Whether only delete the objects in local node, or all nodes in the cluster. - * @param deleteCreatingTasks Whether also delete the tasks that created these objects. - */ - void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks); -} diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterfaceImpl.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterfaceImpl.java deleted file mode 100644 index 736df1192..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterfaceImpl.java +++ /dev/null @@ -1,92 +0,0 @@ -package org.ray.runtime.objectstore; - -import java.util.List; -import java.util.stream.Collectors; -import org.ray.api.exception.RayException; -import org.ray.api.id.BaseId; -import org.ray.api.id.ObjectId; -import org.ray.runtime.AbstractRayRuntime; -import org.ray.runtime.WorkerContext; -import org.ray.runtime.raylet.RayletClient; -import org.ray.runtime.raylet.RayletClientImpl; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This is a wrapper class for core worker object interface. - */ -public class ObjectInterfaceImpl implements ObjectInterface { - - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class); - - /** - * The native pointer of core worker object interface. - */ - private final long nativeObjectInterfacePointer; - - public ObjectInterfaceImpl(WorkerContext workerContext, RayletClient rayletClient, - String storeSocketName) { - this.nativeObjectInterfacePointer = - nativeCreateObjectInterface(workerContext.getNativeWorkerContext(), - ((RayletClientImpl) rayletClient).getClient(), storeSocketName); - } - - @Override - public ObjectId put(NativeRayObject obj) { - return new ObjectId(nativePut(nativeObjectInterfacePointer, obj)); - } - - @Override - public void put(NativeRayObject obj, ObjectId objectId) { - try { - nativePut(nativeObjectInterfacePointer, objectId.getBytes(), obj); - } catch (RayException e) { - LOGGER.warn(e.getMessage()); - } - } - - @Override - public List get(List objectIds, long timeoutMs) { - return nativeGet(nativeObjectInterfacePointer, toBinaryList(objectIds), timeoutMs); - } - - @Override - public List wait(List objectIds, int numObjects, long timeoutMs) { - return nativeWait(nativeObjectInterfacePointer, toBinaryList(objectIds), numObjects, timeoutMs); - } - - @Override - public void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { - nativeDelete(nativeObjectInterfacePointer, - toBinaryList(objectIds), localOnly, deleteCreatingTasks); - } - - public void destroy() { - nativeDestroy(nativeObjectInterfacePointer); - } - - private static List toBinaryList(List ids) { - return ids.stream().map(BaseId::getBytes).collect(Collectors.toList()); - } - - private static native long nativeCreateObjectInterface(long nativeObjectInterface, - long nativeRayletClient, - String storeSocketName); - - private static native byte[] nativePut(long nativeObjectInterface, NativeRayObject obj); - - private static native void nativePut(long nativeObjectInterface, byte[] objectId, - NativeRayObject obj); - - private static native List nativeGet(long nativeObjectInterface, - List ids, - long timeoutMs); - - private static native List nativeWait(long nativeObjectInterface, List objectIds, - int numObjects, long timeoutMs); - - private static native void nativeDelete(long nativeObjectInterface, List objectIds, - boolean localOnly, boolean deleteCreatingTasks); - - private static native void nativeDestroy(long nativeObjectInterface); -} diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java deleted file mode 100644 index 3c3696a9c..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ /dev/null @@ -1,152 +0,0 @@ -package org.ray.runtime.objectstore; - -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Objects; -import org.ray.api.exception.RayActorException; -import org.ray.api.exception.RayException; -import org.ray.api.exception.RayTaskException; -import org.ray.api.exception.RayWorkerException; -import org.ray.api.exception.UnreconstructableException; -import org.ray.api.id.ObjectId; -import org.ray.runtime.WorkerContext; -import org.ray.runtime.generated.Gcs.ErrorType; -import org.ray.runtime.util.Serializer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * A class that is used to put/get objects to/from the object store. - */ -public class ObjectStoreProxy { - - private static final Logger LOGGER = LoggerFactory.getLogger(ObjectStoreProxy.class); - - private static final byte[] WORKER_EXCEPTION_META = String - .valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes(); - private static final byte[] ACTOR_EXCEPTION_META = String - .valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes(); - private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String - .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes(); - - private static final byte[] TASK_EXECUTION_EXCEPTION_META = String - .valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes(); - - private static final byte[] RAW_TYPE_META = "RAW".getBytes(); - - private final WorkerContext workerContext; - - private final ObjectInterface objectInterface; - - public ObjectStoreProxy(WorkerContext workerContext, ObjectInterface objectInterface) { - this.workerContext = workerContext; - this.objectInterface = objectInterface; - } - - /** - * Get an object from the object store. - * - * @param id Id of the object. - * @param Type of the object. - * @return The GetResult object. - */ - public T get(ObjectId id) { - List list = get(ImmutableList.of(id)); - return list.get(0); - } - - /** - * Get a list of objects from the object store. - * - * @param ids List of the object ids. - * @param Type of these objects. - * @return A list of GetResult objects. - */ - @SuppressWarnings("unchecked") - public List get(List ids) { - // Pass -1 as timeout to wait until all objects are available in object store. - List dataAndMetaList = objectInterface.get(ids, -1); - - List results = new ArrayList<>(); - for (int i = 0; i < dataAndMetaList.size(); i++) { - NativeRayObject dataAndMeta = dataAndMetaList.get(i); - Object object = null; - if (dataAndMeta != null) { - byte[] meta = dataAndMeta.metadata; - byte[] data = dataAndMeta.data; - if (meta != null && meta.length > 0) { - // If meta is not null, deserialize the object from meta. - object = deserializeFromMeta(meta, data, - workerContext.getCurrentClassLoader(), ids.get(i)); - } else { - // If data is not null, deserialize the Java object. - object = Serializer.decode(data, workerContext.getCurrentClassLoader()); - } - if (object instanceof RayException) { - // If the object is a `RayException`, it means that an error occurred during task - // execution. - throw (RayException) object; - } - } - - results.add((T) object); - } - // This check must be placed after the throw exception statement. - // Because if there was any exception, The get operation would return early - // and wouldn't wait until all objects exist. - Preconditions.checkState(dataAndMetaList.stream().allMatch(Objects::nonNull)); - return results; - } - - private Object deserializeFromMeta(byte[] meta, byte[] data, - ClassLoader classLoader, ObjectId objectId) { - if (Arrays.equals(meta, RAW_TYPE_META)) { - return data; - } else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) { - return RayWorkerException.INSTANCE; - } else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) { - return RayActorException.INSTANCE; - } else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) { - return new UnreconstructableException(objectId); - } else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) { - return Serializer.decode(data, classLoader); - } - throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta)); - } - - /** - * Serialize and put an object to the object store. - * - * @param id Id of the object. - * @param object The object to put. - */ - public void put(ObjectId id, Object object) { - if (object instanceof byte[]) { - // If the object is a byte array, skip serializing it and use a special metadata to - // indicate it's raw binary. So that this object can also be read by Python. - objectInterface.put(new NativeRayObject((byte[]) object, RAW_TYPE_META), id); - } else if (object instanceof RayTaskException) { - objectInterface - .put(new NativeRayObject(Serializer.encode(object), TASK_EXECUTION_EXCEPTION_META), id); - } else { - objectInterface.put(new NativeRayObject(Serializer.encode(object), null), id); - } - } - - /** - * Put an already serialized object to the object store. - * - * @param id Id of the object. - * @param serializedObject The serialized object to put. - */ - public void putSerialized(ObjectId id, byte[] serializedObject) { - objectInterface.put(new NativeRayObject(serializedObject, null), id); - } - - public ObjectInterface getObjectInterface() { - return objectInterface; - } -} diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java new file mode 100644 index 000000000..9d43244c3 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java @@ -0,0 +1,29 @@ +package org.ray.runtime.raylet; + +import org.apache.commons.lang3.NotImplementedException; +import org.ray.api.id.ActorId; +import org.ray.api.id.UniqueId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Raylet client for local mode. + */ +public class LocalModeRayletClient implements RayletClient { + private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeRayletClient.class); + + @Override + public UniqueId prepareCheckpoint(ActorId actorId) { + throw new NotImplementedException("Not implemented."); + } + + @Override + public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) { + throw new NotImplementedException("Not implemented."); + } + + @Override + public void setResource(String resourceName, double capacity, UniqueId nodeId) { + LOGGER.error("Not implemented under SINGLE_PROCESS mode."); + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java deleted file mode 100644 index 913ab57d0..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ /dev/null @@ -1,204 +0,0 @@ -package org.ray.runtime.raylet; - -import com.google.common.collect.ImmutableList; - -import java.util.ArrayList; -import java.util.Deque; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedDeque; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; - -import java.util.stream.Collectors; -import org.apache.commons.lang3.NotImplementedException; -import org.ray.api.RayObject; -import org.ray.api.WaitResult; -import org.ray.api.id.ActorId; -import org.ray.api.id.ObjectId; -import org.ray.api.id.TaskId; -import org.ray.api.id.UniqueId; -import org.ray.runtime.RayDevRuntime; -import org.ray.runtime.Worker; -import org.ray.runtime.objectstore.MockObjectInterface; -import org.ray.runtime.objectstore.NativeRayObject; -import org.ray.runtime.task.FunctionArg; -import org.ray.runtime.task.TaskSpec; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * A mock implementation of RayletClient, used in single process mode. - */ -public class MockRayletClient implements RayletClient { - - private static final Logger LOGGER = LoggerFactory.getLogger(MockRayletClient.class); - - private final Map> waitingTasks = new ConcurrentHashMap<>(); - private final MockObjectInterface objectInterface; - private final RayDevRuntime runtime; - private final ExecutorService exec; - private final Deque idleWorkers; - private final Map actorWorkers; - private final ThreadLocal currentWorker; - - public MockRayletClient(RayDevRuntime runtime, int numberThreads) { - this.runtime = runtime; - this.objectInterface = runtime.getObjectInterface(); - objectInterface.addObjectPutCallback(this::onObjectPut); - // The thread pool that executes tasks in parallel. - exec = Executors.newFixedThreadPool(numberThreads); - idleWorkers = new ConcurrentLinkedDeque<>(); - actorWorkers = new HashMap<>(); - currentWorker = new ThreadLocal<>(); - } - - public synchronized void onObjectPut(ObjectId id) { - Set tasks = waitingTasks.get(id); - if (tasks != null) { - waitingTasks.remove(id); - for (TaskSpec taskSpec : tasks) { - submitTask(taskSpec); - } - } - } - - public Worker getCurrentWorker() { - return currentWorker.get(); - } - - /** - * Get a worker from the worker pool to run the given task. - */ - private synchronized Worker getWorker(TaskSpec task) { - Worker worker; - if (task.isActorTask()) { - worker = actorWorkers.get(task.actorId); - } else { - if (task.isActorCreationTask()) { - worker = new Worker(runtime); - actorWorkers.put(task.actorCreationId, worker); - } else if (idleWorkers.size() > 0) { - worker = idleWorkers.pop(); - } else { - worker = new Worker(runtime); - } - } - currentWorker.set(worker); - return worker; - } - - /** - * Return the worker to the worker pool. - */ - private void returnWorker(Worker worker) { - currentWorker.remove(); - idleWorkers.push(worker); - } - - @Override - public synchronized void submitTask(TaskSpec task) { - LOGGER.debug("Submitting task: {}.", task); - Set unreadyObjects = getUnreadyObjects(task); - if (unreadyObjects.isEmpty()) { - // If all dependencies are ready, execute this task. - exec.submit(() -> { - Worker worker = getWorker(task); - try { - worker.execute(task); - // If the task is an actor task or an actor creation task, - // put the dummy object in object store, so those tasks which depends on it - // can be executed. - if (task.isActorCreationTask() || task.isActorTask()) { - ObjectId[] returnIds = task.returnIds; - objectInterface.put(new NativeRayObject(new byte[] {}, new byte[] {}), - returnIds[returnIds.length - 1]); - } - } finally { - returnWorker(worker); - } - }); - } else { - // If some dependencies aren't ready yet, put this task in waiting list. - for (ObjectId id : unreadyObjects) { - waitingTasks.computeIfAbsent(id, k -> new HashSet<>()).add(task); - } - } - } - - private Set getUnreadyObjects(TaskSpec spec) { - Set unreadyObjects = new HashSet<>(); - // Check whether task arguments are ready. - for (FunctionArg arg : spec.args) { - if (arg.id != null) { - if (!objectInterface.isObjectReady(arg.id)) { - unreadyObjects.add(arg.id); - } - } - } - if (spec.isActorTask()) { - if (!objectInterface.isObjectReady(spec.previousActorTaskDummyObjectId)) { - unreadyObjects.add(spec.previousActorTaskDummyObjectId); - } - } - return unreadyObjects; - } - - - @Override - public TaskSpec getTask() { - throw new RuntimeException("invalid execution flow here"); - } - - @Override - public WaitResult wait(List> waitFor, int numReturns, int - timeoutMs, TaskId currentTaskId) { - if (waitFor == null || waitFor.isEmpty()) { - return new WaitResult<>(ImmutableList.of(), ImmutableList.of()); - } - - List ids = waitFor.stream().map(RayObject::getId).collect(Collectors.toList()); - List> readyList = new ArrayList<>(); - List> unreadyList = new ArrayList<>(); - List result = objectInterface.wait(ids, ids.size(), timeoutMs); - for (int i = 0; i < waitFor.size(); i++) { - if (result.get(i)) { - readyList.add(waitFor.get(i)); - } else { - unreadyList.add(waitFor.get(i)); - } - } - return new WaitResult<>(readyList, unreadyList); - } - - @Override - public void freePlasmaObjects(List objectIds, boolean localOnly, - boolean deleteCreatingTasks) { - objectInterface.delete(objectIds, localOnly, deleteCreatingTasks); - } - - - @Override - public UniqueId prepareCheckpoint(ActorId actorId) { - throw new NotImplementedException("Not implemented."); - } - - @Override - public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) { - throw new NotImplementedException("Not implemented."); - } - - @Override - public void setResource(String resourceName, double capacity, UniqueId nodeId) { - LOGGER.error("Not implemented under SINGLE_PROCESS mode."); - } - - @Override - public void destroy() { - exec.shutdown(); - } -} diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java new file mode 100644 index 000000000..ed5f10f12 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java @@ -0,0 +1,57 @@ +package org.ray.runtime.raylet; + +import org.ray.api.exception.RayException; +import org.ray.api.id.ActorId; +import org.ray.api.id.UniqueId; + +/** + * Raylet client for cluster mode. This is a wrapper class for C++ RayletClient. + */ +public class NativeRayletClient implements RayletClient { + + /** + * The native pointer of core worker. + */ + private long nativeCoreWorkerPointer = 0; + + public NativeRayletClient(long nativeCoreWorkerPointer) { + this.nativeCoreWorkerPointer = nativeCoreWorkerPointer; + } + + @Override + public UniqueId prepareCheckpoint(ActorId actorId) { + return new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer, actorId.getBytes())); + } + + @Override + public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) { + nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, actorId.getBytes(), + checkpointId.getBytes()); + } + + + public void setResource(String resourceName, double capacity, UniqueId nodeId) { + nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes()); + } + + /// Native method declarations. + /// + /// If you change the signature of any native methods, please re-generate + /// the C++ header file and update the C++ implementation accordingly: + /// + /// Suppose that $Dir is your ray root directory. + /// 1) pushd $Dir/java/runtime/target/classes + /// 2) javah -classpath .:$Dir/java/api/target/classes org.ray.runtime.raylet.NativeRayletClient + /// 3) clang-format -i org_ray_runtime_raylet_NativeRayletClient.h + /// 4) cp org_ray_runtime_raylet_NativeRayletClient.h $Dir/src/ray/core_worker/lib/java/ + /// 5) vim $Dir/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc + /// 6) popd + + private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId); + + private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId, + byte[] checkpointId); + + private static native void nativeSetResource(long conn, String resourceName, double capacity, + byte[] nodeId) throws RayException; +} diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index ea398004a..144187b6b 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -1,33 +1,16 @@ package org.ray.runtime.raylet; -import java.util.List; -import org.ray.api.RayObject; -import org.ray.api.WaitResult; import org.ray.api.id.ActorId; -import org.ray.api.id.ObjectId; -import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; -import org.ray.runtime.task.TaskSpec; /** * Client to the Raylet backend. */ public interface RayletClient { - void submitTask(TaskSpec task); - - TaskSpec getTask(); - - WaitResult wait(List> waitFor, int numReturns, int - timeoutMs, TaskId currentTaskId); - - void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks); - UniqueId prepareCheckpoint(ActorId actorId); void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId); void setResource(String resourceName, double capacity, UniqueId nodeId); - - void destroy(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java deleted file mode 100644 index 1577270b1..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ /dev/null @@ -1,343 +0,0 @@ -package org.ray.runtime.raylet; - -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; -import com.google.protobuf.ByteString; -import com.google.protobuf.InvalidProtocolBufferException; -import java.nio.ByteBuffer; -import java.nio.charset.Charset; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import org.ray.api.RayObject; -import org.ray.api.WaitResult; -import org.ray.api.exception.RayException; -import org.ray.api.id.ActorId; -import org.ray.api.id.UniqueId; -import org.ray.api.id.JobId; -import org.ray.api.id.TaskId; -import org.ray.api.id.ObjectId; -import org.ray.runtime.functionmanager.JavaFunctionDescriptor; -import org.ray.runtime.generated.Common; -import org.ray.runtime.generated.Common.TaskType; -import org.ray.runtime.task.FunctionArg; -import org.ray.runtime.task.TaskLanguage; -import org.ray.runtime.task.TaskSpec; -import org.ray.runtime.util.IdUtil; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class RayletClientImpl implements RayletClient { - - private static final Logger LOGGER = LoggerFactory.getLogger(RayletClientImpl.class); - - /** - * The pointer to c++'s raylet client. - */ - private long client = 0; - - // TODO(qwang): JobId parameter can be removed once we embed jobId in driverId. - public RayletClientImpl(String schedulerSockName, UniqueId workerId, - boolean isWorker, JobId jobId) { - client = nativeInit(schedulerSockName, workerId.getBytes(), - isWorker, jobId.getBytes()); - } - - public long getClient() { - return client; - } - - @Override - public WaitResult wait(List> waitFor, int numReturns, int - timeoutMs, TaskId currentTaskId) { - Preconditions.checkNotNull(waitFor); - if (waitFor.isEmpty()) { - return new WaitResult<>(new ArrayList<>(), new ArrayList<>()); - } - - List ids = new ArrayList<>(); - for (RayObject element : waitFor) { - ids.add(element.getId()); - } - - boolean[] ready = nativeWaitObject(client, IdUtil.getIdBytes(ids), - numReturns, timeoutMs, false, currentTaskId.getBytes()); - List> readyList = new ArrayList<>(); - List> unreadyList = new ArrayList<>(); - - for (int i = 0; i < ready.length; i++) { - if (ready[i]) { - readyList.add(waitFor.get(i)); - } else { - unreadyList.add(waitFor.get(i)); - } - } - - return new WaitResult<>(readyList, unreadyList); - } - - @Override - public void submitTask(TaskSpec spec) { - LOGGER.debug("Submitting task: {}", spec); - Preconditions.checkState(!spec.parentTaskId.isNil()); - Preconditions.checkState(!spec.jobId.isNil()); - - byte[] taskSpec = convertTaskSpecToProtobuf(spec); - nativeSubmitTask(client, taskSpec); - } - - @Override - public TaskSpec getTask() { - byte[] bytes = nativeGetTask(client); - assert (null != bytes); - return parseTaskSpecFromProtobuf(bytes); - } - - @Override - public void freePlasmaObjects(List objectIds, boolean localOnly, - boolean deleteCreatingTasks) { - byte[][] objectIdsArray = IdUtil.getIdBytes(objectIds); - nativeFreePlasmaObjects(client, objectIdsArray, localOnly, deleteCreatingTasks); - } - - @Override - public UniqueId prepareCheckpoint(ActorId actorId) { - return new UniqueId(nativePrepareCheckpoint(client, actorId.getBytes())); - } - - @Override - public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) { - nativeNotifyActorResumedFromCheckpoint(client, actorId.getBytes(), checkpointId.getBytes()); - } - - public static TaskId generateActorCreationTaskId(JobId jobId, TaskId parentTaskId, int taskIndex) { - byte[] bytes = nativeGenerateActorCreationTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex); - return TaskId.fromBytes(bytes); - } - - public static TaskId generateActorTaskId(JobId jobId, TaskId parentTaskId, int taskIndex, ActorId actorId) { - byte[] bytes = nativeGenerateActorTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex, actorId.getBytes()); - return TaskId.fromBytes(bytes); - } - - public static TaskId generateNormalTaskId(JobId jobId, TaskId parentTaskId, int taskIndex) { - byte[] bytes = nativeGenerateNormalTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex); - return TaskId.fromBytes(bytes); - } - - /** - * Parse `TaskSpec` protobuf bytes. - */ - public static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) { - Common.TaskSpec taskSpec; - try { - taskSpec = Common.TaskSpec.parseFrom(bytes); - } catch (InvalidProtocolBufferException e) { - throw new RuntimeException("Invalid protobuf data."); - } - - // Parse common fields. - JobId jobId = JobId.fromByteBuffer(taskSpec.getJobId().asReadOnlyByteBuffer()); - TaskId taskId = TaskId.fromByteBuffer(taskSpec.getTaskId().asReadOnlyByteBuffer()); - TaskId parentTaskId = TaskId.fromByteBuffer(taskSpec.getParentTaskId().asReadOnlyByteBuffer()); - int parentCounter = (int) taskSpec.getParentCounter(); - int numReturns = (int) taskSpec.getNumReturns(); - Map resources = taskSpec.getRequiredResourcesMap(); - - // Parse args. - FunctionArg[] args = new FunctionArg[taskSpec.getArgsCount()]; - for (int i = 0; i < args.length; i++) { - Common.TaskArg arg = taskSpec.getArgs(i); - int objectIdsLength = arg.getObjectIdsCount(); - if (objectIdsLength > 0) { - Preconditions.checkArgument(objectIdsLength == 1, - "This arg has more than one id: {}", objectIdsLength); - args[i] = FunctionArg - .passByReference(ObjectId.fromByteBuffer(arg.getObjectIds(0).asReadOnlyByteBuffer())); - } else { - args[i] = FunctionArg.passByValue(arg.getData().toByteArray()); - } - } - - // Parse function descriptor - Preconditions.checkArgument(taskSpec.getLanguage() == Common.Language.JAVA); - Preconditions.checkArgument(taskSpec.getFunctionDescriptorCount() == 3); - JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor( - taskSpec.getFunctionDescriptor(0).toString(Charset.defaultCharset()), - taskSpec.getFunctionDescriptor(1).toString(Charset.defaultCharset()), - taskSpec.getFunctionDescriptor(2).toString(Charset.defaultCharset()) - ); - - // Parse ActorCreationTaskSpec. - ActorId actorCreationId = ActorId.NIL; - int maxActorReconstructions = 0; - UniqueId[] newActorHandles = new UniqueId[0]; - List dynamicWorkerOptions = new ArrayList<>(); - if (taskSpec.getType() == Common.TaskType.ACTOR_CREATION_TASK) { - Common.ActorCreationTaskSpec actorCreationTaskSpec = taskSpec.getActorCreationTaskSpec(); - actorCreationId = ActorId - .fromByteBuffer(actorCreationTaskSpec.getActorId().asReadOnlyByteBuffer()); - maxActorReconstructions = (int) actorCreationTaskSpec.getMaxActorReconstructions(); - dynamicWorkerOptions = ImmutableList - .copyOf(actorCreationTaskSpec.getDynamicWorkerOptionsList()); - } - - // Parse ActorTaskSpec. - ActorId actorId = ActorId.NIL; - UniqueId actorHandleId = UniqueId.NIL; - ObjectId previousActorTaskDummyObjectId = ObjectId.NIL; - int actorCounter = 0; - if (taskSpec.getType() == Common.TaskType.ACTOR_TASK) { - Common.ActorTaskSpec actorTaskSpec = taskSpec.getActorTaskSpec(); - actorId = ActorId.fromByteBuffer(actorTaskSpec.getActorId().asReadOnlyByteBuffer()); - actorHandleId = UniqueId - .fromByteBuffer(actorTaskSpec.getActorHandleId().asReadOnlyByteBuffer()); - actorCounter = (int) actorTaskSpec.getActorCounter(); - previousActorTaskDummyObjectId = ObjectId.fromByteBuffer( - actorTaskSpec.getPreviousActorTaskDummyObjectId().asReadOnlyByteBuffer()); - newActorHandles = actorTaskSpec.getNewActorHandlesList().stream() - .map(byteString -> UniqueId.fromByteBuffer(byteString.asReadOnlyByteBuffer())) - .toArray(UniqueId[]::new); - } - - return new TaskSpec(jobId, taskId, parentTaskId, parentCounter, actorCreationId, - maxActorReconstructions, actorId, actorHandleId, actorCounter, - previousActorTaskDummyObjectId, newActorHandles, args, numReturns, resources, - TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions); - } - - /** - * Convert a `TaskSpec` to protobuf-serialized bytes. - */ - public static byte[] convertTaskSpecToProtobuf(TaskSpec task) { - // Set common fields. - Common.TaskSpec.Builder builder = Common.TaskSpec.newBuilder() - .setJobId(ByteString.copyFrom(task.jobId.getBytes())) - .setTaskId(ByteString.copyFrom(task.taskId.getBytes())) - .setParentTaskId(ByteString.copyFrom(task.parentTaskId.getBytes())) - .setParentCounter(task.parentCounter) - .setNumReturns(task.numReturns) - .putAllRequiredResources(task.resources); - - // Set args - builder.addAllArgs( - Arrays.stream(task.args).map(arg -> { - Common.TaskArg.Builder argBuilder = Common.TaskArg.newBuilder(); - if (arg.id != null) { - argBuilder.addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build(); - } else { - argBuilder.setData(ByteString.copyFrom(arg.data)).build(); - } - return argBuilder.build(); - }).collect(Collectors.toList()) - ); - - // Set function descriptor and language. - if (task.language == TaskLanguage.JAVA) { - builder.setLanguage(Common.Language.JAVA); - builder.addAllFunctionDescriptor(ImmutableList.of( - ByteString.copyFrom(task.getJavaFunctionDescriptor().className.getBytes()), - ByteString.copyFrom(task.getJavaFunctionDescriptor().name.getBytes()), - ByteString.copyFrom(task.getJavaFunctionDescriptor().typeDescriptor.getBytes()) - )); - } else { - builder.setLanguage(Common.Language.PYTHON); - builder.addAllFunctionDescriptor(ImmutableList.of( - ByteString.copyFrom(task.getPyFunctionDescriptor().moduleName.getBytes()), - ByteString.copyFrom(task.getPyFunctionDescriptor().className.getBytes()), - ByteString.copyFrom(task.getPyFunctionDescriptor().functionName.getBytes()), - ByteString.EMPTY - )); - } - - if (!task.actorCreationId.isNil()) { - // Actor creation task. - builder.setType(TaskType.ACTOR_CREATION_TASK); - builder.setActorCreationTaskSpec( - Common.ActorCreationTaskSpec.newBuilder() - .setActorId(ByteString.copyFrom(task.actorCreationId.getBytes())) - .setMaxActorReconstructions(task.maxActorReconstructions) - .addAllDynamicWorkerOptions(task.dynamicWorkerOptions) - ); - } else if (!task.actorId.isNil()) { - // Actor task. - builder.setType(TaskType.ACTOR_TASK); - List newHandles = Arrays.stream(task.newActorHandles) - .map(id -> ByteString.copyFrom(id.getBytes())).collect(Collectors.toList()); - final ObjectId actorCreationDummyObjectId = IdUtil.computeActorCreationDummyObjectId( - ActorId.fromByteBuffer(ByteBuffer.wrap(task.actorId.getBytes()))); - builder.setActorTaskSpec( - Common.ActorTaskSpec.newBuilder() - .setActorId(ByteString.copyFrom(task.actorId.getBytes())) - .setActorHandleId(ByteString.copyFrom(task.actorHandleId.getBytes())) - .setActorCreationDummyObjectId( - ByteString.copyFrom(actorCreationDummyObjectId.getBytes())) - .setPreviousActorTaskDummyObjectId( - ByteString.copyFrom(task.previousActorTaskDummyObjectId.getBytes())) - .setActorCounter(task.actorCounter) - .addAllNewActorHandles(newHandles) - ); - } else { - // Normal task. - builder.setType(TaskType.NORMAL_TASK); - } - - return builder.build().toByteArray(); - } - - public void setResource(String resourceName, double capacity, UniqueId nodeId) { - nativeSetResource(client, resourceName, capacity, nodeId.getBytes()); - } - - public void destroy() { - nativeDestroy(client); - } - - /// Native method declarations. - /// - /// If you change the signature of any native methods, please re-generate - /// the C++ header file and update the C++ implementation accordingly: - /// - /// Suppose that $Dir is your ray root directory. - /// 1) pushd $Dir/java/runtime/target/classes - /// 2) javah -classpath .:$Dir/java/api/target/classes org.ray.runtime.raylet.RayletClientImpl - /// 3) clang-format -i org_ray_runtime_raylet_RayletClientImpl.h - /// 4) cp org_ray_runtime_raylet_RayletClientImpl.h $Dir/src/ray/raylet/lib/java/ - /// 5) vim $Dir/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc - /// 6) popd - - private static native long nativeInit(String localSchedulerSocket, byte[] workerId, - boolean isWorker, byte[] driverTaskId); - - private static native void nativeSubmitTask(long client, byte[] taskSpec) - throws RayException; - - private static native byte[] nativeGetTask(long client) throws RayException; - - private static native void nativeDestroy(long client) throws RayException; - - private static native boolean[] nativeWaitObject(long conn, byte[][] objectIds, - int numReturns, int timeout, boolean waitLocal, byte[] currentTaskId) throws RayException; - - private static native void nativeFreePlasmaObjects(long conn, byte[][] objectIds, - boolean localOnly, boolean deleteCreatingTasks) throws RayException; - - private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId); - - private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId, - byte[] checkpointId); - - private static native void nativeSetResource(long conn, String resourceName, double capacity, - byte[] nodeId) throws RayException; - - private static native byte[] nativeGenerateActorCreationTaskId(byte[] jobId, byte[] parentTaskId, - int taskIndex); - - private static native byte[] nativeGenerateActorTaskId(byte[] jobId, byte[] parentTaskId, - int taskIndex, byte[] actorId); - - private static native byte[] nativeGenerateNormalTaskId(byte[] jobId, byte[] parentTaskId, - int taskIndex); -} diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 839dff95d..5e197da89 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -20,6 +20,7 @@ import java.util.Random; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.apache.commons.lang3.tuple.Pair; import org.ray.runtime.config.RayConfig; import org.ray.runtime.util.FileUtil; import org.ray.runtime.util.ResourceUtil; @@ -44,17 +45,16 @@ public class RunManager { private Random random; - private List processes; + private List> processes; private static final int KILL_PROCESS_WAIT_TIMEOUT_SECONDS = 1; - private final Map tempFiles; + private static final Map tempFiles = new HashMap<>(); public RunManager(RayConfig rayConfig) { this.rayConfig = rayConfig; processes = new ArrayList<>(); random = new Random(); - tempFiles = new HashMap<>(); } public void cleanup() { @@ -63,19 +63,28 @@ public class RunManager { // cannot exit gracefully. for (int i = processes.size() - 1; i >= 0; --i) { - Process p = processes.get(i); - p.destroy(); + Pair pair = processes.get(i); + String name = pair.getLeft(); + Process p = pair.getRight(); - try { - p.waitFor(KILL_PROCESS_WAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS); - } catch (InterruptedException e) { - LOGGER.warn("Got InterruptedException while waiting for process {}" + - " to be terminated.", processes.get(i)); - } - - if (p.isAlive()) { - p.destroyForcibly(); + int numAttempts = 0; + while (p.isAlive()) { + if (numAttempts == 0) { + LOGGER.debug("Terminating process {}.", name); + p.destroy(); + } else { + LOGGER.debug("Terminating process {} forcibly.", name); + p.destroyForcibly(); + } + try { + p.waitFor(KILL_PROCESS_WAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS); + } catch (InterruptedException e) { + LOGGER.warn("Got InterruptedException while waiting for process {}" + + " to be terminated.", processes.get(i)); + } + numAttempts++; } + LOGGER.info("Process {} is now terminated.", name); } } @@ -152,7 +161,7 @@ public class RunManager { if (!p.isAlive()) { throw new RuntimeException("Failed to start " + name); } - processes.add(p); + processes.add(Pair.of(name, p)); LOGGER.info("{} process started", name); } @@ -259,7 +268,7 @@ public class RunManager { String.format("--store_socket_name=%s", rayConfig.objectStoreSocketName), String.format("--object_manager_port=%d", 0), // The object manager port. String.format("--node_manager_port=%d", 0), // The node manager port. - String.format("--node_ip_address=%s",rayConfig.nodeIp), + String.format("--node_ip_address=%s", rayConfig.nodeIp), String.format("--redis_address=%s", rayConfig.getRedisIp()), String.format("--redis_port=%d", rayConfig.getRedisPort()), String.format("--num_initial_workers=%d", 0), // number of initial workers diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java b/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java index 211411906..698c14973 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java @@ -1,7 +1,7 @@ package org.ray.runtime.runner.worker; import org.ray.api.Ray; -import org.ray.runtime.AbstractRayRuntime; +import org.ray.runtime.RayNativeRuntime; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -20,7 +20,7 @@ public class DefaultWorker { }); Ray.init(); LOGGER.info("Worker started."); - ((AbstractRayRuntime)Ray.internal()).loop(); + ((RayNativeRuntime)Ray.internal()).run(); } catch (Exception e) { LOGGER.error("Failed to start worker.", e); } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index 52447cf79..110c178f7 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -3,12 +3,16 @@ package org.ray.runtime.task; import java.util.ArrayList; import java.util.List; import org.ray.api.Ray; -import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.id.ObjectId; import org.ray.runtime.AbstractRayRuntime; +import org.ray.runtime.object.NativeRayObject; +import org.ray.runtime.object.ObjectStore; import org.ray.runtime.util.Serializer; +/** + * Helper methods to convert arguments from/to objects. + */ public class ArgumentsBuilder { /** @@ -20,16 +24,13 @@ public class ArgumentsBuilder { /** * Convert real function arguments to task spec arguments. */ - public static FunctionArg[] wrap(Object[] args, boolean crossLanguage) { - FunctionArg[] ret = new FunctionArg[args.length]; - for (int i = 0; i < ret.length; i++) { - Object arg = args[i]; + public static List wrap(Object[] args, boolean crossLanguage) { + List ret = new ArrayList<>(); + for (Object arg : args) { ObjectId id = null; byte[] data = null; if (arg == null) { data = Serializer.encode(null); - } else if (arg instanceof RayActor) { - data = Serializer.encode(arg); } else if (arg instanceof RayObject) { id = ((RayObject) arg).getId(); } else if (arg instanceof byte[] && crossLanguage) { @@ -40,41 +41,28 @@ public class ArgumentsBuilder { } else { byte[] serialized = Serializer.encode(arg); if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) { - id = ((AbstractRayRuntime) Ray.internal()).putSerialized(serialized).getId(); + id = ((AbstractRayRuntime) Ray.internal()).getObjectStore() + .put(new NativeRayObject(serialized, null)); } else { data = serialized; } } if (id != null) { - ret[i] = FunctionArg.passByReference(id); + ret.add(FunctionArg.passByReference(id)); } else { - ret[i] = FunctionArg.passByValue(data); + ret.add(FunctionArg.passByValue(data)); } } return ret; } /** - * Convert task spec arguments to real function arguments. + * Convert list of NativeRayObject to real function arguments. */ - public static Object[] unwrap(TaskSpec task, ClassLoader classLoader) { - Object[] realArgs = new Object[task.args.length]; - List idsToFetch = new ArrayList<>(); - List indices = new ArrayList<>(); - for (int i = 0; i < task.args.length; i++) { - FunctionArg arg = task.args[i]; - if (arg.id != null) { - // pass by reference - idsToFetch.add(arg.id); - indices.add(i); - } else { - // pass by value - realArgs[i] = Serializer.decode(arg.data, classLoader); - } - } - List objects = Ray.get(idsToFetch); - for (int i = 0; i < objects.size(); i++) { - realArgs[indices.get(i)] = objects.get(i); + public static Object[] unwrap(ObjectStore objectStore, List args) { + Object[] realArgs = new Object[args.size()]; + for (int i = 0; i < args.size(); i++) { + realArgs[i] = objectStore.deserialize(args.get(i), null); } return realArgs; } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java new file mode 100644 index 000000000..3aa5c5067 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java @@ -0,0 +1,298 @@ +package org.ray.runtime.task; + +import com.google.common.base.Preconditions; +import com.google.protobuf.ByteString; +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.stream.Collectors; +import org.ray.api.RayActor; +import org.ray.api.id.ActorId; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; +import org.ray.api.options.ActorCreationOptions; +import org.ray.api.options.CallOptions; +import org.ray.runtime.actor.LocalModeRayActor; +import org.ray.runtime.context.LocalModeWorkerContext; +import org.ray.runtime.RayDevRuntime; +import org.ray.runtime.functionmanager.FunctionDescriptor; +import org.ray.runtime.functionmanager.JavaFunctionDescriptor; +import org.ray.runtime.generated.Common.ActorCreationTaskSpec; +import org.ray.runtime.generated.Common.ActorTaskSpec; +import org.ray.runtime.generated.Common.Language; +import org.ray.runtime.generated.Common.TaskArg; +import org.ray.runtime.generated.Common.TaskSpec; +import org.ray.runtime.generated.Common.TaskType; +import org.ray.runtime.object.NativeRayObject; +import org.ray.runtime.object.LocalModeObjectStore; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Task submitter for local mode. + */ +public class LocalModeTaskSubmitter implements TaskSubmitter { + + private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeTaskSubmitter.class); + + private final Map> waitingTasks = new HashMap<>(); + private final Object taskAndObjectLock = new Object(); + private final RayDevRuntime runtime; + private final LocalModeObjectStore objectStore; + private final ExecutorService exec; + private final Deque idleTaskExecutors = new ArrayDeque<>(); + private final Map actorTaskExecutors = new HashMap<>(); + private final Object taskExecutorLock = new Object(); + private final ThreadLocal currentTaskExecutor = new ThreadLocal<>(); + + public LocalModeTaskSubmitter(RayDevRuntime runtime, LocalModeObjectStore objectStore, + int numberThreads) { + this.runtime = runtime; + this.objectStore = objectStore; + // The thread pool that executes tasks in parallel. + exec = Executors.newFixedThreadPool(numberThreads); + } + + public void onObjectPut(ObjectId id) { + Set tasks; + synchronized (taskAndObjectLock) { + tasks = waitingTasks.remove(id); + if (tasks != null) { + for (TaskSpec task : tasks) { + Set unreadyObjects = getUnreadyObjects(task); + if (unreadyObjects.isEmpty()) { + submitTaskSpec(task); + } + } + } + } + } + + /** + * Get the worker of current thread.
NOTE: Cannot be used for multi-threading in worker. + */ + public TaskExecutor getCurrentTaskExecutor() { + return currentTaskExecutor.get(); + } + + /** + * Get a worker from the worker pool to run the given task. + */ + private TaskExecutor getTaskExecutor(TaskSpec task) { + TaskExecutor taskExecutor; + synchronized (taskExecutorLock) { + if (task.getType() == TaskType.ACTOR_TASK) { + taskExecutor = actorTaskExecutors.get(getActorId(task)); + } else if (task.getType() == TaskType.ACTOR_CREATION_TASK) { + taskExecutor = new TaskExecutor(runtime); + actorTaskExecutors.put(getActorId(task), taskExecutor); + } else if (idleTaskExecutors.size() > 0) { + taskExecutor = idleTaskExecutors.pop(); + } else { + taskExecutor = new TaskExecutor(runtime); + } + } + currentTaskExecutor.set(taskExecutor); + return taskExecutor; + } + + /** + * Return the worker to the worker pool. + */ + private void returnTaskExecutor(TaskExecutor worker, TaskSpec taskSpec) { + currentTaskExecutor.remove(); + synchronized (taskExecutorLock) { + if (taskSpec.getType() == TaskType.NORMAL_TASK) { + idleTaskExecutors.push(worker); + } + } + } + + private Set getUnreadyObjects(TaskSpec taskSpec) { + Set unreadyObjects = new HashSet<>(); + // Check whether task arguments are ready. + for (TaskArg arg : taskSpec.getArgsList()) { + for (ByteString idByteString : arg.getObjectIdsList()) { + ObjectId id = new ObjectId(idByteString.toByteArray()); + if (!objectStore.isObjectReady(id)) { + unreadyObjects.add(id); + } + } + } + if (taskSpec.getType() == TaskType.ACTOR_TASK) { + ObjectId dummyObjectId = new ObjectId( + taskSpec.getActorTaskSpec().getPreviousActorTaskDummyObjectId().toByteArray()); + if (!objectStore.isObjectReady(dummyObjectId)) { + unreadyObjects.add(dummyObjectId); + } + } + return unreadyObjects; + } + + private TaskSpec.Builder getTaskSpecBuilder(TaskType taskType, + FunctionDescriptor functionDescriptor, List args) { + byte[] taskIdBytes = new byte[TaskId.LENGTH]; + new Random().nextBytes(taskIdBytes); + return TaskSpec.newBuilder() + .setType(taskType) + .setLanguage(Language.JAVA) + .setJobId( + ByteString.copyFrom(runtime.getRayConfig().getJobId().getBytes())) + .setTaskId(ByteString.copyFrom(taskIdBytes)) + .addAllFunctionDescriptor(functionDescriptor.toList().stream().map(ByteString::copyFromUtf8) + .collect(Collectors.toList())) + .addAllArgs(args.stream().map(arg -> arg.id != null ? TaskArg.newBuilder() + .addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build() + : TaskArg.newBuilder().setData(ByteString.copyFrom(arg.data)).build()) + .collect(Collectors.toList())); + } + + @Override + public List submitTask(FunctionDescriptor functionDescriptor, List args, + int numReturns, CallOptions options) { + Preconditions.checkState(numReturns == 1); + TaskSpec taskSpec = getTaskSpecBuilder(TaskType.NORMAL_TASK, functionDescriptor, args) + .setNumReturns(numReturns) + .build(); + submitTaskSpec(taskSpec); + return getReturnIds(taskSpec); + } + + @Override + public RayActor createActor(FunctionDescriptor functionDescriptor, List args, + ActorCreationOptions options) { + ActorId actorId = ActorId.fromRandom(); + TaskSpec taskSpec = getTaskSpecBuilder(TaskType.ACTOR_CREATION_TASK, functionDescriptor, args) + .setNumReturns(1) + .setActorCreationTaskSpec(ActorCreationTaskSpec.newBuilder() + .setActorId(ByteString.copyFrom(actorId.toByteBuffer())) + .build()) + .build(); + submitTaskSpec(taskSpec); + return new LocalModeRayActor(actorId, getReturnIds(taskSpec).get(0)); + } + + @Override + public List submitActorTask(RayActor actor, FunctionDescriptor functionDescriptor, + List args, int numReturns, CallOptions options) { + Preconditions.checkState(numReturns == 1); + TaskSpec.Builder builder = getTaskSpecBuilder(TaskType.ACTOR_TASK, functionDescriptor, args); + List returnIds = getReturnIds( + TaskId.fromBytes(builder.getTaskId().toByteArray()), numReturns + 1); + TaskSpec taskSpec = builder + .setNumReturns(numReturns + 1) + .setActorTaskSpec( + ActorTaskSpec.newBuilder().setActorId(ByteString.copyFrom(actor.getId().getBytes())) + .setPreviousActorTaskDummyObjectId(ByteString.copyFrom( + ((LocalModeRayActor) actor) + .exchangePreviousActorTaskDummyObjectId(returnIds.get(returnIds.size() - 1)) + .getBytes())) + .build()) + .build(); + submitTaskSpec(taskSpec); + return Collections.singletonList(returnIds.get(0)); + } + + public static ActorId getActorId(TaskSpec taskSpec) { + ByteString actorId = null; + if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) { + actorId = taskSpec.getActorCreationTaskSpec().getActorId(); + } else if (taskSpec.getType() == TaskType.ACTOR_TASK) { + actorId = taskSpec.getActorTaskSpec().getActorId(); + } + if (actorId == null) { + return null; + } + return ActorId.fromBytes(actorId.toByteArray()); + } + + private void submitTaskSpec(TaskSpec taskSpec) { + LOGGER.debug("Submitting task: {}.", taskSpec); + synchronized (taskAndObjectLock) { + Set unreadyObjects = getUnreadyObjects(taskSpec); + if (unreadyObjects.isEmpty()) { + // If all dependencies are ready, execute this task. + exec.submit(() -> { + TaskExecutor taskExecutor = getTaskExecutor(taskSpec); + try { + List args = getFunctionArgs(taskSpec).stream() + .map(arg -> arg.id != null ? + objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0) + : new NativeRayObject(arg.data, null)) + .collect(Collectors.toList()); + ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec); + List returnObjects = taskExecutor + .execute(getJavaFunctionDescriptor(taskSpec).toList(), args); + ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null); + List returnIds = getReturnIds(taskSpec); + for (int i = 0; i < returnIds.size(); i++) { + NativeRayObject putObject; + if (i >= returnObjects.size()) { + // If the task is an actor task or an actor creation task, + // put the dummy object in object store, so those tasks which depends on it + // can be executed. + putObject = new NativeRayObject(new byte[]{}, new byte[]{}); + } else { + putObject = returnObjects.get(i); + } + objectStore.putRaw(putObject, returnIds.get(i)); + } + } finally { + returnTaskExecutor(taskExecutor, taskSpec); + } + }); + } else { + // If some dependencies aren't ready yet, put this task in waiting list. + for (ObjectId id : unreadyObjects) { + waitingTasks.computeIfAbsent(id, k -> new HashSet<>()).add(taskSpec); + } + } + } + } + + private static JavaFunctionDescriptor getJavaFunctionDescriptor(TaskSpec taskSpec) { + List functionDescriptor = taskSpec.getFunctionDescriptorList(); + return new JavaFunctionDescriptor(functionDescriptor.get(0).toStringUtf8(), + functionDescriptor.get(1).toStringUtf8(), functionDescriptor.get(2).toStringUtf8()); + } + + private static List getFunctionArgs(TaskSpec taskSpec) { + List functionArgs = new ArrayList<>(); + for (int i = 0; i < taskSpec.getArgsCount(); i++) { + TaskArg arg = taskSpec.getArgs(i); + if (arg.getObjectIdsCount() > 0) { + functionArgs.add(FunctionArg + .passByReference(new ObjectId(arg.getObjectIds(0).toByteArray()))); + } else { + functionArgs.add(FunctionArg.passByValue(arg.getData().toByteArray())); + } + } + return functionArgs; + } + + private static List getReturnIds(TaskSpec taskSpec) { + return getReturnIds(TaskId.fromBytes(taskSpec.getTaskId().toByteArray()), + taskSpec.getNumReturns()); + } + + private static List getReturnIds(TaskId taskId, long numReturns) { + List returnIds = new ArrayList<>(); + for (int i = 0; i < numReturns; i++) { + returnIds.add(ObjectId.fromByteBuffer( + (ByteBuffer) ByteBuffer.allocate(ObjectId.LENGTH).put(taskId.getBytes()) + .putInt(TaskId.LENGTH, i + 1).position(0))); + } + return returnIds; + } + +} diff --git a/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskSubmitter.java b/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskSubmitter.java new file mode 100644 index 000000000..cae93e788 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskSubmitter.java @@ -0,0 +1,64 @@ +package org.ray.runtime.task; + +import com.google.common.base.Preconditions; +import java.util.List; +import java.util.stream.Collectors; +import org.ray.api.RayActor; +import org.ray.api.id.ObjectId; +import org.ray.api.options.ActorCreationOptions; +import org.ray.api.options.CallOptions; +import org.ray.runtime.actor.NativeRayActor; +import org.ray.runtime.functionmanager.FunctionDescriptor; + +/** + * Task submitter for cluster mode. This is a wrapper class for core worker task interface. + */ +public class NativeTaskSubmitter implements TaskSubmitter { + + /** + * The native pointer of core worker. + */ + private final long nativeCoreWorkerPointer; + + public NativeTaskSubmitter(long nativeCoreWorkerPointer) { + this.nativeCoreWorkerPointer = nativeCoreWorkerPointer; + } + + @Override + public List submitTask(FunctionDescriptor functionDescriptor, List args, + int numReturns, CallOptions options) { + List returnIds = nativeSubmitTask(nativeCoreWorkerPointer, functionDescriptor, args, + numReturns, options); + return returnIds.stream().map(ObjectId::new).collect(Collectors.toList()); + } + + @Override + public RayActor createActor(FunctionDescriptor functionDescriptor, List args, + ActorCreationOptions options) { + long nativeActorHandle = nativeCreateActor(nativeCoreWorkerPointer, functionDescriptor, args, + options); + return new NativeRayActor(nativeActorHandle); + } + + @Override + public List submitActorTask(RayActor actor, FunctionDescriptor functionDescriptor, + List args, int numReturns, CallOptions options) { + Preconditions.checkState(actor instanceof NativeRayActor); + List returnIds = nativeSubmitActorTask(nativeCoreWorkerPointer, + ((NativeRayActor) actor).getNativeActorHandle(), functionDescriptor, args, numReturns, + options); + return returnIds.stream().map(ObjectId::new).collect(Collectors.toList()); + } + + private static native List nativeSubmitTask(long nativeCoreWorkerPointer, + FunctionDescriptor functionDescriptor, List args, int numReturns, + CallOptions callOptions); + + private static native long nativeCreateActor(long nativeCoreWorkerPointer, + FunctionDescriptor functionDescriptor, List args, + ActorCreationOptions actorCreationOptions); + + private static native List nativeSubmitActorTask(long nativeCoreWorkerPointer, + long nativeActorHandle, FunctionDescriptor functionDescriptor, List args, + int numReturns, CallOptions callOptions); +} diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java similarity index 61% rename from java/runtime/src/main/java/org/ray/runtime/Worker.java rename to java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java index e4695add6..2f595bd2e 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java @@ -1,4 +1,4 @@ -package org.ray.runtime; +package org.ray.runtime.task; import com.google.common.base.Preconditions; import java.util.ArrayList; @@ -8,38 +8,34 @@ import org.ray.api.Checkpointable.Checkpoint; import org.ray.api.Checkpointable.CheckpointContext; import org.ray.api.exception.RayTaskException; import org.ray.api.id.ActorId; -import org.ray.api.id.ObjectId; +import org.ray.api.id.JobId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; +import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.config.RunMode; +import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.runtime.functionmanager.RayFunction; -import org.ray.runtime.task.ArgumentsBuilder; -import org.ray.runtime.task.TaskSpec; -import org.ray.runtime.util.IdUtil; +import org.ray.runtime.generated.Common.TaskType; +import org.ray.runtime.object.NativeRayObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * The worker, which pulls tasks from {@link org.ray.runtime.raylet.RayletClient} and executes them - * continuously. + * The task executor, which executes tasks assigned by raylet continuously. */ -public class Worker { +public final class TaskExecutor { - private static final Logger LOGGER = LoggerFactory.getLogger(Worker.class); + private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class); // TODO(hchen): Use the C++ config. private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20; - private final AbstractRayRuntime runtime; + protected final AbstractRayRuntime runtime; /** * The current actor object, if this worker is an actor, otherwise null. */ - private Object currentActor = null; - - /** - * Id of the current actor object, if the worker is an actor, otherwise NIL. - */ - private ActorId currentActorId = ActorId.NIL; + protected Object currentActor = null; /** * The exception that failed the actor creation task, if any. @@ -61,53 +57,36 @@ public class Worker { */ private long lastCheckpointTimestamp = 0; - - public Worker(AbstractRayRuntime runtime) { + public TaskExecutor(AbstractRayRuntime runtime) { this.runtime = runtime; } - public ActorId getCurrentActorId() { - return currentActorId; - } + protected List execute(List rayFunctionInfo, + List argsBytes) { + JobId jobId = runtime.getWorkerContext().getCurrentJobId(); + TaskType taskType = runtime.getWorkerContext().getCurrentTaskType(); + TaskId taskId = runtime.getWorkerContext().getCurrentTaskId(); + LOGGER.debug("Executing task {}", taskId); - public void loop() { - while (true) { - LOGGER.info("Fetching new task in thread {}.", Thread.currentThread().getName()); - TaskSpec task = runtime.getRayletClient().getTask(); - execute(task); - } - } - - /** - * Execute a task. - */ - public void execute(TaskSpec spec) { - LOGGER.debug("Executing task {}", spec); - ObjectId returnId = spec.returnIds[0]; + List returnObjects = new ArrayList<>(); ClassLoader oldLoader = Thread.currentThread().getContextClassLoader(); try { // Get method RayFunction rayFunction = runtime.getFunctionManager() - .getFunction(spec.jobId, spec.getJavaFunctionDescriptor()); - // Set context - runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader); + .getFunction(jobId, parseFunctionDescriptor(rayFunctionInfo)); Thread.currentThread().setContextClassLoader(rayFunction.classLoader); - - if (spec.isActorCreationTask()) { - currentActorId = spec.taskId.getActorId(); - } + runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader); // Get local actor object and arguments. Object actor = null; - if (spec.isActorTask()) { - Preconditions.checkState(spec.actorId.equals(currentActorId)); + if (taskType == TaskType.ACTOR_TASK) { if (actorCreationException != null) { throw actorCreationException; } actor = currentActor; } - Object[] args = ArgumentsBuilder.unwrap(spec, rayFunction.classLoader); + Object[] args = ArgumentsBuilder.unwrap(runtime.getObjectStore(), argsBytes); // Execute the task. Object result; if (!rayFunction.isConstructor()) { @@ -116,27 +95,37 @@ public class Worker { result = rayFunction.getConstructor().newInstance(args); } // Set result - if (!spec.isActorCreationTask()) { - if (spec.isActorTask()) { - maybeSaveCheckpoint(actor, spec.actorId); + if (taskType != TaskType.ACTOR_CREATION_TASK) { + if (taskType == TaskType.ACTOR_TASK) { + // TODO (kfstorm): handle checkpoint in core worker. + maybeSaveCheckpoint(actor, runtime.getWorkerContext().getCurrentActorId()); } - - runtime.put(returnId, result); + returnObjects.add(runtime.getObjectStore().serialize(result)); } else { - maybeLoadCheckpoint(result, spec.taskId.getActorId()); + // TODO (kfstorm): handle checkpoint in core worker. + maybeLoadCheckpoint(result, runtime.getWorkerContext().getCurrentActorId()); currentActor = result; } - LOGGER.debug("Finished executing task {}", spec.taskId); + LOGGER.debug("Finished executing task {}", taskId); } catch (Exception e) { - LOGGER.error("Error executing task " + spec, e); - if (!spec.isActorCreationTask()) { - runtime.put(returnId, new RayTaskException("Error executing task " + spec, e)); + LOGGER.error("Error executing task " + taskId, e); + if (taskType != TaskType.ACTOR_CREATION_TASK) { + returnObjects.add(runtime.getObjectStore() + .serialize(new RayTaskException("Error executing task " + taskId, e))); } else { actorCreationException = e; } } finally { Thread.currentThread().setContextClassLoader(oldLoader); + runtime.getWorkerContext().setCurrentClassLoader(null); } + return returnObjects; + } + + private JavaFunctionDescriptor parseFunctionDescriptor(List rayFunctionInfo) { + Preconditions.checkState(rayFunctionInfo != null && rayFunctionInfo.size() == 3); + return new JavaFunctionDescriptor(rayFunctionInfo.get(0), rayFunctionInfo.get(1), + rayFunctionInfo.get(2)); } private void maybeSaveCheckpoint(Object actor, ActorId actorId) { @@ -155,7 +144,7 @@ public class Worker { } numTasksSinceLastCheckpoint = 0; lastCheckpointTimestamp = System.currentTimeMillis(); - UniqueId checkpointId = runtime.rayletClient.prepareCheckpoint(actorId); + UniqueId checkpointId = runtime.getRayletClient().prepareCheckpoint(actorId); checkpointIds.add(checkpointId); if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) { ((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0)); @@ -192,7 +181,7 @@ public class Worker { Preconditions.checkArgument(checkpointValid, "'loadCheckpoint' must return a checkpoint ID that exists in the " + "'availableCheckpoints' list, or null."); - runtime.rayletClient.notifyActorResumedFromCheckpoint(actorId, checkpointId); + runtime.getRayletClient().notifyActorResumedFromCheckpoint(actorId, checkpointId); } } } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskLanguage.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskLanguage.java deleted file mode 100644 index a6b4f31d8..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskLanguage.java +++ /dev/null @@ -1,11 +0,0 @@ -package org.ray.runtime.task; - -/** - * Language of a Ray task. - */ -public enum TaskLanguage { - - JAVA, - - PYTHON, -} diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java deleted file mode 100644 index 522ddec57..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ /dev/null @@ -1,167 +0,0 @@ -package org.ray.runtime.task; - -import com.google.common.base.Preconditions; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import org.ray.api.id.ActorId; -import org.ray.api.id.JobId; -import org.ray.api.id.TaskId; -import org.ray.api.id.ObjectId; -import org.ray.api.id.UniqueId; -import org.ray.runtime.functionmanager.FunctionDescriptor; -import org.ray.runtime.functionmanager.JavaFunctionDescriptor; -import org.ray.runtime.functionmanager.PyFunctionDescriptor; - -/** - * Represents necessary information of a task for scheduling and executing. - */ -public class TaskSpec { - - // ID of the job that created this task. - public final JobId jobId; - - // Task ID of the task. - public final TaskId taskId; - - // Task ID of the parent task. - public final TaskId parentTaskId; - - // A count of the number of tasks submitted by the parent task before this one. - public final int parentCounter; - - // Id for createActor a target actor - public final ActorId actorCreationId; - - public final int maxActorReconstructions; - - // Actor ID of the task. This is the actor that this task is executed on - // or NIL_ACTOR_ID if the task is just a normal task. - public final ActorId actorId; - - // ID per actor client for session consistency - public final UniqueId actorHandleId; - - // Number of tasks that have been submitted to this actor so far. - public final int actorCounter; - - // Object id returned by the previous task submitted to the same actor. - public final ObjectId previousActorTaskDummyObjectId; - - // Task arguments. - public final UniqueId[] newActorHandles; - - // Task arguments. - public final FunctionArg[] args; - - // number of return objects. - public final int numReturns; - - // Return ids. - public final ObjectId[] returnIds; - - // The task's resource demands. - public final Map resources; - - // Language of this task. - public final TaskLanguage language; - - public final List dynamicWorkerOptions; - - // Descriptor of the remote function. - // Note, if task language is Java, the type is JavaFunctionDescriptor. If the task language - // is Python, the type is PyFunctionDescriptor. - private final FunctionDescriptor functionDescriptor; - - public boolean isActorTask() { - return !actorId.isNil(); - } - - public boolean isActorCreationTask() { - return !actorCreationId.isNil(); - } - - public TaskSpec( - JobId jobId, - TaskId taskId, - TaskId parentTaskId, - int parentCounter, - ActorId actorCreationId, - int maxActorReconstructions, - ActorId actorId, - UniqueId actorHandleId, - int actorCounter, - ObjectId previousActorTaskDummyObjectId, - UniqueId[] newActorHandles, - FunctionArg[] args, - int numReturns, - Map resources, - TaskLanguage language, - FunctionDescriptor functionDescriptor, - List dynamicWorkerOptions) { - this.jobId = jobId; - this.taskId = taskId; - this.parentTaskId = parentTaskId; - this.parentCounter = parentCounter; - this.actorCreationId = actorCreationId; - this.maxActorReconstructions = maxActorReconstructions; - this.actorId = actorId; - this.actorHandleId = actorHandleId; - this.actorCounter = actorCounter; - this.previousActorTaskDummyObjectId = previousActorTaskDummyObjectId; - this.newActorHandles = newActorHandles; - this.args = args; - this.numReturns = numReturns; - this.dynamicWorkerOptions = dynamicWorkerOptions; - - returnIds = new ObjectId[numReturns]; - for (int i = 0; i < numReturns; ++i) { - returnIds[i] = ObjectId.forReturn(taskId, i + 1); - } - this.resources = resources; - this.language = language; - if (language == TaskLanguage.JAVA) { - Preconditions.checkArgument(functionDescriptor instanceof JavaFunctionDescriptor, - "Expect JavaFunctionDescriptor type, but got {}.", functionDescriptor.getClass()); - } else if (language == TaskLanguage.PYTHON) { - Preconditions.checkArgument(functionDescriptor instanceof PyFunctionDescriptor, - "Expect PyFunctionDescriptor type, but got {}.", functionDescriptor.getClass()); - } else { - Preconditions.checkArgument(false, "Unknown task language: {}.", language); - } - this.functionDescriptor = functionDescriptor; - } - - public JavaFunctionDescriptor getJavaFunctionDescriptor() { - Preconditions.checkState(language == TaskLanguage.JAVA); - return (JavaFunctionDescriptor) functionDescriptor; - } - - public PyFunctionDescriptor getPyFunctionDescriptor() { - Preconditions.checkState(language == TaskLanguage.PYTHON); - return (PyFunctionDescriptor) functionDescriptor; - } - - @Override - public String toString() { - return "TaskSpec{" + - "jobId=" + jobId + - ", taskId=" + taskId + - ", parentTaskId=" + parentTaskId + - ", parentCounter=" + parentCounter + - ", actorCreationId=" + actorCreationId + - ", maxActorReconstructions=" + maxActorReconstructions + - ", actorId=" + actorId + - ", actorHandleId=" + actorHandleId + - ", actorCounter=" + actorCounter + - ", previousActorTaskDummyObjectId=" + previousActorTaskDummyObjectId + - ", newActorHandles=" + Arrays.toString(newActorHandles) + - ", args=" + Arrays.toString(args) + - ", numReturns=" + numReturns + - ", resources=" + resources + - ", language=" + language + - ", functionDescriptor=" + functionDescriptor + - ", dynamicWorkerOptions=" + dynamicWorkerOptions + - '}'; - } -} diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSubmitter.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSubmitter.java new file mode 100644 index 000000000..d7f825616 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSubmitter.java @@ -0,0 +1,47 @@ +package org.ray.runtime.task; + +import java.util.List; +import org.ray.api.RayActor; +import org.ray.api.id.ObjectId; +import org.ray.api.options.ActorCreationOptions; +import org.ray.api.options.CallOptions; +import org.ray.runtime.functionmanager.FunctionDescriptor; + +/** + * A set of methods to submit tasks and create actors. + */ +public interface TaskSubmitter { + + /** + * Submit a normal task. + * @param functionDescriptor The remote function to execute. + * @param args Arguments of this task. + * @param numReturns Return object count. + * @param options Options for this task. + * @return Ids of the return objects. + */ + List submitTask(FunctionDescriptor functionDescriptor, List args, + int numReturns, CallOptions options); + + /** + * Create an actor. + * @param functionDescriptor The remote function that generates the actor object. + * @param args Arguments of this task. + * @param options Options for this actor creation task. + * @return Handle to the actor. + */ + RayActor createActor(FunctionDescriptor functionDescriptor, List args, + ActorCreationOptions options); + + /** + * Submit an actor task. + * @param actor Handle to the actor. + * @param functionDescriptor The remote function to execute. + * @param args Arguments of this task. + * @param numReturns Return object count. + * @param options Options for this task. + * @return Ids of the return objects. + */ + List submitActorTask(RayActor actor, FunctionDescriptor functionDescriptor, + List args, int numReturns, CallOptions options); +} diff --git a/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java index 93674db84..925603bab 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java @@ -1,28 +1,13 @@ package org.ray.runtime.util; -import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.List; import org.ray.api.id.BaseId; -import org.ray.api.id.ObjectId; -import org.ray.api.id.ActorId; /** - * Helper method for different Ids. - * Note: any changes to these methods must be synced with C++ helper functions - * in src/ray/common/id.h + * Helper method for different Ids. Note: any changes to these methods must be synced with C++ + * helper functions in src/ray/common/id.h */ public class IdUtil { - public static byte[][] getIdBytes(List objectIds) { - int size = objectIds.size(); - byte[][] ids = new byte[size][]; - for (int i = 0; i < size; i++) { - ids[i] = objectIds.get(i).getBytes(); - } - return ids; - } - /** * Compute the murmur hash code of this ID. */ @@ -43,14 +28,14 @@ public class IdUtil { for (int i = 0; i < length8; i++) { final int i8 = i * 8; - long k = ((long)data[i8] & 0xff) - + (((long)data[i8 + 1] & 0xff) << 8) - + (((long)data[i8 + 2] & 0xff) << 16) - + (((long)data[i8 + 3] & 0xff) << 24) - + (((long)data[i8 + 4] & 0xff) << 32) - + (((long)data[i8 + 5] & 0xff) << 40) - + (((long)data[i8 + 6] & 0xff) << 48) - + (((long)data[i8 + 7] & 0xff) << 56); + long k = ((long) data[i8] & 0xff) + + (((long) data[i8 + 1] & 0xff) << 8) + + (((long) data[i8 + 2] & 0xff) << 16) + + (((long) data[i8 + 3] & 0xff) << 24) + + (((long) data[i8 + 4] & 0xff) << 32) + + (((long) data[i8 + 5] & 0xff) << 40) + + (((long) data[i8 + 6] & 0xff) << 48) + + (((long) data[i8 + 7] & 0xff) << 56); k *= m; k ^= k >>> r; @@ -90,16 +75,4 @@ public class IdUtil { return h; } - - /* - * A helper function to compute actor creation dummy object id according - * the given actor id. - */ - public static ObjectId computeActorCreationDummyObjectId(ActorId actorId) { - byte[] bytes = new byte[ObjectId.LENGTH]; - System.arraycopy(actorId.getBytes(), 0, bytes, 0, ActorId.LENGTH); - Arrays.fill(bytes, ActorId.LENGTH, bytes.length, (byte) 0xFF); - return ObjectId.fromByteBuffer(ByteBuffer.wrap(bytes)); - } - } diff --git a/java/runtime/src/main/java/org/ray/runtime/util/Serializer.java b/java/runtime/src/main/java/org/ray/runtime/util/Serializer.java index ab5080aaf..e29140411 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/Serializer.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/Serializer.java @@ -1,7 +1,8 @@ package org.ray.runtime.util; import org.nustaq.serialization.FSTConfiguration; -import org.ray.runtime.RayActorImpl; +import org.ray.runtime.actor.NativeRayActor; +import org.ray.runtime.actor.NativeRayActorSerializer; /** * Java object serialization TODO: use others (e.g. Arrow) for higher performance @@ -10,7 +11,7 @@ public class Serializer { private static final ThreadLocal conf = ThreadLocal.withInitial(() -> { FSTConfiguration conf = FSTConfiguration.createDefaultConfiguration(); - conf.registerSerializer(RayActorImpl.class, new RayActorSerializer(), true); + conf.registerSerializer(NativeRayActor.class, new NativeRayActorSerializer(), true); return conf; }); diff --git a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java index 41440b50d..0834106f9 100644 --- a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java +++ b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java @@ -4,7 +4,6 @@ import java.io.File; import java.nio.file.Files; import java.nio.file.Paths; import java.util.Map; -import java.util.Random; import javax.tools.JavaCompiler; import javax.tools.ToolProvider; import org.apache.commons.io.FileUtils; @@ -14,7 +13,6 @@ import org.ray.api.annotation.RayRemote; import org.ray.api.function.RayFunc0; import org.ray.api.function.RayFunc1; import org.ray.api.id.JobId; -import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.FunctionManager.JobFunctionTable; import org.testng.Assert; import org.testng.annotations.BeforeClass; diff --git a/java/streaming/src/test/java/org/ray/streaming/schedule/impl/TaskAssignImplTest.java b/java/streaming/src/test/java/org/ray/streaming/schedule/impl/TaskAssignImplTest.java index 90c6bff2a..d3604e487 100644 --- a/java/streaming/src/test/java/org/ray/streaming/schedule/impl/TaskAssignImplTest.java +++ b/java/streaming/src/test/java/org/ray/streaming/schedule/impl/TaskAssignImplTest.java @@ -1,5 +1,8 @@ package org.ray.streaming.schedule.impl; +import org.ray.api.id.ActorId; +import org.ray.api.id.ObjectId; +import org.ray.runtime.actor.LocalModeRayActor; import org.ray.streaming.api.partition.impl.RoundRobinPartition; import org.ray.streaming.core.graph.ExecutionEdge; import org.ray.streaming.core.graph.ExecutionGraph; @@ -12,7 +15,6 @@ import org.ray.streaming.schedule.ITaskAssign; import java.util.ArrayList; import java.util.List; import org.ray.api.RayActor; -import org.ray.runtime.RayActorImpl; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.Assert; @@ -29,7 +31,7 @@ public class TaskAssignImplTest { List> workers = new ArrayList<>(); for(int i = 0; i < plan.getPlanVertexList().size(); i++) { - workers.add(new RayActorImpl<>()); + workers.add(new LocalModeRayActor(ActorId.fromRandom(), ObjectId.fromRandom())); } ITaskAssign taskAssign = new TaskAssignImpl(); diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index 3bba9adf9..784c82c4c 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -1,6 +1,8 @@ package org.ray.api.test; import com.google.common.collect.ImmutableList; +import java.util.Collections; +import java.util.List; import java.util.concurrent.TimeUnit; import org.ray.api.Ray; import org.ray.api.RayActor; @@ -10,8 +12,8 @@ import org.ray.api.annotation.RayRemote; import org.ray.api.exception.UnreconstructableException; import org.ray.api.id.UniqueId; import org.ray.runtime.AbstractRayRuntime; -import org.ray.runtime.RayActorImpl; -import org.ray.runtime.objectstore.NativeRayObject; +import org.ray.runtime.actor.NativeRayActor; +import org.ray.runtime.object.NativeRayObject; import org.testng.Assert; import org.testng.annotations.Test; @@ -72,6 +74,12 @@ public class ActorTest extends BaseTest { return res.get(); } + @RayRemote + public static int testActorAsFieldOfParameter(List> actor, int delta) { + RayObject res = Ray.call(Counter::increase, actor.get(0), delta); + return res.get(); + } + @Test public void testPassActorAsParameter() { RayActor actor = Ray.createActor(Counter::new, 0); @@ -79,13 +87,17 @@ public class ActorTest extends BaseTest { Ray.call(ActorTest::testActorAsFirstParameter, actor, 1).get()); Assert.assertEquals(Integer.valueOf(11), Ray.call(ActorTest::testActorAsSecondParameter, 10, actor).get()); + Assert.assertEquals(Integer.valueOf(111), + Ray.call(ActorTest::testActorAsFieldOfParameter, Collections.singletonList(actor), 100) + .get()); } @Test public void testForkingActorHandle() { + TestUtils.skipTestUnderSingleProcess(); RayActor counter = Ray.createActor(Counter::new, 100); Assert.assertEquals(Integer.valueOf(101), Ray.call(Counter::increase, counter, 1).get()); - RayActor counter2 = ((RayActorImpl) counter).fork(); + RayActor counter2 = ((NativeRayActor) counter).fork(); Assert.assertEquals(Integer.valueOf(103), Ray.call(Counter::increase, counter2, 2).get()); } @@ -100,9 +112,8 @@ public class ActorTest extends BaseTest { Ray.internal().free(ImmutableList.of(value.getId()), false, false); // Wait until the object is deleted, because the above free operation is async. while (true) { - NativeRayObject result = ((AbstractRayRuntime) - Ray.internal()).getObjectStoreProxy().getObjectInterface() - .get(ImmutableList.of(value.getId()), 0).get(0); + NativeRayObject result = ((AbstractRayRuntime) Ray.internal()).getObjectStore() + .getRaw(ImmutableList.of(value.getId()), 0).get(0); if (result == null) { break; } diff --git a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java index c3d37e78a..b73ddd75c 100644 --- a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java @@ -7,7 +7,7 @@ import org.ray.api.RayObject; import org.ray.api.TestUtils; import org.ray.api.exception.RayException; import org.ray.api.id.ObjectId; -import org.ray.runtime.RayObjectImpl; +import org.ray.runtime.object.RayObjectImpl; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.Assert; diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java index 13e9930fc..e0d18d9a5 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java @@ -1,10 +1,12 @@ package org.ray.api.test; import com.google.common.collect.ImmutableList; +import java.util.Arrays; import org.ray.api.Ray; import org.ray.api.RayObject; import org.ray.api.TestUtils; import org.ray.api.annotation.RayRemote; +import org.ray.api.id.TaskId; import org.ray.runtime.AbstractRayRuntime; import org.testng.Assert; import org.testng.annotations.Test; @@ -24,8 +26,8 @@ public class PlasmaFreeTest extends BaseTest { Ray.internal().free(ImmutableList.of(helloId.getId()), true, false); final boolean result = TestUtils.waitForCondition(() -> - ((AbstractRayRuntime) Ray.internal()).getObjectStoreProxy().getObjectInterface() - .get(ImmutableList.of(helloId.getId()), 0).get(0) == null, 50); + ((AbstractRayRuntime) Ray.internal()).getObjectStore() + .getRaw(ImmutableList.of(helloId.getId()), 0).get(0) == null, 50); Assert.assertTrue(result); } @@ -36,9 +38,10 @@ public class PlasmaFreeTest extends BaseTest { Assert.assertEquals("hello", helloId.get()); Ray.internal().free(ImmutableList.of(helloId.getId()), true, true); + TaskId taskId = TaskId.fromBytes(Arrays.copyOf(helloId.getId().getBytes(), TaskId.LENGTH)); final boolean result = TestUtils.waitForCondition( - () -> !(((AbstractRayRuntime)Ray.internal()).getGcsClient()) - .rayletTaskExistsInGcs(helloId.getId().getTaskId()), 50); + () -> !(((AbstractRayRuntime) Ray.internal()).getGcsClient()) + .rayletTaskExistsInGcs(taskId), 50); Assert.assertTrue(result); } diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java index 08790f204..52083b4f1 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java @@ -1,10 +1,12 @@ package org.ray.api.test; +import java.util.Collections; import org.ray.api.Ray; import org.ray.api.TestUtils; import org.ray.api.id.ObjectId; import org.ray.runtime.AbstractRayRuntime; -import org.ray.runtime.objectstore.ObjectStoreProxy; +import org.ray.runtime.object.NativeRayObject; +import org.ray.runtime.object.ObjectStore; import org.testng.Assert; import org.testng.annotations.Test; @@ -15,11 +17,15 @@ public class PlasmaStoreTest extends BaseTest { TestUtils.skipTestUnderSingleProcess(); ObjectId objectId = ObjectId.fromRandom(); AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal(); - ObjectStoreProxy objectInterface = runtime.getObjectStoreProxy(); - objectInterface.put(objectId, 1); - Assert.assertEquals(objectInterface.get(objectId), (Integer) 1); - objectInterface.put(objectId, 2); + ObjectStore objectStore = runtime.getObjectStore(); + objectStore.putRaw(new NativeRayObject(new byte[]{1}, null), objectId); + Assert.assertEquals( + objectStore.getRaw(Collections.singletonList(objectId), -1).get(0).data[0], + (byte) 1); + objectStore.putRaw(new NativeRayObject(new byte[]{2}, null), objectId); // Putting 2 objects with duplicate ID should fail but ignored. - Assert.assertEquals(objectInterface.get(objectId), (Integer) 1); + Assert.assertEquals( + objectStore.getRaw(Collections.singletonList(objectId), -1).get(0).data[0], + (byte) 1); } } diff --git a/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java b/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java index 0fdcff03c..729815c23 100644 --- a/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java +++ b/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java @@ -1,23 +1,24 @@ package org.ray.api.test; +import org.ray.api.Ray; import org.ray.api.RayPyActor; -import org.ray.api.id.ActorId; -import org.ray.api.id.JobId; -import org.ray.api.id.UniqueId; -import org.ray.runtime.RayPyActorImpl; -import org.ray.runtime.util.Serializer; +import org.ray.api.id.ObjectId; +import org.ray.runtime.AbstractRayRuntime; +import org.ray.runtime.object.NativeRayObject; +import org.ray.runtime.object.ObjectStore; import org.testng.Assert; import org.testng.annotations.Test; -public class RaySerializerTest { +public class RaySerializerTest extends BaseMultiLanguageTest { @Test public void testSerializePyActor() { - final ActorId pyActorId = ActorId.generateActorId(JobId.fromInt(1)); - RayPyActor pyActor = new RayPyActorImpl(pyActorId, "test", "RaySerializerTest"); - byte[] bytes = Serializer.encode(pyActor); - RayPyActor result = Serializer.decode(bytes); - Assert.assertEquals(result.getId(), pyActorId); + RayPyActor pyActor = Ray.createPyActor("test", "RaySerializerTest"); + ObjectStore objectStore = ((AbstractRayRuntime) Ray.internal()).getObjectStore(); + NativeRayObject nativeRayObject = objectStore.serialize(pyActor); + RayPyActor result = (RayPyActor) objectStore + .deserialize(nativeRayObject, ObjectId.fromRandom()); + Assert.assertEquals(result.getId(), pyActor.getId()); Assert.assertEquals(result.getModuleName(), "test"); Assert.assertEquals(result.getClassName(), "RaySerializerTest"); } diff --git a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java index c2e5aee9d..af0af3129 100644 --- a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java +++ b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java @@ -50,36 +50,10 @@ public class UniqueIdTest { Assert.assertTrue(id6.isNil()); } - @Test - public void testComputeReturnId() { - // Mock a taskId, and the lowest 4 bytes should be 0. - TaskId taskId = TaskId.fromHexString("123456789ABCDE123456789ABCDE"); - - ObjectId returnId = ObjectId.forReturn(taskId, 1); - Assert.assertEquals("123456789abcde123456789abcde00c001000000", returnId.toString()); - Assert.assertEquals(returnId.getTaskId(), taskId); - - returnId = ObjectId.forReturn(taskId, 0x01020304); - Assert.assertEquals("123456789abcde123456789abcde00c004030201", returnId.toString()); - } - - @Test - public void testComputePutId() { - // Mock a taskId, the lowest 4 bytes should be 0. - TaskId taskId = TaskId.fromHexString("123456789ABCDE123456789ABCDE"); - - ObjectId putId = ObjectId.forPut(taskId, 1); - Assert.assertEquals("123456789abcde123456789abcde008001000000".toLowerCase(), putId.toString()); - - putId = ObjectId.forPut(taskId, 0x01020304); - Assert.assertEquals("123456789abcde123456789abcde008004030201".toLowerCase(), putId.toString()); - } - @Test void testMurmurHash() { UniqueId id = UniqueId.fromHexString("3131313131313131313132323232323232323232"); long remainder = Long.remainderUnsigned(IdUtil.murmurHashCode(id), 1000000000); Assert.assertEquals(remainder, 787616861); } - } diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index 78a659a2d..09b8ee61b 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -64,16 +64,6 @@ class TaskArg { const std::shared_ptr data_; }; -/// Information of a task -struct TaskInfo { - /// The ID of task. - const TaskID task_id; - /// The job ID. - const JobID job_id; - /// The type of task. - const TaskType task_type; -}; - enum class StoreProviderType { LOCAL_PLASMA, PLASMA, MEMORY }; enum class TaskTransportType { RAYLET, DIRECT_ACTOR }; diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 331cae850..896f8a70e 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -6,7 +6,10 @@ namespace ray { /// per-thread context for core worker. struct WorkerThreadContext { WorkerThreadContext() - : current_task_id_(TaskID::ForFakeTask()), task_index_(0), put_index_(0) {} + : current_task_id_(TaskID::ForFakeTask()), + current_actor_id_(ActorID::Nil()), + task_index_(0), + put_index_(0) {} int GetNextTaskIndex() { return ++task_index_; } @@ -18,6 +21,8 @@ struct WorkerThreadContext { return current_task_; } + const ActorID &GetCurrentActorID() const { return current_actor_id_; } + void SetCurrentTaskId(const TaskID &task_id) { current_task_id_ = task_id; task_index_ = 0; @@ -27,12 +32,22 @@ struct WorkerThreadContext { void SetCurrentTask(const TaskSpecification &task_spec) { SetCurrentTaskId(task_spec.TaskId()); current_task_ = std::make_shared(task_spec); + if (task_spec.IsActorCreationTask()) { + RAY_CHECK(current_actor_id_.IsNil()); + current_actor_id_ = task_spec.ActorCreationId(); + } + if (task_spec.IsActorTask()) { + RAY_CHECK(current_actor_id_ == task_spec.ActorId()); + } } private: /// The task ID for current task. TaskID current_task_id_; + /// ID of current actor. + ActorID current_actor_id_; + /// The current task. std::shared_ptr current_task_; @@ -81,6 +96,10 @@ std::shared_ptr WorkerContext::GetCurrentTask() const { return GetThreadContext().GetCurrentTask(); } +const ActorID &WorkerContext::GetCurrentActorID() const { + return GetThreadContext().GetCurrentActorID(); +} + WorkerThreadContext &WorkerContext::GetThreadContext() { if (thread_context_ == nullptr) { thread_context_ = std::unique_ptr(new WorkerThreadContext()); diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 8405501d3..77e9e2814 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -24,6 +24,8 @@ class WorkerContext { std::shared_ptr GetCurrentTask() const; + const ActorID &GetCurrentActorID() const; + int GetNextTaskIndex(); int GetNextPutIndex(); diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 71df4451b..fb01789c7 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -47,6 +47,12 @@ CoreWorker::~CoreWorker() { gcs_client_->Disconnect(); io_service_.stop(); io_thread_.join(); + if (task_execution_interface_) { + task_execution_interface_->Stop(); + } + if (raylet_client_) { + RAY_IGNORE_EXPR(raylet_client_->Disconnect()); + } } void CoreWorker::StartIOService() { io_service_.run(); } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index a5baa76d0..218881002 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -38,6 +38,10 @@ class CoreWorker { /// Language of this worker. Language GetLanguage() const { return language_; } + WorkerContext &GetWorkerContext() { return worker_context_; } + + RayletClient &GetRayletClient() { return *raylet_client_; } + /// Return the `CoreWorkerTaskInterface` that contains the methods related to task /// submisson. CoreWorkerTaskInterface &Tasks() { return *task_interface_; } diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 6c66f8f2f..a7bd918ac 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -3,6 +3,9 @@ jclass java_boolean_class; jmethodID java_boolean_init; +jclass java_double_class; +jmethodID java_double_double_value; + jclass java_list_class; jmethodID java_list_size; jmethodID java_list_get; @@ -12,18 +15,62 @@ jclass java_array_list_class; jmethodID java_array_list_init; jmethodID java_array_list_init_with_capacity; +jclass java_map_class; +jmethodID java_map_entry_set; + +jclass java_set_class; +jmethodID java_set_iterator; + +jclass java_iterator_class; +jmethodID java_iterator_has_next; +jmethodID java_iterator_next; + +jclass java_map_entry_class; +jmethodID java_map_entry_get_key; +jmethodID java_map_entry_get_value; + jclass java_ray_exception_class; +jclass java_base_id_class; +jmethodID java_base_id_get_bytes; + +jclass java_function_descriptor_class; +jmethodID java_function_descriptor_get_language; +jmethodID java_function_descriptor_to_list; + +jclass java_language_class; +jmethodID java_language_get_number; + +jclass java_function_arg_class; +jfieldID java_function_arg_id; +jfieldID java_function_arg_data; + +jclass java_base_task_options_class; +jfieldID java_base_task_options_resources; + +jclass java_actor_creation_options_class; +jfieldID java_actor_creation_options_max_reconstructions; +jfieldID java_actor_creation_options_jvm_options; + +jclass java_gcs_client_options_class; +jfieldID java_gcs_client_options_ip; +jfieldID java_gcs_client_options_port; +jfieldID java_gcs_client_options_password; + jclass java_native_ray_object_class; jmethodID java_native_ray_object_init; jfieldID java_native_ray_object_data; jfieldID java_native_ray_object_metadata; -jint JNI_VERSION = JNI_VERSION_1_8; +jclass java_task_executor_class; +jmethodID java_task_executor_execute; + +JavaVM *jvm; inline jclass LoadClass(JNIEnv *env, const char *class_name) { jclass tempLocalClassRef = env->FindClass(class_name); jclass ret = (jclass)env->NewGlobalRef(tempLocalClassRef); + RAY_CHECK(ret) << "Can't load Java class " << class_name; env->DeleteLocalRef(tempLocalClassRef); return ret; } @@ -31,13 +78,18 @@ inline jclass LoadClass(JNIEnv *env, const char *class_name) { /// Load and cache frequently-used Java classes and methods jint JNI_OnLoad(JavaVM *vm, void *reserved) { JNIEnv *env; - if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { + if (vm->GetEnv(reinterpret_cast(&env), CURRENT_JNI_VERSION) != JNI_OK) { return JNI_ERR; } + jvm = vm; + java_boolean_class = LoadClass(env, "java/lang/Boolean"); java_boolean_init = env->GetMethodID(java_boolean_class, "", "(Z)V"); + java_double_class = LoadClass(env, "java/lang/Double"); + java_double_double_value = env->GetMethodID(java_double_class, "doubleValue", "()D"); + java_list_class = LoadClass(env, "java/util/List"); java_list_size = env->GetMethodID(java_list_class, "size", "()I"); java_list_get = env->GetMethodID(java_list_class, "get", "(I)Ljava/lang/Object;"); @@ -48,10 +100,65 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_array_list_init_with_capacity = env->GetMethodID(java_array_list_class, "", "(I)V"); + java_map_class = LoadClass(env, "java/util/Map"); + java_map_entry_set = env->GetMethodID(java_map_class, "entrySet", "()Ljava/util/Set;"); + + java_set_class = LoadClass(env, "java/util/Set"); + java_set_iterator = + env->GetMethodID(java_set_class, "iterator", "()Ljava/util/Iterator;"); + + java_iterator_class = LoadClass(env, "java/util/Iterator"); + java_iterator_has_next = env->GetMethodID(java_iterator_class, "hasNext", "()Z"); + java_iterator_next = + env->GetMethodID(java_iterator_class, "next", "()Ljava/lang/Object;"); + + java_map_entry_class = LoadClass(env, "java/util/Map$Entry"); + java_map_entry_get_key = + env->GetMethodID(java_map_entry_class, "getKey", "()Ljava/lang/Object;"); + java_map_entry_get_value = + env->GetMethodID(java_map_entry_class, "getValue", "()Ljava/lang/Object;"); + java_ray_exception_class = LoadClass(env, "org/ray/api/exception/RayException"); - java_native_ray_object_class = - LoadClass(env, "org/ray/runtime/objectstore/NativeRayObject"); + java_base_id_class = LoadClass(env, "org/ray/api/id/BaseId"); + java_base_id_get_bytes = env->GetMethodID(java_base_id_class, "getBytes", "()[B"); + + java_function_descriptor_class = + LoadClass(env, "org/ray/runtime/functionmanager/FunctionDescriptor"); + java_function_descriptor_get_language = + env->GetMethodID(java_function_descriptor_class, "getLanguage", + "()Lorg/ray/runtime/generated/Common$Language;"); + java_function_descriptor_to_list = + env->GetMethodID(java_function_descriptor_class, "toList", "()Ljava/util/List;"); + + java_language_class = LoadClass(env, "org/ray/runtime/generated/Common$Language"); + java_language_get_number = env->GetMethodID(java_language_class, "getNumber", "()I"); + + java_function_arg_class = LoadClass(env, "org/ray/runtime/task/FunctionArg"); + java_function_arg_id = + env->GetFieldID(java_function_arg_class, "id", "Lorg/ray/api/id/ObjectId;"); + java_function_arg_data = env->GetFieldID(java_function_arg_class, "data", "[B"); + + java_base_task_options_class = LoadClass(env, "org/ray/api/options/BaseTaskOptions"); + java_base_task_options_resources = + env->GetFieldID(java_base_task_options_class, "resources", "Ljava/util/Map;"); + + java_actor_creation_options_class = + LoadClass(env, "org/ray/api/options/ActorCreationOptions"); + java_actor_creation_options_max_reconstructions = + env->GetFieldID(java_actor_creation_options_class, "maxReconstructions", "I"); + java_actor_creation_options_jvm_options = env->GetFieldID( + java_actor_creation_options_class, "jvmOptions", "Ljava/lang/String;"); + + java_gcs_client_options_class = LoadClass(env, "org/ray/runtime/gcs/GcsClientOptions"); + java_gcs_client_options_ip = + env->GetFieldID(java_gcs_client_options_class, "ip", "Ljava/lang/String;"); + java_gcs_client_options_port = + env->GetFieldID(java_gcs_client_options_class, "port", "I"); + java_gcs_client_options_password = + env->GetFieldID(java_gcs_client_options_class, "password", "Ljava/lang/String;"); + + java_native_ray_object_class = LoadClass(env, "org/ray/runtime/object/NativeRayObject"); java_native_ray_object_init = env->GetMethodID(java_native_ray_object_class, "", "([B[B)V"); java_native_ray_object_data = @@ -59,17 +166,34 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_native_ray_object_metadata = env->GetFieldID(java_native_ray_object_class, "metadata", "[B"); - return JNI_VERSION; + java_task_executor_class = LoadClass(env, "org/ray/runtime/task/TaskExecutor"); + java_task_executor_execute = + env->GetMethodID(java_task_executor_class, "execute", + "(Ljava/util/List;Ljava/util/List;)Ljava/util/List;"); + + return CURRENT_JNI_VERSION; } /// Unload java classes void JNI_OnUnload(JavaVM *vm, void *reserved) { JNIEnv *env; - vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); + vm->GetEnv(reinterpret_cast(&env), CURRENT_JNI_VERSION); env->DeleteGlobalRef(java_boolean_class); + env->DeleteGlobalRef(java_double_class); env->DeleteGlobalRef(java_list_class); env->DeleteGlobalRef(java_array_list_class); + env->DeleteGlobalRef(java_map_class); + env->DeleteGlobalRef(java_set_class); + env->DeleteGlobalRef(java_iterator_class); + env->DeleteGlobalRef(java_map_entry_class); env->DeleteGlobalRef(java_ray_exception_class); + env->DeleteGlobalRef(java_base_id_class); + env->DeleteGlobalRef(java_function_descriptor_class); + env->DeleteGlobalRef(java_language_class); + env->DeleteGlobalRef(java_function_arg_class); + env->DeleteGlobalRef(java_base_task_options_class); + env->DeleteGlobalRef(java_actor_creation_options_class); env->DeleteGlobalRef(java_native_ray_object_class); + env->DeleteGlobalRef(java_task_executor_class); } diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index d0f4ca8a5..396b5a841 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -1,5 +1,5 @@ -#ifndef RAY_COMMON_JAVA_JNI_HELPER_H -#define RAY_COMMON_JAVA_JNI_HELPER_H +#ifndef RAY_COMMON_JAVA_JNI_UTILS_H +#define RAY_COMMON_JAVA_JNI_UTILS_H #include #include "ray/common/buffer.h" @@ -12,6 +12,11 @@ extern jclass java_boolean_class; /// Constructor of Boolean class extern jmethodID java_boolean_init; +/// Double class +extern jclass java_double_class; +/// doubleValue method of Double class +extern jmethodID java_double_double_value; + /// List class extern jclass java_list_class; /// size method of List class @@ -28,9 +33,78 @@ extern jmethodID java_array_list_init; /// Constructor of ArrayList class with single parameter capacity extern jmethodID java_array_list_init_with_capacity; +/// Map interface +extern jclass java_map_class; +/// entrySet method of Map interface +extern jmethodID java_map_entry_set; + +/// Set interface +extern jclass java_set_class; +/// iterator method of Set interface +extern jmethodID java_set_iterator; + +/// Iterator interface +extern jclass java_iterator_class; +/// hasNext method of Iterator interface +extern jmethodID java_iterator_has_next; +/// next method of Iterator interface +extern jmethodID java_iterator_next; + +/// Map.Entry interface +extern jclass java_map_entry_class; +/// getKey method of Map.Entry interface +extern jmethodID java_map_entry_get_key; +/// getValue method of Map.Entry interface +extern jmethodID java_map_entry_get_value; + /// RayException class extern jclass java_ray_exception_class; +/// BaseId class +extern jclass java_base_id_class; +/// getBytes method of BaseId class +extern jmethodID java_base_id_get_bytes; + +/// FunctionDescriptor interface +extern jclass java_function_descriptor_class; +/// getLanguage method of FunctionDescriptor interface +extern jmethodID java_function_descriptor_get_language; +/// toList method of FunctionDescriptor interface +extern jmethodID java_function_descriptor_to_list; + +/// Language class +extern jclass java_language_class; +/// getNumber of Language class +extern jmethodID java_language_get_number; + +/// NativeTaskArg class +extern jclass java_function_arg_class; +/// id field of NativeTaskArg class +extern jfieldID java_function_arg_id; +/// data field of NativeTaskArg class +extern jfieldID java_function_arg_data; + +/// BaseTaskOptions class +extern jclass java_base_task_options_class; +/// resources field of BaseTaskOptions class +extern jfieldID java_base_task_options_resources; + +/// ActorCreationOptions class +extern jclass java_actor_creation_options_class; +/// maxReconstructions field of ActorCreationOptions class +extern jfieldID java_actor_creation_options_max_reconstructions; +/// jvmOptions field of ActorCreationOptions class +extern jfieldID java_actor_creation_options_jvm_options; + +/// GcsClientOptions class +extern jclass java_gcs_client_options_class; +/// ip field of GcsClientOptions class +extern jfieldID java_gcs_client_options_ip; +/// port field of GcsClientOptions class +extern jfieldID java_gcs_client_options_port; +/// password field of GcsClientOptions class +extern jfieldID java_gcs_client_options_password; + /// NativeRayObject class extern jclass java_native_ray_object_class; /// Constructor of NativeRayObject class @@ -40,6 +114,15 @@ extern jfieldID java_native_ray_object_data; /// metadata field of NativeRayObject class extern jfieldID java_native_ray_object_metadata; +/// TaskExecutor class +extern jclass java_task_executor_class; +/// execute method of TaskExecutor class +extern jmethodID java_task_executor_execute; + +#define CURRENT_JNI_VERSION JNI_VERSION_1_8 + +extern JavaVM *jvm; + /// Throws a Java RayException if the status is not OK. #define THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, ret) \ { \ @@ -49,6 +132,32 @@ extern jfieldID java_native_ray_object_metadata; } \ } +/// Represents a byte buffer of Java byte array. +/// The destructor will automatically call ReleaseByteArrayElements. +/// NOTE: Instances of this class cannot be used across threads. +class JavaByteArrayBuffer : public ray::Buffer { + public: + JavaByteArrayBuffer(JNIEnv *env, jbyteArray java_byte_array) + : env_(env), java_byte_array_(java_byte_array) { + native_bytes_ = env_->GetByteArrayElements(java_byte_array_, nullptr); + } + + uint8_t *Data() const override { return reinterpret_cast(native_bytes_); } + + size_t Size() const override { return env_->GetArrayLength(java_byte_array_); } + + bool OwnsData() const override { return true; } + + ~JavaByteArrayBuffer() { + env_->ReleaseByteArrayElements(java_byte_array_, native_bytes_, JNI_ABORT); + } + + private: + JNIEnv *env_; + jbyteArray java_byte_array_; + jbyte *native_bytes_; +}; + /// Convert a Java byte array to a C++ UniqueID. template inline ID JavaByteArrayToId(JNIEnv *env, const jbyteArray &bytes) { @@ -95,6 +204,15 @@ inline void JavaListToNativeVector( } } +/// Convert a Java List to C++ std::vector. +inline void JavaStringListToNativeStringVector(JNIEnv *env, jobject java_list, + std::vector *native_vector) { + JavaListToNativeVector( + env, java_list, native_vector, [](JNIEnv *env, jobject jstr) { + return JavaStringToNativeString(env, static_cast(jstr)); + }); +} + /// Convert a C++ std::vector to a Java List. template inline jobject NativeVectorToJavaList( @@ -109,6 +227,22 @@ inline jobject NativeVectorToJavaList( return java_list; } +/// Convert a C++ std::vector to a Java List +inline jobject NativeStringVectorToJavaStringList( + JNIEnv *env, const std::vector &native_vector) { + return NativeVectorToJavaList( + env, native_vector, + [](JNIEnv *env, const std::string &str) { return env->NewStringUTF(str.c_str()); }); +} + +template +inline jobject NativeIdVectorToJavaByteArrayList(JNIEnv *env, + const std::vector &native_vector) { + return NativeVectorToJavaList(env, native_vector, [](JNIEnv *env, const ID &id) { + return IdToJavaByteArray(env, id); + }); +} + /// Convert a C++ ray::Buffer to a Java byte array. inline jbyteArray NativeBufferToJavaByteArray(JNIEnv *env, const std::shared_ptr buffer) { @@ -123,50 +257,40 @@ inline jbyteArray NativeBufferToJavaByteArray(JNIEnv *env, return java_byte_array; } -/// A helper method to help access a Java NativeRayObject instance and ensure memory -/// safety. -/// -/// \param[in] java_obj The Java NativeRayObject object. -/// \param[in] reader The callback function to access a C++ ray::RayObject instance. -/// \return The return value of callback function. -template -inline ReturnT ReadJavaNativeRayObject( - JNIEnv *env, const jobject &java_obj, - std::function &)> reader) { +/// Convert a Java byte[] as a C++ std::shared_ptr. +inline std::shared_ptr JavaByteArrayToNativeBuffer( + JNIEnv *env, const jbyteArray &javaByteArray) { + if (!javaByteArray) { + return nullptr; + } + return std::make_shared(env, javaByteArray); +} + +/// Convert a Java NativeRayObject to a C++ ray::RayObject. +/// NOTE: the returned ray::RayObject cannot be used across threads. +inline std::shared_ptr JavaNativeRayObjectToNativeRayObject( + JNIEnv *env, const jobject &java_obj) { if (!java_obj) { - return reader(nullptr); + return nullptr; } auto java_data = (jbyteArray)env->GetObjectField(java_obj, java_native_ray_object_data); auto java_metadata = (jbyteArray)env->GetObjectField(java_obj, java_native_ray_object_metadata); - auto data_size = env->GetArrayLength(java_data); - jbyte *data = data_size > 0 ? env->GetByteArrayElements(java_data, nullptr) : nullptr; - auto metadata_size = java_metadata ? env->GetArrayLength(java_metadata) : 0; - jbyte *metadata = - metadata_size > 0 ? env->GetByteArrayElements(java_metadata, nullptr) : nullptr; - auto data_buffer = std::make_shared( - reinterpret_cast(data), data_size); - auto metadata_buffer = java_metadata - ? std::make_shared( - reinterpret_cast(metadata), metadata_size) - : nullptr; - - auto native_obj = std::make_shared(data_buffer, metadata_buffer); - auto result = reader(native_obj); - - if (data) { - env->ReleaseByteArrayElements(java_data, data, JNI_ABORT); + std::shared_ptr data_buffer = JavaByteArrayToNativeBuffer(env, java_data); + std::shared_ptr metadata_buffer = + JavaByteArrayToNativeBuffer(env, java_metadata); + if (!data_buffer) { + data_buffer = std::make_shared(nullptr, 0); } - if (metadata) { - env->ReleaseByteArrayElements(java_metadata, metadata, JNI_ABORT); + if (!metadata_buffer) { + metadata_buffer = std::make_shared(nullptr, 0); } - - return result; + return std::make_shared(data_buffer, metadata_buffer); } /// Convert a C++ ray::RayObject to a Java NativeRayObject. -inline jobject ToJavaNativeRayObject(JNIEnv *env, - const std::shared_ptr &rayObject) { +inline jobject NativeRayObjectToJavaNativeRayObject( + JNIEnv *env, const std::shared_ptr &rayObject) { if (!rayObject) { return nullptr; } @@ -177,4 +301,4 @@ inline jobject ToJavaNativeRayObject(JNIEnv *env, return java_obj; } -#endif // RAY_COMMON_JAVA_JNI_HELPER_H +#endif // RAY_COMMON_JAVA_JNI_UTILS_H diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc new file mode 100644 index 000000000..c1c545e8b --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc @@ -0,0 +1,134 @@ +#include "ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h" +#include +#include +#include "ray/common/id.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/lib/java/jni_utils.h" + +thread_local JNIEnv *local_env = nullptr; +thread_local jobject local_java_task_executor = nullptr; + +inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env, + jobject gcs_client_options) { + std::string ip = JavaStringToNativeString( + env, (jstring)env->GetObjectField(gcs_client_options, java_gcs_client_options_ip)); + int port = env->GetIntField(gcs_client_options, java_gcs_client_options_port); + std::string password = JavaStringToNativeString( + env, + (jstring)env->GetObjectField(gcs_client_options, java_gcs_client_options_password)); + return ray::gcs::GcsClientOptions(ip, port, password, /*is_test_client=*/false); +} + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeInitCoreWorker + * Signature: + * (ILjava/lang/String;Ljava/lang/String;[BLorg/ray/runtime/gcs/GcsClientOptions;)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWorker( + JNIEnv *env, jclass, jint workerMode, jstring storeSocket, jstring rayletSocket, + jbyteArray jobId, jobject gcsClientOptions) { + auto native_store_socket = JavaStringToNativeString(env, storeSocket); + auto native_raylet_socket = JavaStringToNativeString(env, rayletSocket); + auto job_id = JavaByteArrayToId(env, jobId); + auto gcs_client_options = ToGcsClientOptions(env, gcsClientOptions); + + auto executor_func = [](const ray::RayFunction &ray_function, + const std::vector> &args, + int num_returns, + std::vector> *results) { + JNIEnv *env = local_env; + RAY_CHECK(env); + RAY_CHECK(local_java_task_executor); + // convert RayFunction + jobject ray_function_array_list = + NativeStringVectorToJavaStringList(env, ray_function.function_descriptor); + // convert args + // TODO (kfstorm): Avoid copying binary data from Java to C++ + jobject args_array_list = NativeVectorToJavaList>( + env, args, NativeRayObjectToJavaNativeRayObject); + + // invoke Java method + jobject java_return_objects = + env->CallObjectMethod(local_java_task_executor, java_task_executor_execute, + ray_function_array_list, args_array_list); + std::vector> return_objects; + JavaListToNativeVector>( + env, java_return_objects, &return_objects, + [](JNIEnv *env, jobject java_native_ray_object) { + return JavaNativeRayObjectToNativeRayObject(env, java_native_ray_object); + }); + for (auto &obj : return_objects) { + results->push_back(obj); + } + return ray::Status::OK(); + }; + + try { + auto core_worker = new ray::CoreWorker( + static_cast(workerMode), ::Language::JAVA, native_store_socket, + native_raylet_socket, job_id, gcs_client_options, executor_func); + return reinterpret_cast(core_worker); + } catch (const std::exception &e) { + std::ostringstream oss; + oss << "Failed to construct core worker: " << e.what(); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, ray::Status::Invalid(oss.str()), 0); + return 0; // To make compiler no complain + } +} + +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeRunTaskExecutor + * Signature: (JLorg/ray/runtime/task/TaskExecutor;)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor( + JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jobject javaTaskExecutor) { + local_env = env; + local_java_task_executor = javaTaskExecutor; + auto core_worker = reinterpret_cast(nativeCoreWorkerPointer); + core_worker->Execution().Run(); + local_env = nullptr; + local_java_task_executor = nullptr; +} + +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeDestroyCoreWorker + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeDestroyCoreWorker( + JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer) { + delete reinterpret_cast(nativeCoreWorkerPointer); +} + +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeSetup + * Signature: (Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(JNIEnv *env, + jclass, + jstring logDir) { + std::string log_dir = JavaStringToNativeString(env, logDir); + ray::RayLog::StartRayLog("java_worker", ray::RayLogLevel::INFO, log_dir); + // TODO (kfstorm): If we add InstallFailureSignalHandler here, Java test may crash. +} + +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeShutdownHook + * Signature: ()V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook(JNIEnv *, + jclass) { + ray::RayLog::ShutDownRayLog(); +} + +#ifdef __cplusplus +} +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h new file mode 100644 index 000000000..c71fec982 --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h @@ -0,0 +1,54 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_ray_runtime_RayNativeRuntime */ + +#ifndef _Included_org_ray_runtime_RayNativeRuntime +#define _Included_org_ray_runtime_RayNativeRuntime +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeInitCoreWorker + * Signature: + * (ILjava/lang/String;Ljava/lang/String;[BLorg/ray/runtime/gcs/GcsClientOptions;)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWorker( + JNIEnv *, jclass, jint, jstring, jstring, jbyteArray, jobject); + +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeRunTaskExecutor + * Signature: (JLorg/ray/runtime/task/TaskExecutor;)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor( + JNIEnv *, jclass, jlong, jobject); + +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeDestroyCoreWorker + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_RayNativeRuntime_nativeDestroyCoreWorker(JNIEnv *, jclass, jlong); + +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeSetup + * Signature: (Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(JNIEnv *, jclass, + jstring); + +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeShutdownHook + * Signature: ()V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook(JNIEnv *, + jclass); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.cc b/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.cc deleted file mode 100644 index 2c91dcdaa..000000000 --- a/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.cc +++ /dev/null @@ -1,134 +0,0 @@ -#include "ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h" -#include -#include "ray/common/id.h" -#include "ray/core_worker/context.h" -#include "ray/core_worker/lib/java/jni_utils.h" - -inline ray::WorkerContext *GetWorkerContextFromPointer( - jlong nativeWorkerContextFromPointer) { - return reinterpret_cast(nativeWorkerContextFromPointer); -} - -#ifdef __cplusplus -extern "C" { -#endif - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeCreateWorkerContext - * Signature: (I[B)J - */ -JNIEXPORT jlong JNICALL Java_org_ray_runtime_WorkerContext_nativeCreateWorkerContext( - JNIEnv *env, jclass, jint workerType, jbyteArray jobId) { - return reinterpret_cast( - new ray::WorkerContext(static_cast(workerType), - JavaByteArrayToId(env, jobId))); -} - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeGetCurrentTaskId - * Signature: (J)[B - */ -JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentTaskId( - JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) { - auto task_id = - GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentTaskID(); - return IdToJavaByteArray(env, task_id); -} - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeSetCurrentTask - * Signature: (J[B)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeSetCurrentTask( - JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer, jbyteArray taskSpec) { - jbyte *data = env->GetByteArrayElements(taskSpec, NULL); - jsize size = env->GetArrayLength(taskSpec); - ray::rpc::TaskSpec task_spec_message; - task_spec_message.ParseFromArray(data, size); - env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT); - - ray::TaskSpecification spec(task_spec_message); - GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->SetCurrentTask(spec); -} - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeGetCurrentTask - * Signature: (J)[B - */ -JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentTask( - JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) { - auto spec = - GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentTask(); - if (!spec) { - return nullptr; - } - - auto task_message = spec->Serialize(); - jbyteArray result = env->NewByteArray(task_message.size()); - env->SetByteArrayRegion( - result, 0, task_message.size(), - reinterpret_cast(const_cast(task_message.data()))); - return result; -} - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeGetCurrentJobId - * Signature: (J)Ljava/nio/ByteBuffer; - */ -JNIEXPORT jobject JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentJobId( - JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) { - const auto &job_id = - GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentJobID(); - return IdToJavaByteBuffer(env, job_id); -} - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeGetCurrentWorkerId - * Signature: (J)[B - */ -JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentWorkerId( - JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) { - auto worker_id = - GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetWorkerID(); - return IdToJavaByteArray(env, worker_id); -} - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeGetNextTaskIndex - * Signature: (J)I - */ -JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextTaskIndex( - JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) { - return GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetNextTaskIndex(); -} - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeGetNextPutIndex - * Signature: (J)I - */ -JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextPutIndex( - JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) { - return GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetNextPutIndex(); -} - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeDestroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeDestroy( - JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) { - delete GetWorkerContextFromPointer(nativeWorkerContextFromPointer); -} - -#ifdef __cplusplus -} -#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h b/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h deleted file mode 100644 index df9c60a56..000000000 --- a/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h +++ /dev/null @@ -1,87 +0,0 @@ -/* DO NOT EDIT THIS FILE - it is machine generated */ -#include -/* Header for class org_ray_runtime_WorkerContext */ - -#ifndef _Included_org_ray_runtime_WorkerContext -#define _Included_org_ray_runtime_WorkerContext -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeCreateWorkerContext - * Signature: (I[B)J - */ -JNIEXPORT jlong JNICALL Java_org_ray_runtime_WorkerContext_nativeCreateWorkerContext( - JNIEnv *, jclass, jint, jbyteArray); - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeGetCurrentTaskId - * Signature: (J)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_WorkerContext_nativeGetCurrentTaskId(JNIEnv *, jclass, jlong); - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeSetCurrentTask - * Signature: (J[B)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeSetCurrentTask( - JNIEnv *, jclass, jlong, jbyteArray); - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeGetCurrentTask - * Signature: (J)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_WorkerContext_nativeGetCurrentTask(JNIEnv *, jclass, jlong); - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeGetCurrentJobId - * Signature: (J)Ljava/nio/ByteBuffer; - */ -JNIEXPORT jobject JNICALL -Java_org_ray_runtime_WorkerContext_nativeGetCurrentJobId(JNIEnv *, jclass, jlong); - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeGetCurrentWorkerId - * Signature: (J)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_WorkerContext_nativeGetCurrentWorkerId(JNIEnv *, jclass, jlong); - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeGetNextTaskIndex - * Signature: (J)I - */ -JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextTaskIndex(JNIEnv *, - jclass, - jlong); - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeGetNextPutIndex - * Signature: (J)I - */ -JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextPutIndex(JNIEnv *, - jclass, - jlong); - -/* - * Class: org_ray_runtime_WorkerContext - * Method: nativeDestroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeDestroy(JNIEnv *, jclass, - jlong); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc new file mode 100644 index 000000000..a63e7efa0 --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc @@ -0,0 +1,113 @@ +#include "ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h" +#include +#include "ray/common/id.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/lib/java/jni_utils.h" +#include "ray/core_worker/task_interface.h" + +inline ray::ActorHandle &GetActorHandle(jlong nativeActorHandle) { + return *(reinterpret_cast(nativeActorHandle)); +} + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeFork + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeFork( + JNIEnv *env, jclass o, jlong nativeActorHandle) { + auto new_actor_handle = GetActorHandle(nativeActorHandle).Fork(); + return reinterpret_cast(new ray::ActorHandle(new_actor_handle)); +} + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeGetActorId + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorId( + JNIEnv *env, jclass o, jlong nativeActorHandle) { + return IdToJavaByteArray(env, + GetActorHandle(nativeActorHandle).ActorID()); +} + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeGetActorHandleId + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorHandleId( + JNIEnv *env, jclass o, jlong nativeActorHandle) { + return IdToJavaByteArray( + env, GetActorHandle(nativeActorHandle).ActorHandleID()); +} + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeGetLanguage + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage( + JNIEnv *env, jclass o, jlong nativeActorHandle) { + return (jint)GetActorHandle(nativeActorHandle).ActorLanguage(); +} + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeGetActorCreationTaskFunctionDescriptor + * Signature: (J)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL +Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorCreationTaskFunctionDescriptor( + JNIEnv *env, jclass o, jlong nativeActorHandle) { + return NativeStringVectorToJavaStringList( + env, GetActorHandle(nativeActorHandle).ActorCreationTaskFunctionDescriptor()); +} + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeSerialize + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeSerialize( + JNIEnv *env, jclass o, jlong nativeActorHandle) { + std::string output; + GetActorHandle(nativeActorHandle).Serialize(&output); + jbyteArray bytes = env->NewByteArray(output.size()); + env->SetByteArrayRegion(bytes, 0, output.size(), + reinterpret_cast(output.c_str())); + return bytes; +} + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeDeserialize + * Signature: ([B)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeDeserialize( + JNIEnv *env, jclass o, jbyteArray data) { + auto buffer = JavaByteArrayToNativeBuffer(env, data); + RAY_CHECK(buffer->Size() > 0); + auto binary = std::string(reinterpret_cast(buffer->Data()), buffer->Size()); + return reinterpret_cast( + new ray::ActorHandle(ray::ActorHandle::Deserialize(binary))); +} + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeFree + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeFree( + JNIEnv *env, jclass o, jlong nativeActorHandle) { + delete &GetActorHandle(nativeActorHandle); +} + +#ifdef __cplusplus +} +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h new file mode 100644 index 000000000..4de114c7a --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h @@ -0,0 +1,80 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_ray_runtime_actor_NativeRayActor */ + +#ifndef _Included_org_ray_runtime_actor_NativeRayActor +#define _Included_org_ray_runtime_actor_NativeRayActor +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeFork + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeFork(JNIEnv *, + jclass, + jlong); + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeGetActorId + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorId(JNIEnv *, jclass, jlong); + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeGetActorHandleId + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorHandleId(JNIEnv *, jclass, jlong); + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeGetLanguage + * Signature: (J)I + */ +JNIEXPORT jint JNICALL +Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage(JNIEnv *, jclass, jlong); + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeGetActorCreationTaskFunctionDescriptor + * Signature: (J)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL +Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorCreationTaskFunctionDescriptor( + JNIEnv *, jclass, jlong); + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeSerialize + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_actor_NativeRayActor_nativeSerialize(JNIEnv *, jclass, jlong); + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeDeserialize + * Signature: ([B)J + */ +JNIEXPORT jlong JNICALL +Java_org_ray_runtime_actor_NativeRayActor_nativeDeserialize(JNIEnv *, jclass, jbyteArray); + +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeFree + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeFree(JNIEnv *, + jclass, + jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc new file mode 100644 index 000000000..b7e791044 --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc @@ -0,0 +1,83 @@ +#include "ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h" +#include +#include "ray/common/id.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/lib/java/jni_utils.h" + +inline ray::WorkerContext &GetWorkerContextFromPointer(jlong nativeCoreWorkerPointer) { + return reinterpret_cast(nativeCoreWorkerPointer)->GetWorkerContext(); +} + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: org_ray_runtime_context_NativeWorkerContext + * Method: nativeGetCurrentTaskType + * Signature: (J)I + */ +JNIEXPORT jint JNICALL +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskType( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { + auto task_spec = GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentTask(); + RAY_CHECK(task_spec) << "Current task is not set."; + return static_cast(task_spec->GetMessage().type()); +} + +/* + * Class: org_ray_runtime_context_NativeWorkerContext + * Method: nativeGetCurrentTaskId + * Signature: (J)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskId( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { + const ray::TaskID &task_id = + GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentTaskID(); + return IdToJavaByteBuffer(env, task_id); +} + +/* + * Class: org_ray_runtime_context_NativeWorkerContext + * Method: nativeGetCurrentJobId + * Signature: (J)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentJobId( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { + const auto &job_id = + GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentJobID(); + return IdToJavaByteBuffer(env, job_id); +} + +/* + * Class: org_ray_runtime_context_NativeWorkerContext + * Method: nativeGetCurrentWorkerId + * Signature: (J)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentWorkerId( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { + const auto &worker_id = + GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetWorkerID(); + return IdToJavaByteBuffer(env, worker_id); +} + +/* + * Class: org_ray_runtime_context_NativeWorkerContext + * Method: nativeGetCurrentActorId + * Signature: (J)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { + const auto &actor_id = + GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentActorID(); + return IdToJavaByteBuffer(env, actor_id); +} + +#ifdef __cplusplus +} +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h new file mode 100644 index 000000000..fe3725484 --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h @@ -0,0 +1,58 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_ray_runtime_context_NativeWorkerContext */ + +#ifndef _Included_org_ray_runtime_context_NativeWorkerContext +#define _Included_org_ray_runtime_context_NativeWorkerContext +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_ray_runtime_context_NativeWorkerContext + * Method: nativeGetCurrentTaskType + * Signature: (J)I + */ +JNIEXPORT jint JNICALL +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskType(JNIEnv *, + jclass, jlong); + +/* + * Class: org_ray_runtime_context_NativeWorkerContext + * Method: nativeGetCurrentTaskId + * Signature: (J)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskId(JNIEnv *, jclass, + jlong); + +/* + * Class: org_ray_runtime_context_NativeWorkerContext + * Method: nativeGetCurrentJobId + * Signature: (J)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentJobId(JNIEnv *, jclass, + jlong); + +/* + * Class: org_ray_runtime_context_NativeWorkerContext + * Method: nativeGetCurrentWorkerId + * Signature: (J)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentWorkerId(JNIEnv *, + jclass, jlong); + +/* + * Class: org_ray_runtime_context_NativeWorkerContext + * Method: nativeGetCurrentActorId + * Signature: (J)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(JNIEnv *, jclass, + jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.cc b/src/ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.cc new file mode 100644 index 000000000..94d71223f --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.cc @@ -0,0 +1,119 @@ +#include "ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.h" +#include +#include "ray/common/id.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/lib/java/jni_utils.h" +#include "ray/core_worker/object_interface.h" + +inline ray::CoreWorkerObjectInterface &GetObjectInterfaceFromPointer( + jlong nativeCoreWorkerPointer) { + return reinterpret_cast(nativeCoreWorkerPointer)->Objects(); +} + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: org_ray_runtime_object_NativeObjectStore + * Method: nativePut + * Signature: (JLorg/ray/runtime/object/NativeRayObject;)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_object_NativeObjectStore_nativePut__JLorg_ray_runtime_object_NativeRayObject_2( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject obj) { + auto ray_object = JavaNativeRayObjectToNativeRayObject(env, obj); + RAY_CHECK(ray_object != nullptr); + ray::ObjectID object_id; + auto status = + GetObjectInterfaceFromPointer(nativeCoreWorkerPointer).Put(*ray_object, &object_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + return IdToJavaByteArray(env, object_id); +} + +/* + * Class: org_ray_runtime_object_NativeObjectStore + * Method: nativePut + * Signature: (J[BLorg/ray/runtime/object/NativeRayObject;)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_object_NativeObjectStore_nativePut__J_3BLorg_ray_runtime_object_NativeRayObject_2( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray objectId, + jobject obj) { + auto object_id = JavaByteArrayToId(env, objectId); + auto ray_object = JavaNativeRayObjectToNativeRayObject(env, obj); + RAY_CHECK(ray_object != nullptr); + auto status = + GetObjectInterfaceFromPointer(nativeCoreWorkerPointer).Put(*ray_object, object_id); + if (status.IsIOError() && + status.message() == "object already exists in the plasma store") { + // Ignore duplicated put on the same object ID. + return; + } + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + +/* + * Class: org_ray_runtime_object_NativeObjectStore + * Method: nativeGet + * Signature: (JLjava/util/List;J)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeGet( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject ids, jlong timeoutMs) { + std::vector object_ids; + JavaListToNativeVector( + env, ids, &object_ids, [](JNIEnv *env, jobject id) { + return JavaByteArrayToId(env, static_cast(id)); + }); + std::vector> results; + auto status = GetObjectInterfaceFromPointer(nativeCoreWorkerPointer) + .Get(object_ids, (int64_t)timeoutMs, &results); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + return NativeVectorToJavaList>( + env, results, NativeRayObjectToJavaNativeRayObject); +} + +/* + * Class: org_ray_runtime_object_NativeObjectStore + * Method: nativeWait + * Signature: (JLjava/util/List;IJ)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeWait( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject objectIds, + jint numObjects, jlong timeoutMs) { + std::vector object_ids; + JavaListToNativeVector( + env, objectIds, &object_ids, [](JNIEnv *env, jobject id) { + return JavaByteArrayToId(env, static_cast(id)); + }); + std::vector results; + auto status = GetObjectInterfaceFromPointer(nativeCoreWorkerPointer) + .Wait(object_ids, (int)numObjects, (int64_t)timeoutMs, &results); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + return NativeVectorToJavaList(env, results, [](JNIEnv *env, const bool &item) { + return env->NewObject(java_boolean_class, java_boolean_init, (jboolean)item); + }); +} + +/* + * Class: org_ray_runtime_object_NativeObjectStore + * Method: nativeDelete + * Signature: (JLjava/util/List;ZZ)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeDelete( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject objectIds, + jboolean localOnly, jboolean deleteCreatingTasks) { + std::vector object_ids; + JavaListToNativeVector( + env, objectIds, &object_ids, [](JNIEnv *env, jobject id) { + return JavaByteArrayToId(env, static_cast(id)); + }); + auto status = GetObjectInterfaceFromPointer(nativeCoreWorkerPointer) + .Delete(object_ids, (bool)localOnly, (bool)deleteCreatingTasks); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + +#ifdef __cplusplus +} +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.h b/src/ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.h new file mode 100644 index 000000000..c41d48434 --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.h @@ -0,0 +1,55 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_ray_runtime_object_NativeObjectStore */ + +#ifndef _Included_org_ray_runtime_object_NativeObjectStore +#define _Included_org_ray_runtime_object_NativeObjectStore +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_ray_runtime_object_NativeObjectStore + * Method: nativePut + * Signature: (JLorg/ray/runtime/object/NativeRayObject;)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_object_NativeObjectStore_nativePut__JLorg_ray_runtime_object_NativeRayObject_2( + JNIEnv *, jclass, jlong, jobject); + +/* + * Class: org_ray_runtime_object_NativeObjectStore + * Method: nativePut + * Signature: (J[BLorg/ray/runtime/object/NativeRayObject;)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_object_NativeObjectStore_nativePut__J_3BLorg_ray_runtime_object_NativeRayObject_2( + JNIEnv *, jclass, jlong, jbyteArray, jobject); + +/* + * Class: org_ray_runtime_object_NativeObjectStore + * Method: nativeGet + * Signature: (JLjava/util/List;J)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeGet( + JNIEnv *, jclass, jlong, jobject, jlong); + +/* + * Class: org_ray_runtime_object_NativeObjectStore + * Method: nativeWait + * Signature: (JLjava/util/List;IJ)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeWait( + JNIEnv *, jclass, jlong, jobject, jint, jlong); + +/* + * Class: org_ray_runtime_object_NativeObjectStore + * Method: nativeDelete + * Signature: (JLjava/util/List;ZZ)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeDelete( + JNIEnv *, jclass, jlong, jobject, jboolean, jboolean); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.cc b/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.cc deleted file mode 100644 index 5c77e7b07..000000000 --- a/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.cc +++ /dev/null @@ -1,151 +0,0 @@ -#include "ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h" -#include -#include "ray/common/id.h" -#include "ray/core_worker/common.h" -#include "ray/core_worker/lib/java/jni_utils.h" -#include "ray/core_worker/object_interface.h" - -using ray::rpc::RayletClient; - -inline ray::CoreWorkerObjectInterface *GetObjectInterfaceFromPointer( - jlong nativeObjectInterfacePointer) { - return reinterpret_cast(nativeObjectInterfacePointer); -} - -#ifdef __cplusplus -extern "C" { -#endif - -/* - * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl - * Method: nativeCreateObjectInterface - * Signature: (JJLjava/lang/String;)J - */ -JNIEXPORT jlong JNICALL -Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeCreateObjectInterface( - JNIEnv *env, jclass, jlong nativeWorkerContext, jlong nativeRayletClient, - jstring storeSocketName) { - return reinterpret_cast(new ray::CoreWorkerObjectInterface( - *reinterpret_cast(nativeWorkerContext), - *reinterpret_cast *>(nativeRayletClient), - JavaStringToNativeString(env, storeSocketName))); -} - -/* - * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl - * Method: nativePut - * Signature: (JLorg/ray/runtime/objectstore/NativeRayObject;)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__JLorg_ray_runtime_objectstore_NativeRayObject_2( - JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject obj) { - ray::Status status; - ray::ObjectID object_id = ReadJavaNativeRayObject( - env, obj, - [nativeObjectInterfacePointer, - &status](const std::shared_ptr &rayObject) { - RAY_CHECK(rayObject != nullptr); - ray::ObjectID object_id; - status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer) - ->Put(*rayObject, &object_id); - return object_id; - }); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - return IdToJavaByteArray(env, object_id); -} - -/* - * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl - * Method: nativePut - * Signature: (J[BLorg/ray/runtime/objectstore/NativeRayObject;)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__J_3BLorg_ray_runtime_objectstore_NativeRayObject_2( - JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jbyteArray objectId, - jobject obj) { - auto object_id = JavaByteArrayToId(env, objectId); - auto status = ReadJavaNativeRayObject( - env, obj, - [nativeObjectInterfacePointer, - &object_id](const std::shared_ptr &rayObject) { - RAY_CHECK(rayObject != nullptr); - return GetObjectInterfaceFromPointer(nativeObjectInterfacePointer) - ->Put(*rayObject, object_id); - }); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -/* - * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl - * Method: nativeGet - * Signature: (JLjava/util/List;J)Ljava/util/List; - */ -JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeGet( - JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject ids, - jlong timeoutMs) { - std::vector object_ids; - JavaListToNativeVector( - env, ids, &object_ids, [](JNIEnv *env, jobject id) { - return JavaByteArrayToId(env, static_cast(id)); - }); - std::vector> results; - auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer) - ->Get(object_ids, (int64_t)timeoutMs, &results); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - return NativeVectorToJavaList>(env, results, - ToJavaNativeRayObject); -} - -/* - * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl - * Method: nativeWait - * Signature: (JLjava/util/List;IJ)Ljava/util/List; - */ -JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeWait( - JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject objectIds, - jint numObjects, jlong timeoutMs) { - std::vector object_ids; - JavaListToNativeVector( - env, objectIds, &object_ids, [](JNIEnv *env, jobject id) { - return JavaByteArrayToId(env, static_cast(id)); - }); - std::vector results; - auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer) - ->Wait(object_ids, (int)numObjects, (int64_t)timeoutMs, &results); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - return NativeVectorToJavaList(env, results, [](JNIEnv *env, const bool &item) { - return env->NewObject(java_boolean_class, java_boolean_init, (jboolean)item); - }); -} - -/* - * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl - * Method: nativeDelete - * Signature: (JLjava/util/List;ZZ)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDelete( - JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject objectIds, - jboolean localOnly, jboolean deleteCreatingTasks) { - std::vector object_ids; - JavaListToNativeVector( - env, objectIds, &object_ids, [](JNIEnv *env, jobject id) { - return JavaByteArrayToId(env, static_cast(id)); - }); - auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer) - ->Delete(object_ids, (bool)localOnly, (bool)deleteCreatingTasks); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -/* - * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl - * Method: nativeDestroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDestroy( - JNIEnv *env, jclass, jlong nativeObjectInterfacePointer) { - delete GetObjectInterfaceFromPointer(nativeObjectInterfacePointer); -} - -#ifdef __cplusplus -} -#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h b/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h deleted file mode 100644 index 0ea41535e..000000000 --- a/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h +++ /dev/null @@ -1,72 +0,0 @@ -/* DO NOT EDIT THIS FILE - it is machine generated */ -#include -/* Header for class org_ray_runtime_objectstore_ObjectInterfaceImpl */ - -#ifndef _Included_org_ray_runtime_objectstore_ObjectInterfaceImpl -#define _Included_org_ray_runtime_objectstore_ObjectInterfaceImpl -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl - * Method: nativeCreateObjectInterface - * Signature: (JJLjava/lang/String;)J - */ -JNIEXPORT jlong JNICALL -Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeCreateObjectInterface( - JNIEnv *, jclass, jlong, jlong, jstring); - -/* - * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl - * Method: nativePut - * Signature: (JLorg/ray/runtime/objectstore/NativeRayObject;)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__JLorg_ray_runtime_objectstore_NativeRayObject_2( - JNIEnv *, jclass, jlong, jobject); - -/* - * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl - * Method: nativePut - * Signature: (J[BLorg/ray/runtime/objectstore/NativeRayObject;)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__J_3BLorg_ray_runtime_objectstore_NativeRayObject_2( - JNIEnv *, jclass, jlong, jbyteArray, jobject); - -/* - * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl - * Method: nativeGet - * Signature: (JLjava/util/List;J)Ljava/util/List; - */ -JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeGet( - JNIEnv *, jclass, jlong, jobject, jlong); - -/* - * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl - * Method: nativeWait - * Signature: (JLjava/util/List;IJ)Ljava/util/List; - */ -JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeWait( - JNIEnv *, jclass, jlong, jobject, jint, jlong); - -/* - * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl - * Method: nativeDelete - * Signature: (JLjava/util/List;ZZ)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDelete( - JNIEnv *, jclass, jlong, jobject, jboolean, jboolean); - -/* - * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl - * Method: nativeDestroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDestroy( - JNIEnv *, jclass, jlong); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc b/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc new file mode 100644 index 000000000..56f0e94ec --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc @@ -0,0 +1,74 @@ +#include "ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h" +#include +#include "ray/common/id.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/lib/java/jni_utils.h" +#include "ray/rpc/raylet/raylet_client.h" + +inline ray::RayletClient &GetRayletClientFromPointer(jlong nativeCoreWorkerPointer) { + return reinterpret_cast(nativeCoreWorkerPointer)->GetRayletClient(); +} + +#ifdef __cplusplus +extern "C" { +#endif + +using ray::ClientID; + +/* + * Class: org_ray_runtime_raylet_NativeRayletClient + * Method: nativePrepareCheckpoint + * Signature: (J[B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_raylet_NativeRayletClient_nativePrepareCheckpoint( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray actorId) { + const auto actor_id = JavaByteArrayToId(env, actorId); + ActorCheckpointID checkpoint_id; + auto status = GetRayletClientFromPointer(nativeCoreWorkerPointer) + .PrepareActorCheckpoint(actor_id, checkpoint_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + jbyteArray result = env->NewByteArray(checkpoint_id.Size()); + env->SetByteArrayRegion(result, 0, checkpoint_id.Size(), + reinterpret_cast(checkpoint_id.Data())); + return result; +} + +/* + * Class: org_ray_runtime_raylet_NativeRayletClient + * Method: nativeNotifyActorResumedFromCheckpoint + * Signature: (J[B[B)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_raylet_NativeRayletClient_nativeNotifyActorResumedFromCheckpoint( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray actorId, + jbyteArray checkpointId) { + const auto actor_id = JavaByteArrayToId(env, actorId); + const auto checkpoint_id = JavaByteArrayToId(env, checkpointId); + auto status = GetRayletClientFromPointer(nativeCoreWorkerPointer) + .NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + +/* + * Class: org_ray_runtime_raylet_NativeRayletClient + * Method: nativeSetResource + * Signature: (JLjava/lang/String;D[B)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_NativeRayletClient_nativeSetResource( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jstring resourceName, + jdouble capacity, jbyteArray nodeId) { + const auto node_id = JavaByteArrayToId(env, nodeId); + const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); + + auto status = + GetRayletClientFromPointer(nativeCoreWorkerPointer) + .SetResource(native_resource_name, static_cast(capacity), node_id); + env->ReleaseStringUTFChars(resourceName, native_resource_name); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + +#ifdef __cplusplus +} +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h b/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h new file mode 100644 index 000000000..0b54300de --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h @@ -0,0 +1,39 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_ray_runtime_raylet_NativeRayletClient */ + +#ifndef _Included_org_ray_runtime_raylet_NativeRayletClient +#define _Included_org_ray_runtime_raylet_NativeRayletClient +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_ray_runtime_raylet_NativeRayletClient + * Method: nativePrepareCheckpoint + * Signature: (J[B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_raylet_NativeRayletClient_nativePrepareCheckpoint(JNIEnv *, jclass, + jlong, jbyteArray); + +/* + * Class: org_ray_runtime_raylet_NativeRayletClient + * Method: nativeNotifyActorResumedFromCheckpoint + * Signature: (J[B[B)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_raylet_NativeRayletClient_nativeNotifyActorResumedFromCheckpoint( + JNIEnv *, jclass, jlong, jbyteArray, jbyteArray); + +/* + * Class: org_ray_runtime_raylet_NativeRayletClient + * Method: nativeSetResource + * Signature: (JLjava/lang/String;D[B)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_NativeRayletClient_nativeSetResource( + JNIEnv *, jclass, jlong, jstring, jdouble, jbyteArray); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc new file mode 100644 index 000000000..ca219ddae --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc @@ -0,0 +1,173 @@ +#include "ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.h" +#include +#include "ray/common/id.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/lib/java/jni_utils.h" +#include "ray/core_worker/task_interface.h" + +inline ray::CoreWorkerTaskInterface &GetTaskInterfaceFromPointer( + jlong nativeCoreWorkerPointer) { + return reinterpret_cast(nativeCoreWorkerPointer)->Tasks(); +} + +inline ray::RayFunction ToRayFunction(JNIEnv *env, jobject functionDescriptor) { + std::vector function_descriptor; + JavaStringListToNativeStringVector( + env, env->CallObjectMethod(functionDescriptor, java_function_descriptor_to_list), + &function_descriptor); + jobject java_language = + env->CallObjectMethod(functionDescriptor, java_function_descriptor_get_language); + int language = env->CallIntMethod(java_language, java_language_get_number); + ray::RayFunction ray_function{static_cast<::Language>(language), function_descriptor}; + return ray_function; +} + +inline std::vector ToTaskArgs(JNIEnv *env, jobject args) { + std::vector task_args; + JavaListToNativeVector( + env, args, &task_args, [](JNIEnv *env, jobject arg) { + auto java_id = env->GetObjectField(arg, java_function_arg_id); + if (java_id) { + auto java_id_bytes = static_cast( + env->CallObjectMethod(java_id, java_base_id_get_bytes)); + return ray::TaskArg::PassByReference( + JavaByteArrayToId(env, java_id_bytes)); + } + auto java_data = + static_cast(env->GetObjectField(arg, java_function_arg_data)); + RAY_CHECK(java_data) << "Both id and data of FunctionArg are null."; + return ray::TaskArg::PassByValue(JavaByteArrayToNativeBuffer(env, java_data)); + }); + return task_args; +} + +inline std::unordered_map ToResources(JNIEnv *env, + jobject java_resources) { + std::unordered_map resources; + if (java_resources) { + jobject entry_set = env->CallObjectMethod(java_resources, java_map_entry_set); + jobject iterator = env->CallObjectMethod(entry_set, java_set_iterator); + while (env->CallBooleanMethod(iterator, java_iterator_has_next)) { + jobject map_entry = env->CallObjectMethod(iterator, java_iterator_next); + std::string key = JavaStringToNativeString( + env, (jstring)env->CallObjectMethod(map_entry, java_map_entry_get_key)); + double value = env->CallDoubleMethod( + env->CallObjectMethod(map_entry, java_map_entry_get_value), + java_double_double_value); + resources.emplace(key, value); + } + } + return resources; +} + +inline ray::TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject callOptions) { + std::unordered_map resources; + if (callOptions) { + jobject java_resources = + env->GetObjectField(callOptions, java_base_task_options_resources); + resources = ToResources(env, java_resources); + } + + ray::TaskOptions task_options{numReturns, resources}; + return task_options; +} + +inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, + jobject actorCreationOptions) { + uint64_t max_reconstructions = 0; + std::unordered_map resources; + std::vector dynamic_worker_options; + if (actorCreationOptions) { + max_reconstructions = static_cast(env->GetIntField( + actorCreationOptions, java_actor_creation_options_max_reconstructions)); + jobject java_resources = + env->GetObjectField(actorCreationOptions, java_base_task_options_resources); + resources = ToResources(env, java_resources); + std::string jvm_options = JavaStringToNativeString( + env, (jstring)env->GetObjectField(actorCreationOptions, + java_actor_creation_options_jvm_options)); + dynamic_worker_options.emplace_back(jvm_options); + } + + ray::ActorCreationOptions action_creation_options{ + static_cast(max_reconstructions), false, resources, + dynamic_worker_options}; + return action_creation_options; +} + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: org_ray_runtime_task_NativeTaskSubmitter + * Method: nativeSubmitTask + * Signature: + * (JLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitTask( + JNIEnv *env, jclass p, jlong nativeCoreWorkerPointer, jobject functionDescriptor, + jobject args, jint numReturns, jobject callOptions) { + auto ray_function = ToRayFunction(env, functionDescriptor); + auto task_args = ToTaskArgs(env, args); + auto task_options = ToTaskOptions(env, numReturns, callOptions); + + std::vector return_ids; + auto status = GetTaskInterfaceFromPointer(nativeCoreWorkerPointer) + .SubmitTask(ray_function, task_args, task_options, &return_ids); + + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + + return NativeIdVectorToJavaByteArrayList(env, return_ids); +} + +/* + * Class: org_ray_runtime_task_NativeTaskSubmitter + * Method: nativeCreateActor + * Signature: + * (JLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;Lorg/ray/api/options/ActorCreationOptions;)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor( + JNIEnv *env, jclass p, jlong nativeCoreWorkerPointer, jobject functionDescriptor, + jobject args, jobject actorCreationOptions) { + auto ray_function = ToRayFunction(env, functionDescriptor); + auto task_args = ToTaskArgs(env, args); + auto actor_creation_options = ToActorCreationOptions(env, actorCreationOptions); + + std::unique_ptr actor_handle; + auto status = + GetTaskInterfaceFromPointer(nativeCoreWorkerPointer) + .CreateActor(ray_function, task_args, actor_creation_options, &actor_handle); + + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, 0); + return reinterpret_cast(actor_handle.release()); +} + +/* + * Class: org_ray_runtime_task_NativeTaskSubmitter + * Method: nativeSubmitActorTask + * Signature: + * (JJLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL +Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask( + JNIEnv *env, jclass p, jlong nativeCoreWorkerPointer, jlong nativeActorHandle, + jobject functionDescriptor, jobject args, jint numReturns, jobject callOptions) { + auto &actor_handle = *(reinterpret_cast(nativeActorHandle)); + auto ray_function = ToRayFunction(env, functionDescriptor); + auto task_args = ToTaskArgs(env, args); + auto task_options = ToTaskOptions(env, numReturns, callOptions); + + std::vector return_ids; + auto status = GetTaskInterfaceFromPointer(nativeCoreWorkerPointer) + .SubmitActorTask(actor_handle, ray_function, task_args, task_options, + &return_ids); + + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + return NativeIdVectorToJavaByteArrayList(env, return_ids); +} + +#ifdef __cplusplus +} +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.h b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.h new file mode 100644 index 000000000..908b5eee4 --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.h @@ -0,0 +1,43 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_ray_runtime_task_NativeTaskSubmitter */ + +#ifndef _Included_org_ray_runtime_task_NativeTaskSubmitter +#define _Included_org_ray_runtime_task_NativeTaskSubmitter +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_ray_runtime_task_NativeTaskSubmitter + * Method: nativeSubmitTask + * Signature: + * (JLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitTask( + JNIEnv *, jclass, jlong, jobject, jobject, jint, jobject); + +/* + * Class: org_ray_runtime_task_NativeTaskSubmitter + * Method: nativeCreateActor + * Signature: + * (JLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;Lorg/ray/api/options/ActorCreationOptions;)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor( + JNIEnv *, jclass, jlong, jobject, jobject, jobject); + +/* + * Class: org_ray_runtime_task_NativeTaskSubmitter + * Method: nativeSubmitActorTask + * Signature: + * (JJLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL +Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(JNIEnv *, jclass, + jlong, jlong, jobject, + jobject, jint, + jobject); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/src/ray/core_worker/task_execution.cc b/src/ray/core_worker/task_execution.cc index c143c98a7..f397ab314 100644 --- a/src/ray/core_worker/task_execution.cc +++ b/src/ray/core_worker/task_execution.cc @@ -13,7 +13,8 @@ CoreWorkerTaskExecutionInterface::CoreWorkerTaskExecutionInterface( object_interface_(object_interface), execution_callback_(executor), worker_server_("Worker", 0 /* let grpc choose port */), - main_work_(main_service_) { + main_service_(std::make_shared()), + main_work_(*main_service_) { RAY_CHECK(execution_callback_ != nullptr); auto func = std::bind(&CoreWorkerTaskExecutionInterface::ExecuteTask, this, @@ -21,11 +22,11 @@ CoreWorkerTaskExecutionInterface::CoreWorkerTaskExecutionInterface( task_receivers_.emplace( TaskTransportType::RAYLET, std::unique_ptr(new CoreWorkerRayletTaskReceiver( - raylet_client, object_interface_, main_service_, worker_server_, func))); + raylet_client, object_interface_, *main_service_, worker_server_, func))); task_receivers_.emplace( TaskTransportType::DIRECT_ACTOR, std::unique_ptr( - new CoreWorkerDirectActorTaskReceiver(object_interface_, main_service_, + new CoreWorkerDirectActorTaskReceiver(object_interface_, *main_service_, worker_server_, func))); // Start RPC server after all the task receivers are properly initialized. @@ -35,6 +36,8 @@ CoreWorkerTaskExecutionInterface::CoreWorkerTaskExecutionInterface( Status CoreWorkerTaskExecutionInterface::ExecuteTask( const TaskSpecification &task_spec, std::vector> *results) { + RAY_LOG(DEBUG) << "Executing task " << task_spec.TaskId(); + worker_context_.SetCurrentTask(task_spec); RayFunction func{task_spec.GetLanguage(), task_spec.FunctionDescriptor()}; @@ -42,17 +45,6 @@ Status CoreWorkerTaskExecutionInterface::ExecuteTask( std::vector> args; RAY_CHECK_OK(BuildArgsForExecutor(task_spec, &args)); - TaskType task_type; - if (task_spec.IsActorCreationTask()) { - task_type = TaskType::ACTOR_CREATION_TASK; - } else if (task_spec.IsActorTask()) { - task_type = TaskType::ACTOR_TASK; - } else { - task_type = TaskType::NORMAL_TASK; - } - - TaskInfo task_info{task_spec.TaskId(), task_spec.JobId(), task_type}; - auto num_returns = task_spec.NumReturns(); if (task_spec.IsActorCreationTask() || task_spec.IsActorTask()) { RAY_CHECK(num_returns > 0); @@ -60,7 +52,7 @@ Status CoreWorkerTaskExecutionInterface::ExecuteTask( num_returns--; } - auto status = execution_callback_(func, args, task_info, num_returns, results); + auto status = execution_callback_(func, args, num_returns, results); // TODO(zhijunfu): // 1. Check and handle failure. // 2. Save or load checkpoint. @@ -69,10 +61,15 @@ Status CoreWorkerTaskExecutionInterface::ExecuteTask( void CoreWorkerTaskExecutionInterface::Run() { // Run main IO service. - main_service_.run(); + main_service_->run(); +} - // should never reach here. - RAY_LOG(FATAL) << "should never reach here after running main io service"; +void CoreWorkerTaskExecutionInterface::Stop() { + // Stop main IO service. + std::shared_ptr main_service = main_service_; + // Delay the execution of io_service::stop() to avoid deadlock if + // CoreWorkerTaskExecutionInterface::Stop is called inside a task. + main_service_->post([main_service]() { main_service->stop(); }); } Status CoreWorkerTaskExecutionInterface::BuildArgsForExecutor( diff --git a/src/ray/core_worker/task_execution.h b/src/ray/core_worker/task_execution.h index c8b164642..6e25680a1 100644 --- a/src/ray/core_worker/task_execution.h +++ b/src/ray/core_worker/task_execution.h @@ -27,23 +27,26 @@ class CoreWorkerTaskExecutionInterface { /// /// \param ray_function[in] Information about the function to execute. /// \param args[in] Arguments of the task. - /// \param task_info[in] Information of the task to execute. /// \param results[out] Results of the task execution. /// \return Status. using TaskExecutor = std::function> &args, const TaskInfo &task_info, - int num_returns, std::vector> *results)>; + const std::vector> &args, int num_returns, + std::vector> *results)>; CoreWorkerTaskExecutionInterface(WorkerContext &worker_context, std::unique_ptr &raylet_client, CoreWorkerObjectInterface &object_interface, const TaskExecutor &executor); - /// Start receving and executes tasks in a infinite loop. + /// Start receiving and executing tasks. /// \return void. void Run(); + /// Stop receiving and executing tasks. + /// \return void. + void Stop(); + private: /// Build arguments for task executor. This would loop through all the arguments /// in task spec, and for each of them that's passed by reference (ObjectID), @@ -80,7 +83,7 @@ class CoreWorkerTaskExecutionInterface { rpc::GrpcServer worker_server_; /// Event loop where tasks are processed. - boost::asio::io_service main_service_; + std::shared_ptr main_service_; /// The asio work to keep main_service_ alive. boost::asio::io_service::work main_work_; diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc index 9820bf385..e8e347fd1 100644 --- a/src/ray/core_worker/task_interface.cc +++ b/src/ray/core_worker/task_interface.cc @@ -171,7 +171,7 @@ Status CoreWorkerTaskInterface::CreateActor( actor_creation_options.resources, actor_creation_options.resources, TaskTransportType::RAYLET, &return_ids); builder.SetActorCreationTaskSpec(actor_id, actor_creation_options.max_reconstructions, - {}); + actor_creation_options.dynamic_worker_options); *actor_handle = std::unique_ptr(new ActorHandle( actor_id, ActorHandleID::Nil(), function.language, diff --git a/src/ray/core_worker/task_interface.h b/src/ray/core_worker/task_interface.h index d017f8694..bca4011d0 100644 --- a/src/ray/core_worker/task_interface.h +++ b/src/ray/core_worker/task_interface.h @@ -37,10 +37,12 @@ struct TaskOptions { struct ActorCreationOptions { ActorCreationOptions() {} ActorCreationOptions(uint64_t max_reconstructions, bool is_direct_call, - const std::unordered_map &resources) + const std::unordered_map &resources, + const std::vector &dynamic_worker_options) : max_reconstructions(max_reconstructions), is_direct_call(is_direct_call), - resources(resources) {} + resources(resources), + dynamic_worker_options(dynamic_worker_options) {} /// Maximum number of times that the actor should be reconstructed when it dies /// unexpectedly. It must be non-negative. If it's 0, the actor won't be reconstructed. @@ -50,6 +52,9 @@ struct ActorCreationOptions { const bool is_direct_call = false; /// Resources required by the whole lifetime of this actor. const std::unordered_map resources; + /// The dynamic options used in the worker command when starting a worker process for + /// an actor creation task. + const std::vector dynamic_worker_options; }; /// A handle to an actor. diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 7b18aac3a..d18fc0969 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -62,7 +62,7 @@ std::unique_ptr CreateActorHelper( std::vector args; args.emplace_back(TaskArg::PassByValue(buffer)); - ActorCreationOptions actor_options{max_reconstructions, is_direct_call, resources}; + ActorCreationOptions actor_options{max_reconstructions, is_direct_call, resources, {}}; // Create an actor. RAY_CHECK_OK(worker.Tasks().CreateActor(func, args, actor_options, &actor_handle)); @@ -586,7 +586,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { args.emplace_back(TaskArg::PassByValue(buffer)); std::unordered_map resources; - ActorCreationOptions actor_options{0, /* is_direct_call */ true, resources}; + ActorCreationOptions actor_options{0, /*is_direct_call*/ true, resources, {}}; const auto job_id = NextJobId(); ActorHandle actor_handle(ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1), ActorHandleID::Nil(), function.language, true, @@ -647,7 +647,7 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) { args.emplace_back(TaskArg::PassByValue(buffer)); std::unordered_map resources; - ActorCreationOptions actor_options{0, /* is_direct_call */ true, resources}; + ActorCreationOptions actor_options{0, /*is_direct_call*/ true, resources, {}}; // Create an actor. RAY_CHECK_OK(driver.Tasks().CreateActor(func, args, actor_options, &actor_handle)); // wait for actor creation finish. diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index d4f7a7abd..1db55b3fc 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -24,7 +24,7 @@ class MockWorker { const gcs::GcsClientOptions &gcs_options) : worker_(WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket, JobID::FromInt(1), gcs_options, - std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5)) {} + std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4)) {} void Run() { // Start executing tasks. @@ -33,8 +33,7 @@ class MockWorker { private: Status ExecuteTask(const RayFunction &ray_function, - const std::vector> &args, - const TaskInfo &task_info, int num_returns, + const std::vector> &args, int num_returns, std::vector> *results) { // Note that this doesn't include dummy object id. RAY_CHECK(num_returns >= 0); diff --git a/src/ray/core_worker/transport/raylet_transport.cc b/src/ray/core_worker/transport/raylet_transport.cc index 4dbb70bee..c6bd9ab8e 100644 --- a/src/ray/core_worker/transport/raylet_transport.cc +++ b/src/ray/core_worker/transport/raylet_transport.cc @@ -29,6 +29,7 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask( rpc::SendReplyCallback send_reply_callback) { const Task task(request.task()); const auto &task_spec = task.GetTaskSpecification(); + RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId(); std::vector> results; auto status = task_handler_(task_spec, &results); @@ -39,12 +40,23 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask( num_returns--; } + RAY_LOG(DEBUG) << "Assigned task " << task_spec.TaskId() + << " finished execution. num_returns: " << num_returns; RAY_CHECK(results.size() == num_returns); for (size_t i = 0; i < num_returns; i++) { ObjectID id = ObjectID::ForTaskReturn( task_spec.TaskId(), /*index=*/i + 1, /*transport_type=*/static_cast(TaskTransportType::RAYLET)); - RAY_CHECK_OK(object_interface_.Put(*results[i], id)); + Status status = object_interface_.Put(*results[i], id); + if (!status.ok()) { + // TODO (kfstorm): RAY_LOG(FATAL) except the error is about the object to put + // already exists. + RAY_LOG(WARNING) << "Task " << task_spec.TaskId() << " failed to put object " << id + << " in store: " << status.message(); + } else { + RAY_LOG(DEBUG) << "Task " << task_spec.TaskId() << " put object " << id + << " in store."; + } } // Notify raylet that current task is done via a `TaskDone` message. This is to diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index ae6cb6088..e45395a58 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -188,6 +188,17 @@ Status AuthenticateRedis(redisAsyncContext *context, const std::string &password return Status::OK(); } +void RedisAsyncContextDisconnectCallback(const redisAsyncContext *context, int status) { + RAY_LOG(WARNING) << "Redis async context disconnected. Status: " << status; + reinterpret_cast(context->data) + ->AsyncDisconnectCallback(context, status); +} + +void SetDisconnectCallback(RedisContext *redis_context, redisAsyncContext *context) { + context->data = redis_context; + redisAsyncSetDisconnectCallback(context, RedisAsyncContextDisconnectCallback); +} + template Status ConnectWithRetries(const std::string &address, int port, const RedisConnectFunction &connect_function, @@ -216,6 +227,10 @@ Status ConnectWithRetries(const std::string &address, int port, Status RedisContext::Connect(const std::string &address, int port, bool sharding, const std::string &password = "") { + RAY_CHECK(!context_); + RAY_CHECK(!async_context_); + RAY_CHECK(!subscribe_context_); + RAY_CHECK_OK(ConnectWithRetries(address, port, redisConnect, &context_)); RAY_CHECK_OK(AuthenticateRedis(context_, password)); @@ -226,10 +241,12 @@ Status RedisContext::Connect(const std::string &address, int port, bool sharding // Connect to async context RAY_CHECK_OK(ConnectWithRetries(address, port, redisAsyncConnect, &async_context_)); + SetDisconnectCallback(this, async_context_); RAY_CHECK_OK(AuthenticateRedis(async_context_, password)); // Connect to subscribe context RAY_CHECK_OK(ConnectWithRetries(address, port, redisAsyncConnect, &subscribe_context_)); + SetDisconnectCallback(this, subscribe_context_); RAY_CHECK_OK(AuthenticateRedis(subscribe_context_, password)); return Status::OK(); @@ -245,6 +262,7 @@ Status RedisContext::AttachToEventLoop(aeEventLoop *loop) { } Status RedisContext::RunArgvAsync(const std::vector &args) { + RAY_CHECK(async_context_); // Build the arguments. std::vector argv; std::vector argc; @@ -268,6 +286,7 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id, int64_t *out_callback_index) { RAY_CHECK(pubsub_channel != TablePubsub::NO_PUBLISH) << "Client requested subscribe on a table that does not support pubsub"; + RAY_CHECK(subscribe_context_); int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, true); RAY_CHECK(out_callback_index != nullptr); @@ -294,6 +313,15 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id, return Status::OK(); } +void RedisContext::AsyncDisconnectCallback(const redisAsyncContext *context, int status) { + if (context == async_context_) { + async_context_ = nullptr; + } + if (context == subscribe_context_) { + subscribe_context_ = nullptr; + } +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 093aab245..0c81bc7f8 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -149,9 +149,25 @@ class RedisContext { /// \return Status. Status SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel, const RedisCallback &redisCallback, int64_t *out_callback_index); - redisContext *sync_context() { return context_; } - redisAsyncContext *async_context() { return async_context_; } - redisAsyncContext *subscribe_context() { return subscribe_context_; }; + + /// Called when an instance of redisAsyncContext is disconnected. + /// + /// \param context the redisAsyncContext instances + /// \param status The status code of disconnection + void AsyncDisconnectCallback(const redisAsyncContext *context, int status); + + redisContext *sync_context() { + RAY_CHECK(context_); + return context_; + } + redisAsyncContext *async_context() { + RAY_CHECK(async_context_); + return async_context_; + } + redisAsyncContext *subscribe_context() { + RAY_CHECK(subscribe_context_); + return subscribe_context_; + }; private: redisContext *context_; @@ -164,6 +180,7 @@ Status RedisContext::RunAsync(const std::string &command, const ID &id, const vo size_t length, const TablePrefix prefix, const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length) { + RAY_CHECK(async_context_); int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); if (length > 0) { if (log_length >= 0) { diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 6c6b01b5e..46f740f40 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -95,7 +95,7 @@ message ActorCreationTaskSpec { // The max number of times this actor should be recontructed. // If this number of 0 or negative, the actor won't be reconstructed on failure. uint64 max_actor_reconstructions = 3; - // The dynamic options used in the worker command when starting the worker process for + // The dynamic options used in the worker command when starting a worker process for // an actor creation task. If the list isn't empty, the options will be used to replace // the placeholder strings (`RAY_WORKER_OPTION_0`, `RAY_WORKER_OPTION_1`, etc) in the // worker command. diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc deleted file mode 100644 index 1cf9e1620..000000000 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ /dev/null @@ -1,296 +0,0 @@ -#include "ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h" - -#include - -#include "ray/common/id.h" -#include "ray/core_worker/lib/java/jni_utils.h" -#include "ray/rpc/raylet/raylet_client.h" -#include "ray/util/logging.h" - -#ifdef __cplusplus -extern "C" { -#endif - -using ray::ClientID; -using ray::WorkerID; -using ray::rpc::RayletClient; - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeInit - * Signature: (Ljava/lang/String;[BZ[B)J - */ -JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( - JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker, - jbyteArray jobId) { - const auto worker_id = JavaByteArrayToId(env, workerId); - const auto job_id = JavaByteArrayToId(env, jobId); - const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE); - auto raylet_client = new std::unique_ptr( - new RayletClient(nativeString, worker_id, isWorker, job_id, Language::JAVA)); - env->ReleaseStringUTFChars(sockName, nativeString); - return reinterpret_cast(raylet_client); -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeSubmitTask - * Signature: (J[BLjava/nio/ByteBuffer;II)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask( - JNIEnv *env, jclass, jlong client, jbyteArray taskSpec) { - auto &raylet_client = *reinterpret_cast *>(client); - - jbyte *data = env->GetByteArrayElements(taskSpec, NULL); - jsize size = env->GetArrayLength(taskSpec); - ray::rpc::TaskSpec task_spec_message; - task_spec_message.ParseFromArray(data, size); - env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT); - - ray::TaskSpecification task_spec(task_spec_message); - auto status = raylet_client->SubmitTask(task_spec); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGetTask - * Signature: (J)[B - */ -JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask( - JNIEnv *env, jclass, jlong client) { - auto &raylet_client = *reinterpret_cast *>(client); - - std::unique_ptr spec; - auto status = raylet_client->GetTask(&spec); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - - // Serialize the task spec and copy to Java byte array. - auto task_data = spec->Serialize(); - - jbyteArray result = env->NewByteArray(task_data.size()); - if (result == nullptr) { - return nullptr; /* out of memory error thrown */ - } - - env->SetByteArrayRegion(result, 0, task_data.size(), - reinterpret_cast(task_data.data())); - - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeDestroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy( - JNIEnv *env, jclass, jlong client) { - auto raylet_client = reinterpret_cast *>(client); - auto status = (*raylet_client)->Disconnect(); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); - delete raylet_client; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeWaitObject - * Signature: (J[[BIIZ[B)[Z - */ -JNIEXPORT jbooleanArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject( - JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jint numReturns, - jint timeoutMillis, jboolean isWaitLocal, jbyteArray currentTaskId) { - std::vector object_ids; - auto len = env->GetArrayLength(objectIds); - for (int i = 0; i < len; i++) { - jbyteArray object_id_bytes = - static_cast(env->GetObjectArrayElement(objectIds, i)); - const auto object_id = JavaByteArrayToId(env, object_id_bytes); - object_ids.push_back(object_id); - env->DeleteLocalRef(object_id_bytes); - } - const auto current_task_id = JavaByteArrayToId(env, currentTaskId); - - auto &raylet_client = *reinterpret_cast *>(client); - - // Invoke wait. - WaitResultPair result; - auto status = - raylet_client->Wait(object_ids, numReturns, timeoutMillis, - static_cast(isWaitLocal), current_task_id, &result); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - - // Convert result to java object. - jboolean put_value = true; - jbooleanArray resultArray = env->NewBooleanArray(object_ids.size()); - for (uint i = 0; i < result.first.size(); ++i) { - for (uint j = 0; j < object_ids.size(); ++j) { - if (result.first[i] == object_ids[j]) { - env->SetBooleanArrayRegion(resultArray, j, 1, &put_value); - break; - } - } - } - - put_value = false; - for (uint i = 0; i < result.second.size(); ++i) { - for (uint j = 0; j < object_ids.size(); ++j) { - if (result.second[i] == object_ids[j]) { - env->SetBooleanArrayRegion(resultArray, j, 1, &put_value); - break; - } - } - } - return resultArray; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateActorCreationTaskId - * Signature: ([B[BI)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorCreationTaskId( - JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, - jint parent_task_counter) { - const auto job_id = JavaByteArrayToId(env, jobId); - const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); - - const ActorID actor_id = ray::ActorID::Of(job_id, parent_task_id, parent_task_counter); - const TaskID actor_creation_task_id = ray::TaskID::ForActorCreationTask(actor_id); - jbyteArray result = env->NewByteArray(actor_creation_task_id.Size()); - if (nullptr == result) { - return nullptr; - } - env->SetByteArrayRegion(result, 0, actor_creation_task_id.Size(), - reinterpret_cast(actor_creation_task_id.Data())); - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateActorTaskId - * Signature: ([B[BI[B)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorTaskId( - JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, - jint parent_task_counter, jbyteArray actorId) { - const auto job_id = JavaByteArrayToId(env, jobId); - const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); - const auto actor_id = JavaByteArrayToId(env, actorId); - const TaskID actor_task_id = - ray::TaskID::ForActorTask(job_id, parent_task_id, parent_task_counter, actor_id); - - jbyteArray result = env->NewByteArray(actor_task_id.Size()); - if (nullptr == result) { - return nullptr; - } - env->SetByteArrayRegion(result, 0, actor_task_id.Size(), - reinterpret_cast(actor_task_id.Data())); - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateNormalTaskId - * Signature: ([B[BI)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateNormalTaskId( - JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, - jint parent_task_counter) { - const auto job_id = JavaByteArrayToId(env, jobId); - const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); - const TaskID task_id = - ray::TaskID::ForNormalTask(job_id, parent_task_id, parent_task_counter); - - jbyteArray result = env->NewByteArray(task_id.Size()); - if (nullptr == result) { - return nullptr; - } - env->SetByteArrayRegion(result, 0, task_id.Size(), - reinterpret_cast(task_id.Data())); - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeFreePlasmaObjects - * Signature: (J[[BZZ)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects( - JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean localOnly, - jboolean deleteCreatingTasks) { - std::vector object_ids; - auto len = env->GetArrayLength(objectIds); - for (int i = 0; i < len; i++) { - jbyteArray object_id_bytes = - static_cast(env->GetObjectArrayElement(objectIds, i)); - const auto object_id = JavaByteArrayToId(env, object_id_bytes); - object_ids.push_back(object_id); - env->DeleteLocalRef(object_id_bytes); - } - auto &raylet_client = *reinterpret_cast *>(client); - auto status = raylet_client->FreeObjects(object_ids, localOnly, deleteCreatingTasks); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativePrepareCheckpoint - * Signature: (J[B)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env, jclass, - jlong client, - jbyteArray actorId) { - auto &raylet_client = *reinterpret_cast *>(client); - const auto actor_id = JavaByteArrayToId(env, actorId); - ActorCheckpointID checkpoint_id; - auto status = raylet_client->PrepareActorCheckpoint(actor_id, checkpoint_id); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - jbyteArray result = env->NewByteArray(checkpoint_id.Size()); - env->SetByteArrayRegion(result, 0, checkpoint_id.Size(), - reinterpret_cast(checkpoint_id.Data())); - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeNotifyActorResumedFromCheckpoint - * Signature: (J[B[B)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint( - JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) { - auto &raylet_client = *reinterpret_cast *>(client); - const auto actor_id = JavaByteArrayToId(env, actorId); - const auto checkpoint_id = JavaByteArrayToId(env, checkpointId); - auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeSetResource - * Signature: (JLjava/lang/String;D[B)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource( - JNIEnv *env, jclass, jlong client, jstring resourceName, jdouble capacity, - jbyteArray nodeId) { - auto &raylet_client = *reinterpret_cast *>(client); - const auto node_id = JavaByteArrayToId(env, nodeId); - const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); - - auto status = raylet_client->SetResource(native_resource_name, - static_cast(capacity), node_id); - env->ReleaseStringUTFChars(resourceName, native_resource_name); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -#ifdef __cplusplus -} -#endif diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h deleted file mode 100644 index 8b8237e29..000000000 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h +++ /dev/null @@ -1,121 +0,0 @@ -/* DO NOT EDIT THIS FILE - it is machine generated */ -#include -/* Header for class org_ray_runtime_raylet_RayletClientImpl */ - -#ifndef _Included_org_ray_runtime_raylet_RayletClientImpl -#define _Included_org_ray_runtime_raylet_RayletClientImpl -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeInit - * Signature: (Ljava/lang/String;[BZ[B)J - */ -JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( - JNIEnv *, jclass, jstring, jbyteArray, jboolean, jbyteArray); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeSubmitTask - * Signature: (J[B)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask( - JNIEnv *, jclass, jlong, jbyteArray); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGetTask - * Signature: (J)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask(JNIEnv *, jclass, jlong); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeDestroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy(JNIEnv *, jclass, jlong); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeWaitObject - * Signature: (J[[BIIZ[B)[Z - */ -JNIEXPORT jbooleanArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject(JNIEnv *, jclass, jlong, - jobjectArray, jint, jint, - jboolean, jbyteArray); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateActorCreationTaskId - * Signature: ([B[BI)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorCreationTaskId( - JNIEnv *, jclass, jbyteArray, jbyteArray, jint); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateActorTaskId - * Signature: ([B[BI[B)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorTaskId(JNIEnv *, jclass, - jbyteArray, - jbyteArray, jint, - jbyteArray); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateNormalTaskId - * Signature: ([B[BI)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateNormalTaskId(JNIEnv *, jclass, - jbyteArray, - jbyteArray, jint); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeFreePlasmaObjects - * Signature: (J[[BZZ)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects(JNIEnv *, jclass, - jlong, jobjectArray, - jboolean, jboolean); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativePrepareCheckpoint - * Signature: (J[B)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *, jclass, - jlong, jbyteArray); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeNotifyActorResumedFromCheckpoint - * Signature: (J[B[B)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint( - JNIEnv *, jclass, jlong, jbyteArray, jbyteArray); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeSetResource - * Signature: (JLjava/lang/String;D[B)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource( - JNIEnv *, jclass, jlong, jstring, jdouble, jbyteArray); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 4b2e21f77..82ce87f44 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -797,11 +797,12 @@ void NodeManager::HandleRegisterClientRequest( << ". Is worker: " << is_worker << ". Worker pid " << request.worker_pid(); + Status status; if (is_worker) { // Register the new worker. bool use_push_task = worker->UsePush(); - worker_pool_.RegisterWorker(worker_id, std::move(worker)); - if (use_push_task) { + status = worker_pool_.RegisterWorker(worker_id, std::move(worker)); + if (status.ok() && use_push_task) { // only call `HandleWorkerAvailable` when push mode is used. HandleWorkerAvailable(worker_id); } @@ -811,13 +812,15 @@ void NodeManager::HandleRegisterClientRequest( auto job_id = JobID::FromBinary(request.job_id()); worker->AssignTaskId(driver_task_id); worker->AssignJobId(job_id); - worker_pool_.RegisterDriver(worker_id, std::move(worker)); - local_queues_.AddDriverTaskId(driver_task_id); - RAY_CHECK_OK(gcs_client_->job_table().AppendJobData( - job_id, /*is_dead=*/false, std::time(nullptr), - initial_config_.node_manager_address, request.worker_pid())); + status = worker_pool_.RegisterDriver(worker_id, std::move(worker)); + if (status.ok()) { + local_queues_.AddDriverTaskId(driver_task_id); + RAY_CHECK_OK(gcs_client_->job_table().AppendJobData( + job_id, /*is_dead=*/false, std::time(nullptr), + initial_config_.node_manager_address, request.worker_pid())); + } } - send_reply_callback(Status::OK(), nullptr, nullptr); + send_reply_callback(status, nullptr, nullptr); } void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_local, diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index d1d215db6..49624d4bd 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -141,6 +141,13 @@ void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set) // and assigning new task will be done when raylet receives // `TaskDone` message. }); + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to assign task " << task.GetTaskSpecification().TaskId() + << " to worker " << worker_id_; + } else { + RAY_LOG(DEBUG) << "Assigned task " << task.GetTaskSpecification().TaskId() + << " to worker " << worker_id_; + } } else { // Use pull mode. This corresponds to existing python/java workers that haven't been // migrated to core worker architecture. diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 3d6e4cec8..4d49d31dc 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -166,30 +166,33 @@ pid_t WorkerPool::StartProcess(const std::vector &worker_command_ar return 0; } -void WorkerPool::RegisterWorker(const WorkerID &worker_id, - const std::shared_ptr &worker) { +Status WorkerPool::RegisterWorker(const WorkerID &worker_id, + const std::shared_ptr &worker) { const auto pid = worker->Pid(); const auto port = worker->Port(); RAY_LOG(DEBUG) << "Registering worker with pid " << pid << ", port: " << port; auto &state = GetStateForLanguage(worker->GetLanguage()); - state.registered_workers.emplace(worker_id, std::move(worker)); auto it = state.starting_worker_processes.find(pid); if (it == state.starting_worker_processes.end()) { RAY_LOG(WARNING) << "Received a register request from an unknown worker " << pid; - return; + return Status::Invalid("Unknown worker"); } it->second--; if (it->second == 0) { state.starting_worker_processes.erase(it); } + + state.registered_workers.emplace(worker_id, std::move(worker)); + return Status::OK(); } -void WorkerPool::RegisterDriver(const WorkerID &driver_id, - const std::shared_ptr &driver) { +Status WorkerPool::RegisterDriver(const WorkerID &driver_id, + const std::shared_ptr &driver) { RAY_CHECK(!driver->GetAssignedTaskId().IsNil()); auto &state = GetStateForLanguage(driver->GetLanguage()); state.registered_drivers.emplace(driver_id, std::move(driver)); + return Status::OK(); } std::shared_ptr WorkerPool::GetRegisteredWorker(const WorkerID &worker_id) const { diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 0d5d094a9..44d8c5714 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -51,13 +51,15 @@ class WorkerPool { /// pool after it becomes idle (e.g., requests a work assignment). /// /// \param The Worker to be registered. - void RegisterWorker(const WorkerID &worker_id, const std::shared_ptr &worker); + /// \return If the registration is successful. + Status RegisterWorker(const WorkerID &worker_id, const std::shared_ptr &worker); /// Register a new driver. /// Driver is a treated as a special worker, so use WorkerID as key here. /// /// \param The driver to be registered. - void RegisterDriver(const WorkerID &driver_id, const std::shared_ptr &worker); + /// \return If the registration is successful. + Status RegisterDriver(const WorkerID &driver_id, const std::shared_ptr &worker); /// Get the client connection's registered worker. /// diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 1cd7031b3..3a4c4ad14 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -123,7 +123,7 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) { ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), 1); // Check that we cannot lookup the worker before it's registered. ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker_id), nullptr); - worker_pool_.RegisterWorker(worker_id, worker); + RAY_CHECK_OK(worker_pool_.RegisterWorker(worker_id, worker)); // Check that we can lookup the worker after it's registered. ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker_id), worker); } diff --git a/src/ray/util/logging.cc b/src/ray/util/logging.cc index 3c5ff7f4e..70a5543d1 100644 --- a/src/ray/util/logging.cc +++ b/src/ray/util/logging.cc @@ -138,14 +138,15 @@ void RayLog::StartRayLog(const std::string &app_name, RayLogLevel severity_thres log_dir_ = log_dir; #ifdef RAY_USE_GLOG google::InitGoogleLogging(app_name_.c_str()); - google::SetStderrLogging(GetMappedSeverity(RayLogLevel::ERROR)); - for (int i = static_cast(severity_threshold_); - i <= static_cast(RayLogLevel::FATAL); ++i) { - int level = GetMappedSeverity(static_cast(i)); - google::base::SetLogger(level, &stdout_logger_singleton); - } - // Enable log file if log_dir_ is not empty. - if (!log_dir_.empty()) { + if (log_dir_.empty()) { + google::SetStderrLogging(GetMappedSeverity(RayLogLevel::ERROR)); + for (int i = static_cast(severity_threshold_); + i <= static_cast(RayLogLevel::FATAL); ++i) { + int level = GetMappedSeverity(static_cast(i)); + google::base::SetLogger(level, &stdout_logger_singleton); + } + } else { + // Enable log file if log_dir_ is not empty. auto dir_ends_with_slash = log_dir_; if (log_dir_[log_dir_.length() - 1] != '/') { dir_ends_with_slash += "/"; @@ -161,11 +162,8 @@ void RayLog::StartRayLog(const std::string &app_name, RayLogLevel severity_thres } } google::SetLogFilenameExtension(app_name_without_path.c_str()); - for (int i = static_cast(severity_threshold_); - i <= static_cast(RayLogLevel::FATAL); ++i) { - int level = GetMappedSeverity(static_cast(i)); - google::SetLogDestination(level, dir_ends_with_slash.c_str()); - } + int level = GetMappedSeverity(severity_threshold_); + google::SetLogDestination(level, dir_ends_with_slash.c_str()); } #endif }