diff --git a/java/api/src/main/java/org/ray/api/RayCall.java b/java/api/src/main/java/org/ray/api/RayCall.java index 967830199..96735fa70 100644 --- a/java/api/src/main/java/org/ray/api/RayCall.java +++ b/java/api/src/main/java/org/ray/api/RayCall.java @@ -2,7 +2,6 @@ package org.ray.api; -import org.ray.api.function.RayFunc; import org.ray.api.function.RayFunc0; import org.ray.api.function.RayFunc1; import org.ray.api.function.RayFunc2; @@ -11,7 +10,6 @@ import org.ray.api.function.RayFunc4; import org.ray.api.function.RayFunc5; import org.ray.api.function.RayFunc6; import org.ray.api.options.ActorCreationOptions; -import org.ray.api.options.BaseTaskOptions; import org.ray.api.options.CallOptions; /** @@ -2312,4 +2310,143 @@ class RayCall { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; return Ray.internal().createActor(f, args, options); } + // =========================== + // Cross-language methods. + // =========================== + public static RayObject callPy(String moduleName, String functionName) { + Object[] args = new Object[]{}; + return Ray.internal().callPy(moduleName, functionName, args, null); + } + public static RayObject callPy(String moduleName, String functionName, CallOptions options) { + Object[] args = new Object[]{}; + return Ray.internal().callPy(moduleName, functionName, args, options); + } + public static RayObject callPy(String moduleName, String functionName, Object obj0) { + Object[] args = new Object[]{obj0}; + return Ray.internal().callPy(moduleName, functionName, args, null); + } + public static RayObject callPy(String moduleName, String functionName, Object obj0, CallOptions options) { + Object[] args = new Object[]{obj0}; + return Ray.internal().callPy(moduleName, functionName, args, options); + } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1) { + Object[] args = new Object[]{obj0, obj1}; + return Ray.internal().callPy(moduleName, functionName, args, null); + } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, CallOptions options) { + Object[] args = new Object[]{obj0, obj1}; + return Ray.internal().callPy(moduleName, functionName, args, options); + } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2) { + Object[] args = new Object[]{obj0, obj1, obj2}; + return Ray.internal().callPy(moduleName, functionName, args, null); + } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, CallOptions options) { + Object[] args = new Object[]{obj0, obj1, obj2}; + return Ray.internal().callPy(moduleName, functionName, args, options); + } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3}; + return Ray.internal().callPy(moduleName, functionName, args, null); + } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, CallOptions options) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3}; + return Ray.internal().callPy(moduleName, functionName, args, options); + } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4}; + return Ray.internal().callPy(moduleName, functionName, args, null); + } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, CallOptions options) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4}; + return Ray.internal().callPy(moduleName, functionName, args, options); + } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5}; + return Ray.internal().callPy(moduleName, functionName, args, null); + } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5, CallOptions options) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5}; + return Ray.internal().callPy(moduleName, functionName, args, options); + } + public static RayObject callPy(RayPyActor pyActor, String functionName) { + Object[] args = new Object[]{}; + return Ray.internal().callPy(pyActor, functionName, args); + } + public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0) { + Object[] args = new Object[]{obj0}; + return Ray.internal().callPy(pyActor, functionName, args); + } + public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1) { + Object[] args = new Object[]{obj0, obj1}; + return Ray.internal().callPy(pyActor, functionName, args); + } + public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2) { + Object[] args = new Object[]{obj0, obj1, obj2}; + return Ray.internal().callPy(pyActor, functionName, args); + } + public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2, Object obj3) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3}; + return Ray.internal().callPy(pyActor, functionName, args); + } + public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4}; + return Ray.internal().callPy(pyActor, functionName, args); + } + public static RayPyActor createPyActor(String moduleName, String className) { + Object[] args = new Object[]{}; + return Ray.internal().createPyActor(moduleName, className, args, null); + } + public static RayPyActor createPyActor(String moduleName, String className, ActorCreationOptions options) { + Object[] args = new Object[]{}; + return Ray.internal().createPyActor(moduleName, className, args, options); + } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0) { + Object[] args = new Object[]{obj0}; + return Ray.internal().createPyActor(moduleName, className, args, null); + } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, ActorCreationOptions options) { + Object[] args = new Object[]{obj0}; + return Ray.internal().createPyActor(moduleName, className, args, options); + } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1) { + Object[] args = new Object[]{obj0, obj1}; + return Ray.internal().createPyActor(moduleName, className, args, null); + } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, ActorCreationOptions options) { + Object[] args = new Object[]{obj0, obj1}; + return Ray.internal().createPyActor(moduleName, className, args, options); + } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2) { + Object[] args = new Object[]{obj0, obj1, obj2}; + return Ray.internal().createPyActor(moduleName, className, args, null); + } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, ActorCreationOptions options) { + Object[] args = new Object[]{obj0, obj1, obj2}; + return Ray.internal().createPyActor(moduleName, className, args, options); + } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3}; + return Ray.internal().createPyActor(moduleName, className, args, null); + } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, ActorCreationOptions options) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3}; + return Ray.internal().createPyActor(moduleName, className, args, options); + } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4}; + return Ray.internal().createPyActor(moduleName, className, args, null); + } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, ActorCreationOptions options) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4}; + return Ray.internal().createPyActor(moduleName, className, args, options); + } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5}; + return Ray.internal().createPyActor(moduleName, className, args, null); + } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5, ActorCreationOptions options) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5}; + return Ray.internal().createPyActor(moduleName, className, args, options); + } } diff --git a/java/api/src/main/java/org/ray/api/RayPyActor.java b/java/api/src/main/java/org/ray/api/RayPyActor.java new file mode 100644 index 000000000..4f32bc4f4 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/RayPyActor.java @@ -0,0 +1,18 @@ +package org.ray.api; + +/** + * Handle of a Python actor. + */ +public interface RayPyActor extends RayActor { + + /** + * @return Module name of the Python actor class. + */ + String getModuleName(); + + /** + * @return Name of the Python actor class. + */ + String getClassName(); +} + diff --git a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java index 905bf1f14..905958ddf 100644 --- a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java @@ -3,6 +3,7 @@ package org.ray.api.runtime; import java.util.List; import org.ray.api.RayActor; import org.ray.api.RayObject; +import org.ray.api.RayPyActor; import org.ray.api.RuntimeContext; import org.ray.api.WaitResult; import org.ray.api.function.RayFunc; @@ -45,8 +46,8 @@ public interface RayRuntime { List get(List objectIds); /** - * Wait for a list of RayObjects to be locally available, - * until specified number of objects are ready, or specified timeout has passed. + * Wait for a list of RayObjects to be locally available, until specified number of objects are + * ready, or specified timeout has passed. * * @param waitList A list of RayObject to wait for. * @param numReturns The number of objects that should be returned. @@ -96,4 +97,37 @@ public interface RayRuntime { ActorCreationOptions options); RuntimeContext getRuntimeContext(); + + /** + * Invoke a remote Python function. + * + * @param moduleName Module name of the Python function. + * @param functionName Name of the Python function. + * @param args Arguments of the function. + * @param options The options for this call. + * @return The result object. + */ + RayObject callPy(String moduleName, String functionName, Object[] args, CallOptions options); + + /** + * Invoke a remote Python function on an actor. + * + * @param pyActor A handle to the actor. + * @param functionName Name of the actor method. + * @param args Arguments of the function. + * @return The result object. + */ + RayObject callPy(RayPyActor pyActor, String functionName, Object[] args); + + /** + * Create a Python actor on a remote node. + * + * @param moduleName Module name of the Python actor class. + * @param className Name of the Python actor class. + * @param args Arguments of the actor constructor. + * @param options The options for creating actor. + * @return A handle to the actor. + */ + RayPyActor createPyActor(String moduleName, String className, Object[] args, + ActorCreationOptions options); } diff --git a/java/pom.xml b/java/pom.xml index 9cd9f158a..1194c3926 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -235,7 +235,6 @@ true true warning - xml ${project.build.directory}/checkstyle-errors.xml false diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 2411b9267..4ae1ee606 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -1,5 +1,6 @@ package org.ray.runtime; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.Arrays; @@ -10,6 +11,7 @@ import java.util.Map; import java.util.stream.Collectors; import org.ray.api.RayActor; import org.ray.api.RayObject; +import org.ray.api.RayPyActor; import org.ray.api.RuntimeContext; import org.ray.api.WaitResult; import org.ray.api.exception.RayException; @@ -20,12 +22,14 @@ import org.ray.api.options.BaseTaskOptions; import org.ray.api.options.CallOptions; import org.ray.api.runtime.RayRuntime; import org.ray.runtime.config.RayConfig; +import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.functionmanager.FunctionManager; -import org.ray.runtime.functionmanager.RayFunction; +import org.ray.runtime.functionmanager.PyFunctionDescriptor; import org.ray.runtime.objectstore.ObjectStoreProxy; import org.ray.runtime.objectstore.ObjectStoreProxy.GetResult; import org.ray.runtime.raylet.RayletClient; import org.ray.runtime.task.ArgumentsBuilder; +import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.ResourceUtil; import org.ray.runtime.util.UniqueIdUtil; @@ -69,7 +73,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { functionManager = new FunctionManager(rayConfig.driverResourcePath); worker = new Worker(this); workerContext = new WorkerContext(rayConfig.workerMode, - rayConfig.driverId, rayConfig.runMode); + rayConfig.driverId, rayConfig.runMode); runtimeContext = new RuntimeContextImpl(this); } @@ -229,7 +233,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { @Override public RayObject call(RayFunc func, Object[] args, CallOptions options) { - TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, args, false, options); + TaskSpec spec = createTaskSpec(func, null, RayActorImpl.NIL, args, false, options); rayletClient.submitTask(spec); return new RayObjectImpl(spec.returnIds[0]); } @@ -242,7 +246,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { RayActorImpl actorImpl = (RayActorImpl) actor; TaskSpec spec; synchronized (actor) { - spec = createTaskSpec(func, actorImpl, args, false, null); + spec = createTaskSpec(func, null, actorImpl, args, false, null); spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor()); actorImpl.setTaskCursor(spec.returnIds[1]); actorImpl.clearNewActorHandles(); @@ -255,7 +259,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { @SuppressWarnings("unchecked") public RayActor createActor(RayFunc actorFactoryFunc, Object[] args, ActorCreationOptions options) { - TaskSpec spec = createTaskSpec(actorFactoryFunc, RayActorImpl.NIL, + TaskSpec spec = createTaskSpec(actorFactoryFunc, null, RayActorImpl.NIL, args, true, options); RayActorImpl actor = new RayActorImpl(spec.returnIds[0]); actor.increaseTaskCounter(); @@ -264,17 +268,71 @@ public abstract class AbstractRayRuntime implements RayRuntime { return (RayActor) actor; } + private void checkPyArguments(Object[] args) { + for (Object arg : args) { + Preconditions.checkArgument( + (arg instanceof RayPyActor) || (arg instanceof byte[]), + "Python argument can only be a RayPyActor or a byte array, not {}.", + arg.getClass().getName()); + } + } + + @Override + public RayObject callPy(String moduleName, String functionName, Object[] args, + CallOptions options) { + checkPyArguments(args); + PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, "", functionName); + TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, false, options); + rayletClient.submitTask(spec); + return new RayObjectImpl(spec.returnIds[0]); + } + + @Override + public RayObject callPy(RayPyActor pyActor, String functionName, Object... args) { + checkPyArguments(args); + PyFunctionDescriptor desc = new PyFunctionDescriptor(pyActor.getModuleName(), + pyActor.getClassName(), functionName); + RayPyActorImpl actorImpl = (RayPyActorImpl) pyActor; + TaskSpec spec; + synchronized (pyActor) { + spec = createTaskSpec(null, desc, actorImpl, args, false, null); + spec.getExecutionDependencies().add(actorImpl.getTaskCursor()); + actorImpl.setTaskCursor(spec.returnIds[1]); + actorImpl.clearNewActorHandles(); + } + rayletClient.submitTask(spec); + return new RayObjectImpl(spec.returnIds[0]); + } + + @Override + public RayPyActor createPyActor(String moduleName, String className, Object[] args, + ActorCreationOptions options) { + checkPyArguments(args); + PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, className, "__init__"); + TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, true, options); + RayPyActorImpl actor = new RayPyActorImpl(spec.actorCreationId, moduleName, className); + actor.increaseTaskCounter(); + actor.setTaskCursor(spec.returnIds[0]); + rayletClient.submitTask(spec); + return actor; + } + /** * Create the task specification. * * @param func The target remote function. + * @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a + * Python task. * @param actor The actor handle. If the task is not an actor task, actor id must be NIL. * @param args The arguments for the remote function. * @param isActorCreationTask Whether this task is an actor creation task. * @return A TaskSpec object. */ - private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args, + private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDescriptor, + RayActorImpl actor, Object[] args, boolean isActorCreationTask, BaseTaskOptions taskOptions) { + Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null)); + UniqueId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(), workerContext.getCurrentTaskId(), workerContext.nextTaskIndex()); int numReturns = actor.getId().isNil() ? 1 : 2; @@ -302,7 +360,16 @@ public abstract class AbstractRayRuntime implements RayRuntime { maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions; } - RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentDriverId(), func); + TaskLanguage language; + FunctionDescriptor functionDescriptor; + if (func != null) { + language = TaskLanguage.JAVA; + functionDescriptor = functionManager.getFunction(workerContext.getCurrentDriverId(), func) + .getFunctionDescriptor(); + } else { + language = TaskLanguage.PYTHON; + functionDescriptor = pyFunctionDescriptor; + } return new TaskSpec( workerContext.getCurrentDriverId(), @@ -315,10 +382,11 @@ public abstract class AbstractRayRuntime implements RayRuntime { actor.getHandleId(), actor.increaseTaskCounter(), actor.getNewActorHandles().toArray(new UniqueId[0]), - ArgumentsBuilder.wrap(args), + ArgumentsBuilder.wrap(args, language == TaskLanguage.PYTHON), returnIds, resources, - rayFunction.getFunctionDescriptor() + language, + functionDescriptor ); } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java index 2d86449c5..7899869ae 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java @@ -10,26 +10,32 @@ import org.ray.api.RayActor; import org.ray.api.id.UniqueId; import org.ray.runtime.util.Sha1Digestor; -public final class RayActorImpl implements RayActor, Externalizable { +public class RayActorImpl implements RayActor, Externalizable { public static final RayActorImpl NIL = new RayActorImpl(); - private UniqueId id; - private UniqueId handleId; + /** + * Id of this actor. + */ + protected UniqueId id; + /** + * Handle id of this actor. + */ + protected UniqueId handleId; /** * The number of tasks that have been invoked on this actor. */ - private int taskCounter; + protected int taskCounter; /** * The unique id of the last return of the last task. * It's used as a dependency for the next task. */ - private UniqueId taskCursor; + protected UniqueId taskCursor; /** * The number of times that this actor handle has been forked. * It's used to make sure ids of actor handles are unique. */ - private int numForks; + protected int numForks; /** * The new actor handles that were created from this handle @@ -37,7 +43,7 @@ public final class RayActorImpl implements RayActor, Externalizable { * used to garbage-collect dummy objects that are no longer * necessary in the backend. */ - private List newActorHandles; + protected List newActorHandles; public RayActorImpl() { this(UniqueId.NIL, UniqueId.NIL); diff --git a/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java new file mode 100644 index 000000000..2938478d2 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java @@ -0,0 +1,69 @@ +package org.ray.runtime; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import org.ray.api.RayPyActor; +import org.ray.api.id.UniqueId; + +public class RayPyActorImpl extends RayActorImpl implements RayPyActor { + + public static final RayPyActorImpl NIL = new RayPyActorImpl(UniqueId.NIL, null, null); + + /** + * Module name of the Python actor class. + */ + private String moduleName; + + /** + * Name of the Python actor class. + */ + private String className; + + private RayPyActorImpl() {} + + public RayPyActorImpl(UniqueId id, String moduleName, String className) { + super(id); + this.moduleName = moduleName; + this.className = className; + } + + @Override + public String getModuleName() { + return moduleName; + } + + @Override + public String getClassName() { + return className; + } + + public RayPyActorImpl fork() { + RayPyActorImpl ret = new RayPyActorImpl(); + ret.id = this.id; + ret.taskCounter = 0; + ret.numForks = 0; + ret.taskCursor = this.taskCursor; + ret.moduleName = this.moduleName; + ret.className = this.className; + ret.handleId = this.computeNextActorHandleId(); + newActorHandles.add(ret.handleId); + return ret; + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + super.writeExternal(out); + out.writeObject(this.moduleName); + out.writeObject(this.className); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + super.readExternal(in); + this.moduleName = (String) in.readObject(); + this.className = (String) in.readObject(); + } + +} + diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/Worker.java index ef319ea20..c8f5aaa34 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/Worker.java @@ -85,7 +85,7 @@ public class Worker { try { // Get method RayFunction rayFunction = runtime.getFunctionManager() - .getFunction(spec.driverId, spec.functionDescriptor); + .getFunction(spec.driverId, spec.getJavaFunctionDescriptor()); // Set context runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader); Thread.currentThread().setContextClassLoader(rayFunction.classLoader); diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionDescriptor.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionDescriptor.java index 70be2f3e9..3d0b36b35 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionDescriptor.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionDescriptor.java @@ -1,52 +1,11 @@ package org.ray.runtime.functionmanager; -import com.google.common.base.Objects; - /** - * Represents the function's metadata. + * Base interface of a Ray task's function descriptor. + * + * A function descriptor is a list of strings that can uniquely describe a function. It's used to + * load a function in workers. */ -public final class FunctionDescriptor { +public interface FunctionDescriptor { - /** - * Function's class name. - */ - public final String className; - /** - * Function's name. - */ - public final String name; - /** - * Function's type descriptor. - */ - public final String typeDescriptor; - - public FunctionDescriptor(String className, String name, String typeDescriptor) { - this.className = className; - this.name = name; - this.typeDescriptor = typeDescriptor; - } - - @Override - public String toString() { - return className + "." + name; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - FunctionDescriptor that = (FunctionDescriptor) o; - return Objects.equal(className, that.className) && - Objects.equal(name, that.name) && - Objects.equal(typeDescriptor, that.typeDescriptor); - } - - @Override - public int hashCode() { - return Objects.hashCode(className, name, typeDescriptor); - } } diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java index d7698c22a..7c267d0ea 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java @@ -30,10 +30,10 @@ public class FunctionManager { static final String CONSTRUCTOR_NAME = ""; /** - * Cache from a RayFunc object to its corresponding FunctionDescriptor. Because + * Cache from a RayFunc object to its corresponding JavaFunctionDescriptor. Because * `LambdaUtils.getSerializedLambda` is expensive. */ - private static final ThreadLocal, FunctionDescriptor>> + private static final ThreadLocal, JavaFunctionDescriptor>> RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new); /** @@ -64,13 +64,13 @@ public class FunctionManager { * @return A RayFunction object. */ public RayFunction getFunction(UniqueId driverId, RayFunc func) { - FunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass()); + JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass()); if (functionDescriptor == null) { SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func); final String className = serializedLambda.getImplClass().replace('/', '.'); final String methodName = serializedLambda.getImplMethodName(); final String typeDescriptor = serializedLambda.getImplMethodSignature(); - functionDescriptor = new FunctionDescriptor(className, methodName, typeDescriptor); + functionDescriptor = new JavaFunctionDescriptor(className, methodName, typeDescriptor); RAY_FUNC_CACHE.get().put(func.getClass(),functionDescriptor); } return getFunction(driverId, functionDescriptor); @@ -83,7 +83,7 @@ public class FunctionManager { * @param functionDescriptor The function descriptor. * @return A RayFunction object. */ - public RayFunction getFunction(UniqueId driverId, FunctionDescriptor functionDescriptor) { + public RayFunction getFunction(UniqueId driverId, JavaFunctionDescriptor functionDescriptor) { DriverFunctionTable driverFunctionTable = driverFunctionTables.get(driverId); if (driverFunctionTable == null) { String resourcePath = driverResourcePath + "/" + driverId.toString() + "/"; @@ -122,7 +122,7 @@ public class FunctionManager { this.functions = new HashMap<>(); } - RayFunction getFunction(FunctionDescriptor descriptor) { + RayFunction getFunction(JavaFunctionDescriptor descriptor) { Map, RayFunction> classFunctions = functions.get(descriptor.className); if (classFunctions == null) { classFunctions = loadFunctionsForClass(descriptor.className); @@ -150,7 +150,7 @@ public class FunctionManager { e instanceof Method ? Type.getType((Method) e) : Type.getType((Constructor) e); final String typeDescriptor = type.getDescriptor(); RayFunction rayFunction = new RayFunction(e, classLoader, - new FunctionDescriptor(className, methodName, typeDescriptor)); + new JavaFunctionDescriptor(className, methodName, typeDescriptor)); map.put(ImmutablePair.of(methodName, typeDescriptor), rayFunction); } } catch (Exception e) { diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/JavaFunctionDescriptor.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/JavaFunctionDescriptor.java new file mode 100644 index 000000000..aac416fa5 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/JavaFunctionDescriptor.java @@ -0,0 +1,52 @@ +package org.ray.runtime.functionmanager; + +import com.google.common.base.Objects; + +/** + * Represents metadata of Java function. + */ +public final class JavaFunctionDescriptor implements FunctionDescriptor { + + /** + * Function's class name. + */ + public final String className; + /** + * Function's name. + */ + public final String name; + /** + * Function's type descriptor. + */ + public final String typeDescriptor; + + public JavaFunctionDescriptor(String className, String name, String typeDescriptor) { + this.className = className; + this.name = name; + this.typeDescriptor = typeDescriptor; + } + + @Override + public String toString() { + return className + "." + name; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + JavaFunctionDescriptor that = (JavaFunctionDescriptor) o; + return Objects.equal(className, that.className) && + Objects.equal(name, that.name) && + Objects.equal(typeDescriptor, that.typeDescriptor); + } + + @Override + public int hashCode() { + return Objects.hashCode(className, name, typeDescriptor); + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/PyFunctionDescriptor.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/PyFunctionDescriptor.java new file mode 100644 index 000000000..1fe13f0fb --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/PyFunctionDescriptor.java @@ -0,0 +1,25 @@ +package org.ray.runtime.functionmanager; + +/** + * Represents metadata of a Python function. + */ +public class PyFunctionDescriptor implements FunctionDescriptor { + + public String moduleName; + + public String className; + + public String functionName; + + public PyFunctionDescriptor(String moduleName, String className, String functionName) { + this.moduleName = moduleName; + this.className = className; + this.functionName = functionName; + } + + @Override + public String toString() { + return moduleName + "." + className + "." + functionName; + } +} + diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java index 2f39ec3dc..ac2f77e05 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java @@ -23,10 +23,10 @@ public class RayFunction { /** * Function's metadata. */ - public final FunctionDescriptor functionDescriptor; + public final JavaFunctionDescriptor functionDescriptor; public RayFunction(Executable executable, ClassLoader classLoader, - FunctionDescriptor functionDescriptor) { + JavaFunctionDescriptor functionDescriptor) { this.executable = executable; this.classLoader = classLoader; this.functionDescriptor = functionDescriptor; @@ -53,7 +53,7 @@ public class RayFunction { return (Method) executable; } - public FunctionDescriptor getFunctionDescriptor() { + public JavaFunctionDescriptor getFunctionDescriptor() { return functionDescriptor; } diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index d1d9102f7..64b9e2b73 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -36,6 +36,8 @@ public class ObjectStoreProxy { private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes(); + private static final byte[] RAW_TYPE_META = "RAW".getBytes(); + private final AbstractRayRuntime runtime; private static ThreadLocal objectStore; @@ -83,9 +85,8 @@ public class ObjectStoreProxy { GetResult result; if (meta != null) { - // If meta is not null, deserialize the exception. - RayException exception = deserializeRayExceptionFromMeta(meta, ids.get(i)); - result = new GetResult<>(true, null, exception); + // If meta is not null, deserialize the object from meta. + result = deserializeFromMeta(meta, data, ids.get(i)); } else if (data != null) { // If data is not null, deserialize the Java object. Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader()); @@ -112,13 +113,16 @@ public class ObjectStoreProxy { return results; } - private RayException deserializeRayExceptionFromMeta(byte[] meta, UniqueId objectId) { - if (Arrays.equals(meta, WORKER_EXCEPTION_META)) { - return RayWorkerException.INSTANCE; + @SuppressWarnings("unchecked") + private GetResult deserializeFromMeta(byte[] meta, byte[] data, UniqueId objectId) { + if (Arrays.equals(meta, RAW_TYPE_META)) { + return (GetResult) new GetResult<>(true, data, null); + } else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) { + return new GetResult<>(true, null, RayWorkerException.INSTANCE); } else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) { - return RayActorException.INSTANCE; + return new GetResult<>(true, null, RayActorException.INSTANCE); } else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) { - return new UnreconstructableException(objectId); + return new GetResult<>(true, null, new UnreconstructableException(objectId)); } throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta)); } @@ -131,7 +135,13 @@ public class ObjectStoreProxy { */ public void put(UniqueId id, Object object) { try { - objectStore.get().put(id.getBytes(), Serializer.encode(object), null); + if (object instanceof byte[]) { + // If the object is a byte array, skip serializing it and use a special metadata to + // indicate it's raw binary. So that this object can also be read by Python. + objectStore.get().put(id.getBytes(), (byte[]) object, RAW_TYPE_META); + } else { + objectStore.get().put(id.getBytes(), Serializer.encode(object), null); + } } catch (DuplicateObjectException e) { LOGGER.warn(e.getMessage()); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 60eaf2d23..8fb93a4a4 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -12,12 +12,13 @@ import org.ray.api.RayObject; import org.ray.api.WaitResult; import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; -import org.ray.runtime.functionmanager.FunctionDescriptor; +import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.runtime.generated.Arg; import org.ray.runtime.generated.Language; import org.ray.runtime.generated.ResourcePair; import org.ray.runtime.generated.TaskInfo; import org.ray.runtime.task.FunctionArg; +import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.UniqueIdUtil; import org.slf4j.Logger; @@ -183,13 +184,14 @@ public class RayletClientImpl implements RayletClient { resources.put(info.requiredResources(i).key(), info.requiredResources(i).value()); } // Deserialize function descriptor + Preconditions.checkArgument(info.language() == Language.JAVA); Preconditions.checkArgument(info.functionDescriptorLength() == 3); - FunctionDescriptor functionDescriptor = new FunctionDescriptor( + JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor( info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2) ); return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId, maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles, - args, returnIds, resources, functionDescriptor); + args, returnIds, resources, TaskLanguage.JAVA, functionDescriptor); } private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { @@ -250,12 +252,29 @@ public class RayletClientImpl implements RayletClient { int requiredPlacementResourcesOffset = fbb.createVectorOfTables(requiredPlacementResourcesOffsets); - int[] functionDescriptorOffsets = new int[]{ - fbb.createString(task.functionDescriptor.className), - fbb.createString(task.functionDescriptor.name), - fbb.createString(task.functionDescriptor.typeDescriptor) - }; - int functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets); + int language; + int functionDescriptorOffset; + + if (task.language == TaskLanguage.JAVA) { + // This is a Java task. + language = Language.JAVA; + int[] functionDescriptorOffsets = new int[]{ + fbb.createString(task.getJavaFunctionDescriptor().className), + fbb.createString(task.getJavaFunctionDescriptor().name), + fbb.createString(task.getJavaFunctionDescriptor().typeDescriptor) + }; + functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets); + } else { + // This is a Python task. + language = Language.PYTHON; + int[] functionDescriptorOffsets = new int[]{ + fbb.createString(task.getPyFunctionDescriptor().moduleName), + fbb.createString(task.getPyFunctionDescriptor().className), + fbb.createString(task.getPyFunctionDescriptor().functionName), + fbb.createString("") + }; + functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets); + } int root = TaskInfo.createTaskInfo( fbb, @@ -274,7 +293,7 @@ public class RayletClientImpl implements RayletClient { returnsOffset, requiredResourcesOffset, requiredPlacementResourcesOffset, - Language.JAVA, + language, functionDescriptorOffset); fbb.finish(root); ByteBuffer buffer = fbb.dataBuffer(); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index 83714a6de..1da6dec31 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -12,16 +12,15 @@ import org.ray.runtime.util.Serializer; public class ArgumentsBuilder { /** - * If the the size of an argument's serialized data is smaller than this number, - * the argument will be passed by value. Otherwise it'll be passed by reference. + * If the the size of an argument's serialized data is smaller than this number, the argument will + * be passed by value. Otherwise it'll be passed by reference. */ private static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024; - /** * Convert real function arguments to task spec arguments. */ - public static FunctionArg[] wrap(Object[] args) { + public static FunctionArg[] wrap(Object[] args, boolean crossLanguage) { FunctionArg[] ret = new FunctionArg[args.length]; for (int i = 0; i < ret.length; i++) { Object arg = args[i]; @@ -33,10 +32,15 @@ public class ArgumentsBuilder { data = Serializer.encode(arg); } else if (arg instanceof RayObject) { id = ((RayObject) arg).getId(); + } else if (arg instanceof byte[] && crossLanguage) { + // If the argument is a byte array and will be used by a different language, + // do not inline this argument. Because the other language doesn't know how + // to deserialize it. + id = Ray.put(arg).getId(); } else { byte[] serialized = Serializer.encode(arg); if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) { - id = ((AbstractRayRuntime)Ray.internal()).putSerialized(serialized).getId(); + id = ((AbstractRayRuntime) Ray.internal()).putSerialized(serialized).getId(); } else { data = serialized; } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskLanguage.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskLanguage.java new file mode 100644 index 000000000..a6b4f31d8 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskLanguage.java @@ -0,0 +1,11 @@ +package org.ray.runtime.task; + +/** + * Language of a Ray task. + */ +public enum TaskLanguage { + + JAVA, + + PYTHON, +} diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index 1e205b99b..d8f715ce6 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -1,12 +1,14 @@ package org.ray.runtime.task; +import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.FunctionDescriptor; -import org.ray.runtime.util.ResourceUtil; +import org.ray.runtime.functionmanager.JavaFunctionDescriptor; +import org.ray.runtime.functionmanager.PyFunctionDescriptor; /** * Represents necessary information of a task for scheduling and executing. @@ -52,9 +54,13 @@ public class TaskSpec { // The task's resource demands. public final Map resources; - // Function descriptor is a list of strings that can uniquely identify a function. - // It will be sent to worker and used to load the target callable function. - public final FunctionDescriptor functionDescriptor; + // Language of this task. + public final TaskLanguage language; + + // Descriptor of the remote function. + // Note, if task language is Java, the type is JavaFunctionDescriptor. If the task language + // is Python, the type is PyFunctionDescriptor. + private final FunctionDescriptor functionDescriptor; private List executionDependencies; @@ -66,10 +72,22 @@ public class TaskSpec { return !actorCreationId.isNil(); } - public TaskSpec(UniqueId driverId, UniqueId taskId, UniqueId parentTaskId, int parentCounter, - UniqueId actorCreationId, int maxActorReconstructions, UniqueId actorId, - UniqueId actorHandleId, int actorCounter, UniqueId[] newActorHandles, FunctionArg[] args, - UniqueId[] returnIds, Map resources, FunctionDescriptor functionDescriptor) { + public TaskSpec( + UniqueId driverId, + UniqueId taskId, + UniqueId parentTaskId, + int parentCounter, + UniqueId actorCreationId, + int maxActorReconstructions, + UniqueId actorId, + UniqueId actorHandleId, + int actorCounter, + UniqueId[] newActorHandles, + FunctionArg[] args, + UniqueId[] returnIds, + Map resources, + TaskLanguage language, + FunctionDescriptor functionDescriptor) { this.driverId = driverId; this.taskId = taskId; this.parentTaskId = parentTaskId; @@ -83,10 +101,30 @@ public class TaskSpec { this.args = args; this.returnIds = returnIds; this.resources = resources; + this.language = language; + if (language == TaskLanguage.JAVA) { + Preconditions.checkArgument(functionDescriptor instanceof JavaFunctionDescriptor, + "Expect JavaFunctionDescriptor type, but got {}.", functionDescriptor.getClass()); + } else if (language == TaskLanguage.PYTHON) { + Preconditions.checkArgument(functionDescriptor instanceof PyFunctionDescriptor, + "Expect PyFunctionDescriptor type, but got {}.", functionDescriptor.getClass()); + } else { + Preconditions.checkArgument(false, "Unknown task language: {}.", language); + } this.functionDescriptor = functionDescriptor; this.executionDependencies = new ArrayList<>(); } + public JavaFunctionDescriptor getJavaFunctionDescriptor() { + Preconditions.checkState(language == TaskLanguage.JAVA); + return (JavaFunctionDescriptor) functionDescriptor; + } + + public PyFunctionDescriptor getPyFunctionDescriptor() { + Preconditions.checkState(language == TaskLanguage.PYTHON); + return (PyFunctionDescriptor) functionDescriptor; + } + public List getExecutionDependencies() { return executionDependencies; } @@ -99,13 +137,17 @@ public class TaskSpec { ", parentTaskId=" + parentTaskId + ", parentCounter=" + parentCounter + ", actorCreationId=" + actorCreationId + + ", maxActorReconstructions=" + maxActorReconstructions + ", actorId=" + actorId + ", actorHandleId=" + actorHandleId + ", actorCounter=" + actorCounter + + ", newActorHandles=" + Arrays.toString(newActorHandles) + ", args=" + Arrays.toString(args) + ", returnIds=" + Arrays.toString(returnIds) + - ", resources=" + ResourceUtil.getResourcesStringFromMap(resources) + + ", resources=" + resources + + ", language=" + language + ", functionDescriptor=" + functionDescriptor + + ", executionDependencies=" + executionDependencies + '}'; } } diff --git a/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java b/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java index 82fdf6b7f..764887160 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java @@ -21,7 +21,6 @@ public class RayCallGenerator extends BaseGenerator { newLine(""); newLine("package org.ray.api;"); newLine(""); - newLine("import org.ray.api.function.RayFunc;"); newLine("import org.ray.api.function.RayFunc0;"); newLine("import org.ray.api.function.RayFunc1;"); newLine("import org.ray.api.function.RayFunc2;"); @@ -30,7 +29,6 @@ public class RayCallGenerator extends BaseGenerator { newLine("import org.ray.api.function.RayFunc5;"); newLine("import org.ray.api.function.RayFunc6;"); newLine("import org.ray.api.options.ActorCreationOptions;"); - newLine("import org.ray.api.options.BaseTaskOptions;"); newLine("import org.ray.api.options.CallOptions;"); newLine(""); @@ -46,6 +44,7 @@ public class RayCallGenerator extends BaseGenerator { buildCalls(i, false, false, false); buildCalls(i, false, false, true); } + newLine(1, "// ==========================================="); newLine(1, "// Methods for remote actor method invocation."); newLine(1, "// ==========================================="); @@ -59,6 +58,21 @@ public class RayCallGenerator extends BaseGenerator { buildCalls(i, false, true, false); buildCalls(i, false, true, true); } + + newLine(1, "// ==========================="); + newLine(1, "// Cross-language methods."); + newLine(1, "// ==========================="); + for (int i = 0; i <= MAX_PARAMETERS; i++) { + buildPyCalls(i, false, false, false); + buildPyCalls(i, false, false, true); + } + for (int i = 0; i <= MAX_PARAMETERS - 1; i++) { + buildPyCalls(i, true, false, false); + } + for (int i = 0; i <= MAX_PARAMETERS; i++) { + buildPyCalls(i, false, true, false); + buildPyCalls(i,false, true, true); + } newLine("}"); return sb.toString(); } @@ -117,18 +131,86 @@ public class RayCallGenerator extends BaseGenerator { String funcName = !forActorCreation ? "call" : "createActor"; String funcArgs = !forActor ? "f, args" : "f, actor, args"; for (String param : generateParameters(0, numParameters)) { - // method signature + // Method signature. newLine(1, String.format( "public static <%s> %s %s(%s%s) {", genericTypes, returnType, funcName, paramPrefix + param, optionsParam )); - // method body + // Method body. newLine(2, String.format("Object[] args = new Object[]{%s};", argList)); newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg)); newLine(1, "}"); } } + /** + * Build the `Ray.callPy` or `Ray.createPyActor` methods. + * @param forActor build actor api when true, otherwise build task api. + * @param forActorCreation build `Ray.createPyActor` when true, otherwise build `Ray.callPy`. + */ + private void buildPyCalls(int numParameters, boolean forActor, + boolean forActorCreation, boolean hasOptionsParam) { + String argList = ""; + String paramList = ""; + for (int i = 0; i < numParameters; i++) { + paramList += "Object obj" + i + ", "; + argList += "obj" + i + ", "; + } + if (argList.endsWith(", ")) { + argList = argList.substring(0, argList.length() - 2); + } + if (paramList.endsWith(", ")) { + paramList = paramList.substring(0, paramList.length() - 2); + } + + String paramPrefix = ""; + String funcArgs = ""; + if (forActorCreation) { + paramPrefix += "String moduleName, String className"; + funcArgs += "moduleName, className"; + } else if (forActor) { + paramPrefix += "RayPyActor pyActor, String functionName"; + funcArgs += "pyActor, functionName"; + } else { + paramPrefix += "String moduleName, String functionName"; + funcArgs += "moduleName, functionName"; + } + if (numParameters > 0) { + paramPrefix += ", "; + } + + String optionsParam; + if (hasOptionsParam) { + optionsParam = forActorCreation ? ", ActorCreationOptions options" : ", CallOptions options"; + } else { + optionsParam = ""; + } + + String optionsArg; + if (forActor) { + optionsArg = ""; + } else { + if (hasOptionsParam) { + optionsArg = ", options"; + } else { + optionsArg = ", null"; + } + } + + String returnType = !forActorCreation ? "RayObject" : "RayPyActor"; + String funcName = !forActorCreation ? "callPy" : "createPyActor"; + funcArgs += ", args"; + // Method signature. + newLine(1, String.format( + "public static %s %s(%s%s) {", + returnType, funcName, paramPrefix + paramList, optionsParam + )); + // Method body. + newLine(2, String.format("Object[] args = new Object[]{%s};", argList)); + newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg)); + newLine(1, "}"); + } + private List generateParameters(int from, int to) { List res = new ArrayList<>(); dfs(from, from, to, "", res); @@ -155,3 +237,4 @@ public class RayCallGenerator extends BaseGenerator { FileUtil.overrideFile(path, new RayCallGenerator().build()); } } + diff --git a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java index e0307635a..7bc1864d3 100644 --- a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java +++ b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java @@ -40,20 +40,20 @@ public class FunctionManagerTest { private static RayFunc0 fooFunc; private static RayFunc1 barFunc; private static RayFunc0 barConstructor; - private static FunctionDescriptor fooDescriptor; - private static FunctionDescriptor barDescriptor; - private static FunctionDescriptor barConstructorDescriptor; + private static JavaFunctionDescriptor fooDescriptor; + private static JavaFunctionDescriptor barDescriptor; + private static JavaFunctionDescriptor barConstructorDescriptor; @BeforeClass public static void beforeClass() { fooFunc = FunctionManagerTest::foo; barConstructor = Bar::new; barFunc = Bar::bar; - fooDescriptor = new FunctionDescriptor(FunctionManagerTest.class.getName(), "foo", + fooDescriptor = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(), "foo", "()Ljava/lang/Object;"); - barDescriptor = new FunctionDescriptor(Bar.class.getName(), "bar", + barDescriptor = new JavaFunctionDescriptor(Bar.class.getName(), "bar", "()Ljava/lang/Object;"); - barConstructorDescriptor = new FunctionDescriptor(Bar.class.getName(), + barConstructorDescriptor = new JavaFunctionDescriptor(Bar.class.getName(), FunctionManager.CONSTRUCTOR_NAME, "()V"); } @@ -132,7 +132,7 @@ public class FunctionManagerTest { Files.copy(Paths.get(srcJarPath), Paths.get(destJarPath), StandardCopyOption.REPLACE_EXISTING); final FunctionManager functionManager = new FunctionManager(resourcePath); - FunctionDescriptor sayHelloDescriptor = new FunctionDescriptor("org.ray.exercise.Exercise02", + JavaFunctionDescriptor sayHelloDescriptor = new JavaFunctionDescriptor("org.ray.exercise.Exercise02", "sayHello", "()Ljava/lang/String;"); RayFunction func = functionManager.getFunction(driverId, sayHelloDescriptor); Assert.assertEquals(func.getFunctionDescriptor(), sayHelloDescriptor); diff --git a/java/test/pom.xml b/java/test/pom.xml index 448364641..8c59026db 100644 --- a/java/test/pom.xml +++ b/java/test/pom.xml @@ -49,9 +49,12 @@ org.apache.maven.plugins maven-surefire-plugin - 2.21.0 + 3.0.0-M3 + false + false ${basedir}/src/main/java/ + ${basedir}/src/main/resources/ ${project.build.directory}/classes/ diff --git a/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java b/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java new file mode 100644 index 000000000..939372b96 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java @@ -0,0 +1,112 @@ +package org.ray.api.test; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.io.File; +import java.lang.ProcessBuilder.Redirect; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.TimeUnit; +import org.ray.api.Ray; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.SkipException; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; + +public abstract class BaseMultiLanguageTest { + + private static final Logger LOGGER = LoggerFactory.getLogger(BaseMultiLanguageTest.class); + + private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/test/plasma_store_socket"; + private static final String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket"; + + /** + * Execute an external command. + * + * @return Whether the command succeeded. + */ + private boolean executeCommand(List command, int waitTimeoutSeconds, + Map env) { + try { + LOGGER.info("Executing command: {}", String.join(" ", command)); + ProcessBuilder processBuilder = new ProcessBuilder(command).redirectOutput(Redirect.INHERIT) + .redirectError(Redirect.INHERIT); + for (Entry entry : env.entrySet()) { + processBuilder.environment().put(entry.getKey(), entry.getValue()); + } + Process process = processBuilder.start(); + process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS); + return process.exitValue() == 0; + } catch (Exception e) { + throw new RuntimeException("Error executing command " + String.join(" ", command), e); + } + } + + @BeforeClass + public void setUp() { + if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) { + LOGGER.info("Skip Multi-language tests because environment variable " + + "ENABLE_MULTI_LANGUAGE_TESTS isn't set"); + throw new SkipException("Skip test."); + } + + // Delete existing socket files. + for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) { + File file = new File(socket); + if (file.exists()) { + file.delete(); + } + } + + // Start ray cluster. + String workerOptions = + " -classpath " + System.getProperty("java.class.path"); + final List startCommand = ImmutableList.of( + "ray", + "start", + "--head", + "--redis-port=6379", + String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME), + String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME), + "--load-code-from-local", + "--include-java", + "--java-worker-options=" + workerOptions + ); + if (!executeCommand(startCommand, 10, getRayStartEnv())) { + throw new RuntimeException("Couldn't start ray cluster."); + } + + // Connect to the cluster. + System.setProperty("ray.redis.address", "127.0.0.1:6379"); + System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME); + System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME); + Ray.init(); + } + + /** + * @return The environment variables needed for the `ray start` command. + */ + protected Map getRayStartEnv() { + return ImmutableMap.of(); + } + + @AfterClass + public void tearDown() { + // Disconnect to the cluster. + Ray.shutdown(); + System.clearProperty("ray.redis.address"); + System.clearProperty("ray.object-store.socket-name"); + System.clearProperty("ray.raylet.socket-name"); + + // Stop ray cluster. + final List stopCommand = ImmutableList.of( + "ray", + "stop" + ); + if (!executeCommand(stopCommand, 10, ImmutableMap.of())) { + throw new RuntimeException("Couldn't stop ray cluster"); + } + } +} diff --git a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java new file mode 100644 index 000000000..2f75c7b54 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java @@ -0,0 +1,54 @@ +package org.ray.api.test; + +import com.google.common.collect.ImmutableMap; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.util.Map; +import org.apache.commons.io.FileUtils; +import org.ray.api.Ray; +import org.ray.api.RayObject; +import org.ray.api.RayPyActor; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { + + private static final String PYTHON_MODULE = "test_cross_language_invocation"; + + @Override + protected Map getRayStartEnv() { + // Delete and re-create the temp dir. + File tempDir = new File( + System.getProperty("java.io.tmpdir") + File.separator + "ray_cross_language_test"); + FileUtils.deleteQuietly(tempDir); + tempDir.mkdirs(); + tempDir.deleteOnExit(); + + // Write the test Python file to the temp dir. + InputStream in = CrossLanguageInvocationTest.class + .getResourceAsStream("/" + PYTHON_MODULE + ".py"); + File pythonFile = new File( + tempDir.getAbsolutePath() + File.separator + PYTHON_MODULE + ".py"); + try { + FileUtils.copyInputStreamToFile(in, pythonFile); + } catch (IOException e) { + throw new RuntimeException(e); + } + + return ImmutableMap.of("PYTHONPATH", tempDir.getAbsolutePath()); + } + + @Test + public void testCallingPythonFunction() { + RayObject res = Ray.callPy(PYTHON_MODULE, "py_func", "hello".getBytes()); + Assert.assertEquals(res.get(), "Response from Python: hello".getBytes()); + } + + @Test + public void testCallingPythonActor() { + RayPyActor actor = Ray.createPyActor(PYTHON_MODULE, "Counter", "1".getBytes()); + RayObject res = Ray.callPy(actor, "increase", "1".getBytes()); + Assert.assertEquals(res.get(), "2".getBytes()); + } +} diff --git a/java/test/src/main/java/org/ray/api/test/MultiLanguageClusterTest.java b/java/test/src/main/java/org/ray/api/test/MultiLanguageClusterTest.java index ad3033681..043eb0aae 100644 --- a/java/test/src/main/java/org/ray/api/test/MultiLanguageClusterTest.java +++ b/java/test/src/main/java/org/ray/api/test/MultiLanguageClusterTest.java @@ -1,113 +1,18 @@ package org.ray.api.test; -import com.google.common.collect.ImmutableList; -import java.io.File; -import java.lang.reflect.Method; -import java.util.List; -import java.util.concurrent.TimeUnit; import org.ray.api.Ray; import org.ray.api.RayObject; import org.ray.api.annotation.RayRemote; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.testng.Assert; -import org.testng.SkipException; -import org.testng.annotations.AfterMethod; -import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -/** - * Test starting a ray cluster with multi-language support. - */ -public class MultiLanguageClusterTest { - - private static final Logger LOGGER = LoggerFactory.getLogger(MultiLanguageClusterTest.class); - - private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/test/plasma_store_socket"; - private static final String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket"; +public class MultiLanguageClusterTest extends BaseMultiLanguageTest { @RayRemote public static String echo(String word) { return word; } - /** - * Execute an external command. - * - * @return Whether the command succeeded. - */ - private boolean executeCommand(List command, int waitTimeoutSeconds) { - try { - LOGGER.info("Executing command: {}", String.join(" ", command)); - Process process = new ProcessBuilder(command).inheritIO().start(); - process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS); - return process.exitValue() == 0; - } catch (Exception e) { - throw new RuntimeException("Error executing command " + String.join(" ", command), e); - } - } - - @BeforeMethod - public void setUp(Method method) { - String testName = method.getName(); - if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) { - LOGGER.info("Skip " + testName + - " because env variable ENABLE_MULTI_LANGUAGE_TESTS isn't set"); - throw new SkipException("Skip test."); - } - - // Delete existing socket files. - for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) { - File file = new File(socket); - if (file.exists()) { - file.delete(); - } - } - - // Start ray cluster. - String testDir = System.getProperty("user.dir"); - String workerOptions = - " -classpath " + String.format("%s/../../build/java/*:%s/target/*", testDir, testDir); - final List startCommand = ImmutableList.of( - "ray", - "start", - "--head", - "--redis-port=6379", - String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME), - String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME), - "--load-code-from-local", - "--include-java", - "--java-worker-options=" + workerOptions - ); - if (!executeCommand(startCommand, 10)) { - throw new RuntimeException("Couldn't start ray cluster."); - } - - // Connect to the cluster. - System.setProperty("ray.redis.address", "127.0.0.1:6379"); - System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME); - System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME); - Ray.init(); - } - - @AfterMethod - public void tearDown() { - // Disconnect to the cluster. - Ray.shutdown(); - System.clearProperty("ray.redis.address"); - System.clearProperty("ray.object-store.socket-name"); - System.clearProperty("ray.raylet.socket-name"); - - // Stop ray cluster. - final List stopCommand = ImmutableList.of( - "ray", - "stop" - ); - if (!executeCommand(stopCommand, 10)) { - throw new RuntimeException("Couldn't stop ray cluster"); - } - } - @Test public void testMultiLanguageCluster() { RayObject obj = Ray.call(MultiLanguageClusterTest::echo, "hello"); diff --git a/java/test/src/main/resources/test_cross_language_invocation.py b/java/test/src/main/resources/test_cross_language_invocation.py new file mode 100644 index 000000000..7e78b2eba --- /dev/null +++ b/java/test/src/main/resources/test_cross_language_invocation.py @@ -0,0 +1,20 @@ +# This file is used by CrossLanguageInvocationTest.java to test cross-language +# invocation. +import ray +import six + + +@ray.remote +def py_func(value): + assert isinstance(value, bytes) + return b"Response from Python: " + value + + +@ray.remote +class Counter(object): + def __init__(self, value): + self.value = int(value) + + def increase(self, delta): + self.value += int(delta) + return str(self.value).encode("utf-8") if six.PY3 else str(self.value) diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index 04527f2eb..a54a3679c 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -111,3 +111,6 @@ PROCESS_TYPE_REDIS_SERVER = "redis_server" PROCESS_TYPE_WEB_UI = "web_ui" LOG_MONITOR_MAX_OPEN_FILES = 200 + +# A constant used as object metadata to indicate the object is raw binary. +RAW_BUFFER_METADATA = b"RAW" diff --git a/python/ray/worker.py b/python/ray/worker.py index 24278a88b..aa65def06 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -287,12 +287,22 @@ class Worker(object): "type {}.".format(type(value))) counter += 1 try: - self.plasma_client.put( - value, - object_id=pyarrow.plasma.ObjectID(object_id.binary()), - memcopy_threads=self.memcopy_threads, - serialization_context=self.get_serialization_context( - self.task_driver_id)) + if isinstance(value, bytes): + # If the object is a byte array, skip serializing it and + # use a special metadata to indicate it's raw binary. So + # that this object can also be read by Java. + self.plasma_client.put_raw_buffer( + value, + object_id=pyarrow.plasma.ObjectID(object_id.binary()), + metadata=ray_constants.RAW_BUFFER_METADATA, + memcopy_threads=self.memcopy_threads) + else: + self.plasma_client.put( + value, + object_id=pyarrow.plasma.ObjectID(object_id.binary()), + memcopy_threads=self.memcopy_threads, + serialization_context=self.get_serialization_context( + self.task_driver_id)) break except pyarrow.SerializationCallbackError as e: try: @@ -437,7 +447,10 @@ class Worker(object): def _deserialize_object_from_arrow(self, data, metadata, object_id, serialization_context): if metadata: - # If metadata is not empty, return an exception object based on + # Check if the object should be returned as raw bytes. + if metadata == ray_constants.RAW_BUFFER_METADATA: + return data.to_pybytes() + # Otherwise, return an exception object based on # the error type. error_type = int(metadata) if error_type == ErrorType.WORKER_DIED: diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 5595c3657..bc0ecfb07 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -96,8 +96,6 @@ table TaskInfo { // uniquely describe a function. // For a Python function, it should be: [module_name, class_name, function_name] // For a Java function, it should be: [class_name, method_name, type_descriptor] - // TODO(hchen): after changing Python worker to use function_descriptor, - // function_id can be removed. function_descriptor: [string]; }