From 7697ea2be26bfc1d6ffbf2f10efb6a166fde37c1 Mon Sep 17 00:00:00 2001 From: fyrestone Date: Tue, 17 Mar 2020 14:52:43 +0800 Subject: [PATCH] Java call Python actor method use actor.call (#7614) --- .../main/java/org/ray/api/PyActorCall.java | 41 +++++++++++++ .../src/main/java/org/ray/api/RayCall.java | 52 ++++++++-------- .../src/main/java/org/ray/api/RayPyActor.java | 2 +- .../java/org/ray/api/runtime/RayRuntime.java | 2 +- .../org/ray/runtime/AbstractRayRuntime.java | 2 +- .../runtime/RayMultiWorkerNativeRuntime.java | 4 +- .../util/generator/RayCallGenerator.java | 61 +++++++++++++++---- .../api/test/CrossLanguageInvocationTest.java | 4 +- python/ray/scripts/scripts.py | 2 +- .../runtime/schedule/JobSchedulerImpl.java | 3 +- 10 files changed, 127 insertions(+), 46 deletions(-) create mode 100644 java/api/src/main/java/org/ray/api/PyActorCall.java diff --git a/java/api/src/main/java/org/ray/api/PyActorCall.java b/java/api/src/main/java/org/ray/api/PyActorCall.java new file mode 100644 index 000000000..a462a208b --- /dev/null +++ b/java/api/src/main/java/org/ray/api/PyActorCall.java @@ -0,0 +1,41 @@ +// Generated by `RayCallGenerator.java`. DO NOT EDIT. + +package org.ray.api; + +/** + * This class provides type-safe interfaces for remote actor calls. + **/ +@SuppressWarnings({"rawtypes", "unchecked"}) +interface PyActorCall { + + default RayObject call(String functionName) { + Object[] args = new Object[]{}; + return Ray.internal().callPyActor((RayPyActor)this, functionName, args); + } + + default RayObject call(String functionName, Object obj0) { + Object[] args = new Object[]{obj0}; + return Ray.internal().callPyActor((RayPyActor)this, functionName, args); + } + + default RayObject call(String functionName, Object obj0, Object obj1) { + Object[] args = new Object[]{obj0, obj1}; + return Ray.internal().callPyActor((RayPyActor)this, functionName, args); + } + + default RayObject call(String functionName, Object obj0, Object obj1, Object obj2) { + Object[] args = new Object[]{obj0, obj1, obj2}; + return Ray.internal().callPyActor((RayPyActor)this, functionName, args); + } + + default RayObject call(String functionName, Object obj0, Object obj1, Object obj2, Object obj3) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3}; + return Ray.internal().callPyActor((RayPyActor)this, functionName, args); + } + + default RayObject call(String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) { + Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4}; + return Ray.internal().callPyActor((RayPyActor)this, functionName, args); + } + +} diff --git a/java/api/src/main/java/org/ray/api/RayCall.java b/java/api/src/main/java/org/ray/api/RayCall.java index 4bd537d0e..62259fba1 100644 --- a/java/api/src/main/java/org/ray/api/RayCall.java +++ b/java/api/src/main/java/org/ray/api/RayCall.java @@ -3847,136 +3847,140 @@ class RayCall { Object[] args = new Object[]{}; return Ray.internal().callPy(moduleName, functionName, args, null); } + public static RayObject callPy(String moduleName, String functionName, CallOptions options) { Object[] args = new Object[]{}; return Ray.internal().callPy(moduleName, functionName, args, options); } + public static RayObject callPy(String moduleName, String functionName, Object obj0) { Object[] args = new Object[]{obj0}; return Ray.internal().callPy(moduleName, functionName, args, null); } + public static RayObject callPy(String moduleName, String functionName, Object obj0, CallOptions options) { Object[] args = new Object[]{obj0}; return Ray.internal().callPy(moduleName, functionName, args, options); } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1) { Object[] args = new Object[]{obj0, obj1}; return Ray.internal().callPy(moduleName, functionName, args, null); } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, CallOptions options) { Object[] args = new Object[]{obj0, obj1}; return Ray.internal().callPy(moduleName, functionName, args, options); } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2) { Object[] args = new Object[]{obj0, obj1, obj2}; return Ray.internal().callPy(moduleName, functionName, args, null); } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, CallOptions options) { Object[] args = new Object[]{obj0, obj1, obj2}; return Ray.internal().callPy(moduleName, functionName, args, options); } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3) { Object[] args = new Object[]{obj0, obj1, obj2, obj3}; return Ray.internal().callPy(moduleName, functionName, args, null); } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, CallOptions options) { Object[] args = new Object[]{obj0, obj1, obj2, obj3}; return Ray.internal().callPy(moduleName, functionName, args, options); } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) { Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4}; return Ray.internal().callPy(moduleName, functionName, args, null); } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, CallOptions options) { Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4}; return Ray.internal().callPy(moduleName, functionName, args, options); } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5) { Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5}; return Ray.internal().callPy(moduleName, functionName, args, null); } + public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5, CallOptions options) { Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5}; return Ray.internal().callPy(moduleName, functionName, args, options); } - public static RayObject callPy(RayPyActor pyActor, String functionName) { - Object[] args = new Object[]{}; - return Ray.internal().callPy(pyActor, functionName, args); - } - public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0) { - Object[] args = new Object[]{obj0}; - return Ray.internal().callPy(pyActor, functionName, args); - } - public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1) { - Object[] args = new Object[]{obj0, obj1}; - return Ray.internal().callPy(pyActor, functionName, args); - } - public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2) { - Object[] args = new Object[]{obj0, obj1, obj2}; - return Ray.internal().callPy(pyActor, functionName, args); - } - public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2, Object obj3) { - Object[] args = new Object[]{obj0, obj1, obj2, obj3}; - return Ray.internal().callPy(pyActor, functionName, args); - } - public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) { - Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4}; - return Ray.internal().callPy(pyActor, functionName, args); - } + public static RayPyActor createPyActor(String moduleName, String className) { Object[] args = new Object[]{}; return Ray.internal().createPyActor(moduleName, className, args, null); } + public static RayPyActor createPyActor(String moduleName, String className, ActorCreationOptions options) { Object[] args = new Object[]{}; return Ray.internal().createPyActor(moduleName, className, args, options); } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0) { Object[] args = new Object[]{obj0}; return Ray.internal().createPyActor(moduleName, className, args, null); } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, ActorCreationOptions options) { Object[] args = new Object[]{obj0}; return Ray.internal().createPyActor(moduleName, className, args, options); } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1) { Object[] args = new Object[]{obj0, obj1}; return Ray.internal().createPyActor(moduleName, className, args, null); } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, ActorCreationOptions options) { Object[] args = new Object[]{obj0, obj1}; return Ray.internal().createPyActor(moduleName, className, args, options); } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2) { Object[] args = new Object[]{obj0, obj1, obj2}; return Ray.internal().createPyActor(moduleName, className, args, null); } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, ActorCreationOptions options) { Object[] args = new Object[]{obj0, obj1, obj2}; return Ray.internal().createPyActor(moduleName, className, args, options); } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3) { Object[] args = new Object[]{obj0, obj1, obj2, obj3}; return Ray.internal().createPyActor(moduleName, className, args, null); } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, ActorCreationOptions options) { Object[] args = new Object[]{obj0, obj1, obj2, obj3}; return Ray.internal().createPyActor(moduleName, className, args, options); } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) { Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4}; return Ray.internal().createPyActor(moduleName, className, args, null); } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, ActorCreationOptions options) { Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4}; return Ray.internal().createPyActor(moduleName, className, args, options); } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5) { Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5}; return Ray.internal().createPyActor(moduleName, className, args, null); } + public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5, ActorCreationOptions options) { Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5}; return Ray.internal().createPyActor(moduleName, className, args, options); } + } diff --git a/java/api/src/main/java/org/ray/api/RayPyActor.java b/java/api/src/main/java/org/ray/api/RayPyActor.java index 4f32bc4f4..cde6c8df7 100644 --- a/java/api/src/main/java/org/ray/api/RayPyActor.java +++ b/java/api/src/main/java/org/ray/api/RayPyActor.java @@ -3,7 +3,7 @@ package org.ray.api; /** * Handle of a Python actor. */ -public interface RayPyActor extends RayActor { +public interface RayPyActor extends RayActor, PyActorCall { /** * @return Module name of the Python actor class. diff --git a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java index 61c3b01e9..7854821ab 100644 --- a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java @@ -137,7 +137,7 @@ public interface RayRuntime { * @param args Arguments of the function. * @return The result object. */ - RayObject callPy(RayPyActor pyActor, String functionName, Object[] args); + RayObject callPyActor(RayPyActor pyActor, String functionName, Object[] args); /** * Create a Python actor on a remote node. diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 84990d5b4..1fd8372e8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -135,7 +135,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { } @Override - public RayObject callPy(RayPyActor pyActor, String functionName, Object... args) { + public RayObject callPyActor(RayPyActor pyActor, String functionName, Object... args) { checkPyArguments(args); PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(), pyActor.getClassName(), functionName); diff --git a/java/runtime/src/main/java/org/ray/runtime/RayMultiWorkerNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayMultiWorkerNativeRuntime.java index 4fc5e8752..176201232 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayMultiWorkerNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayMultiWorkerNativeRuntime.java @@ -171,8 +171,8 @@ public class RayMultiWorkerNativeRuntime implements RayRuntime { } @Override - public RayObject callPy(RayPyActor pyActor, String functionName, Object[] args) { - return getCurrentRuntime().callPy(pyActor, functionName, args); + public RayObject callPyActor(RayPyActor pyActor, String functionName, Object[] args) { + return getCurrentRuntime().callPyActor(pyActor, functionName, args); } @Override diff --git a/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java b/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java index 21330bc8c..c2bfcffcd 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java @@ -63,10 +63,6 @@ public class RayCallGenerator extends BaseGenerator { buildPyCalls(i, false, false, false); buildPyCalls(i, false, false, true); } - // TODO(hchen): move Python actor call API to `RayPyActor` class. - for (int i = 0; i <= MAX_PARAMETERS - 1; i++) { - buildPyCalls(i, true, false, false); - } for (int i = 0; i <= MAX_PARAMETERS; i++) { buildPyCalls(i, false, true, false); buildPyCalls(i, false, true, true); @@ -102,7 +98,29 @@ public class RayCallGenerator extends BaseGenerator { buildCalls(i, true, false, true, false); buildCalls(i, true, false, false, false); } + newLine("}"); + return sb.toString(); + } + /** + * @return Whole file content of `PyActorCall.java`. + */ + private String generatePyActorCallDotJava() { + sb = new StringBuilder(); + + newLine("// Generated by `RayCallGenerator.java`. DO NOT EDIT."); + newLine(""); + newLine("package org.ray.api;"); + newLine(""); + newLine("/**"); + newLine(" * This class provides type-safe interfaces for remote actor calls."); + newLine(" **/"); + newLine("@SuppressWarnings({\"rawtypes\", \"unchecked\"})"); + newLine("interface PyActorCall {"); + newLine(""); + for (int i = 0; i <= MAX_PARAMETERS - 1; i++) { + buildPyCalls(i, true, false, false); + } newLine("}"); return sb.toString(); } @@ -213,13 +231,20 @@ public class RayCallGenerator extends BaseGenerator { } /** - * Build the `Ray.callPy` or `Ray.createPyActor` methods. + * Build `Ray.callPy`, `Ray.createPyActor` and `actor.call` methods with + * the given number of parameters. * - * @param forActor Build actor api when true, otherwise build task api. + * @param numParameters the number of parameters + * @param forActor Build `actor.call` when true, otherwise build `Ray.callPy`. * @param forActorCreation Build `Ray.createPyActor` when true, otherwise build `Ray.callPy`. + * @param hasOptionsParam Add ActorCreationOptions if forActorCreation is true; + * Add CallOptions if forActorCreation is false; + * No additional param if hasOptionsParam is false. */ private void buildPyCalls(int numParameters, boolean forActor, boolean forActorCreation, boolean hasOptionsParam) { + String modifiers = forActor ? "default" : "public static"; + String argList = ""; String paramList = ""; for (int i = 0; i < numParameters; i++) { @@ -239,8 +264,8 @@ public class RayCallGenerator extends BaseGenerator { paramPrefix += "String moduleName, String className"; funcArgs += "moduleName, className"; } else if (forActor) { - paramPrefix += "RayPyActor pyActor, String functionName"; - funcArgs += "pyActor, functionName"; + paramPrefix += "String functionName"; + funcArgs += "functionName"; } else { paramPrefix += "String moduleName, String functionName"; funcArgs += "moduleName, functionName"; @@ -268,17 +293,26 @@ public class RayCallGenerator extends BaseGenerator { } String returnType = !forActorCreation ? "RayObject" : "RayPyActor"; - String funcName = !forActorCreation ? "callPy" : "createPyActor"; + String funcName = forActorCreation ? "createPyActor" : forActor ? "call" : "callPy"; + String internalCallFunc = forActorCreation ? "createPyActor" : + forActor ? "callPyActor" : "callPy"; funcArgs += ", args"; // Method signature. newLine(1, String.format( - "public static %s %s(%s%s) {", + "%s %s %s(%s%s) {", modifiers, returnType, funcName, paramPrefix + paramList, optionsParam )); // Method body. newLine(2, String.format("Object[] args = new Object[]{%s};", argList)); - newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg)); + if (forActor) { + newLine(2, String.format("return Ray.internal().%s((RayPyActor)this, %s%s);", + internalCallFunc, funcArgs, optionsArg)); + } else { + newLine(2, String.format("return Ray.internal().%s(%s%s);", + internalCallFunc, funcArgs, optionsArg)); + } newLine(1, "}"); + newLine(""); } private List generateParameters(int numParams) { @@ -307,6 +341,9 @@ public class RayCallGenerator extends BaseGenerator { + "/api/src/main/java/org/ray/api/ActorCall.java"; FileUtils.write(new File(path), new RayCallGenerator().generateActorCallDotJava(), Charset.defaultCharset()); + path = System.getProperty("user.dir") + + "/api/src/main/java/org/ray/api/PyActorCall.java"; + FileUtils.write(new File(path), new RayCallGenerator().generatePyActorCallDotJava(), + Charset.defaultCharset()); } } - diff --git a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java index fe4b95dd3..cdaa80348 100644 --- a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java +++ b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java @@ -61,7 +61,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { @Test public void testCallingPythonActor() { RayPyActor actor = Ray.createPyActor(PYTHON_MODULE, "Counter", "1".getBytes()); - RayObject res = Ray.callPy(actor, "increase", "1".getBytes()); + RayObject res = actor.call("increase", "1".getBytes()); Assert.assertEquals(res.get(), "2".getBytes()); } @@ -108,7 +108,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { public static byte[] callPythonActorHandle(byte[] value) { // This function will be called from test_cross_language_invocation.py NativeRayPyActor actor = (NativeRayPyActor)NativeRayActor.fromBytes(value); - RayObject res = Ray.callPy(actor, "increase", "1".getBytes()); + RayObject res = actor.call("increase", "1".getBytes()); Assert.assertEquals(res.get(), "3".getBytes()); return (byte[])res.get(); } diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 28b9c3498..48880e10a 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -268,7 +268,7 @@ def dashboard(cluster_config_file, cluster_name, port): @click.option( "--internal-config", default=None, - type=str, + type=json.loads, help="Do NOT use this. This is for debugging/development purposes ONLY.") @click.option( "--load-code-from-local", diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/JobSchedulerImpl.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/JobSchedulerImpl.java index df239c1dd..08b717277 100644 --- a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/JobSchedulerImpl.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/JobSchedulerImpl.java @@ -67,8 +67,7 @@ public class JobSchedulerImpl implements JobScheduler { case PYTHON: byte[] workerContextBytes = buildPythonWorkerContext( taskId, executionGraphPb, jobConfig); - waits.add(Ray.callPy((RayPyActor) worker, - "init", workerContextBytes)); + waits.add(((RayPyActor)worker).call("init", workerContextBytes)); break; default: throw new UnsupportedOperationException(