Java call Python actor method use actor.call (#7614)

This commit is contained in:
fyrestone
2020-03-17 14:52:43 +08:00
committed by GitHub
parent ffa9df4683
commit 7697ea2be2
10 changed files with 127 additions and 46 deletions
@@ -0,0 +1,41 @@
// Generated by `RayCallGenerator.java`. DO NOT EDIT.
package org.ray.api;
/**
* This class provides type-safe interfaces for remote actor calls.
**/
@SuppressWarnings({"rawtypes", "unchecked"})
interface PyActorCall {
default RayObject call(String functionName) {
Object[] args = new Object[]{};
return Ray.internal().callPyActor((RayPyActor)this, functionName, args);
}
default RayObject call(String functionName, Object obj0) {
Object[] args = new Object[]{obj0};
return Ray.internal().callPyActor((RayPyActor)this, functionName, args);
}
default RayObject call(String functionName, Object obj0, Object obj1) {
Object[] args = new Object[]{obj0, obj1};
return Ray.internal().callPyActor((RayPyActor)this, functionName, args);
}
default RayObject call(String functionName, Object obj0, Object obj1, Object obj2) {
Object[] args = new Object[]{obj0, obj1, obj2};
return Ray.internal().callPyActor((RayPyActor)this, functionName, args);
}
default RayObject call(String functionName, Object obj0, Object obj1, Object obj2, Object obj3) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
return Ray.internal().callPyActor((RayPyActor)this, functionName, args);
}
default RayObject call(String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
return Ray.internal().callPyActor((RayPyActor)this, functionName, args);
}
}
+28 -24
View File
@@ -3847,136 +3847,140 @@ class RayCall {
Object[] args = new Object[]{};
return Ray.internal().callPy(moduleName, functionName, args, null);
}
public static RayObject callPy(String moduleName, String functionName, CallOptions options) {
Object[] args = new Object[]{};
return Ray.internal().callPy(moduleName, functionName, args, options);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0) {
Object[] args = new Object[]{obj0};
return Ray.internal().callPy(moduleName, functionName, args, null);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, CallOptions options) {
Object[] args = new Object[]{obj0};
return Ray.internal().callPy(moduleName, functionName, args, options);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1) {
Object[] args = new Object[]{obj0, obj1};
return Ray.internal().callPy(moduleName, functionName, args, null);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, CallOptions options) {
Object[] args = new Object[]{obj0, obj1};
return Ray.internal().callPy(moduleName, functionName, args, options);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2) {
Object[] args = new Object[]{obj0, obj1, obj2};
return Ray.internal().callPy(moduleName, functionName, args, null);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, CallOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2};
return Ray.internal().callPy(moduleName, functionName, args, options);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
return Ray.internal().callPy(moduleName, functionName, args, null);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, CallOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
return Ray.internal().callPy(moduleName, functionName, args, options);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
return Ray.internal().callPy(moduleName, functionName, args, null);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, CallOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
return Ray.internal().callPy(moduleName, functionName, args, options);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5};
return Ray.internal().callPy(moduleName, functionName, args, null);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5, CallOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5};
return Ray.internal().callPy(moduleName, functionName, args, options);
}
public static RayObject callPy(RayPyActor pyActor, String functionName) {
Object[] args = new Object[]{};
return Ray.internal().callPy(pyActor, functionName, args);
}
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0) {
Object[] args = new Object[]{obj0};
return Ray.internal().callPy(pyActor, functionName, args);
}
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1) {
Object[] args = new Object[]{obj0, obj1};
return Ray.internal().callPy(pyActor, functionName, args);
}
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2) {
Object[] args = new Object[]{obj0, obj1, obj2};
return Ray.internal().callPy(pyActor, functionName, args);
}
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2, Object obj3) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
return Ray.internal().callPy(pyActor, functionName, args);
}
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
return Ray.internal().callPy(pyActor, functionName, args);
}
public static RayPyActor createPyActor(String moduleName, String className) {
Object[] args = new Object[]{};
return Ray.internal().createPyActor(moduleName, className, args, null);
}
public static RayPyActor createPyActor(String moduleName, String className, ActorCreationOptions options) {
Object[] args = new Object[]{};
return Ray.internal().createPyActor(moduleName, className, args, options);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0) {
Object[] args = new Object[]{obj0};
return Ray.internal().createPyActor(moduleName, className, args, null);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, ActorCreationOptions options) {
Object[] args = new Object[]{obj0};
return Ray.internal().createPyActor(moduleName, className, args, options);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1) {
Object[] args = new Object[]{obj0, obj1};
return Ray.internal().createPyActor(moduleName, className, args, null);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, ActorCreationOptions options) {
Object[] args = new Object[]{obj0, obj1};
return Ray.internal().createPyActor(moduleName, className, args, options);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2) {
Object[] args = new Object[]{obj0, obj1, obj2};
return Ray.internal().createPyActor(moduleName, className, args, null);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, ActorCreationOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2};
return Ray.internal().createPyActor(moduleName, className, args, options);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
return Ray.internal().createPyActor(moduleName, className, args, null);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, ActorCreationOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
return Ray.internal().createPyActor(moduleName, className, args, options);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
return Ray.internal().createPyActor(moduleName, className, args, null);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, ActorCreationOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
return Ray.internal().createPyActor(moduleName, className, args, options);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5};
return Ray.internal().createPyActor(moduleName, className, args, null);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5, ActorCreationOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5};
return Ray.internal().createPyActor(moduleName, className, args, options);
}
}
@@ -3,7 +3,7 @@ package org.ray.api;
/**
* Handle of a Python actor.
*/
public interface RayPyActor extends RayActor {
public interface RayPyActor extends RayActor, PyActorCall {
/**
* @return Module name of the Python actor class.
@@ -137,7 +137,7 @@ public interface RayRuntime {
* @param args Arguments of the function.
* @return The result object.
*/
RayObject callPy(RayPyActor pyActor, String functionName, Object[] args);
RayObject callPyActor(RayPyActor pyActor, String functionName, Object[] args);
/**
* Create a Python actor on a remote node.
@@ -135,7 +135,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
@Override
public RayObject callPy(RayPyActor pyActor, String functionName, Object... args) {
public RayObject callPyActor(RayPyActor pyActor, String functionName, Object... args) {
checkPyArguments(args);
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(),
pyActor.getClassName(), functionName);
@@ -171,8 +171,8 @@ public class RayMultiWorkerNativeRuntime implements RayRuntime {
}
@Override
public RayObject callPy(RayPyActor pyActor, String functionName, Object[] args) {
return getCurrentRuntime().callPy(pyActor, functionName, args);
public RayObject callPyActor(RayPyActor pyActor, String functionName, Object[] args) {
return getCurrentRuntime().callPyActor(pyActor, functionName, args);
}
@Override
@@ -63,10 +63,6 @@ public class RayCallGenerator extends BaseGenerator {
buildPyCalls(i, false, false, false);
buildPyCalls(i, false, false, true);
}
// TODO(hchen): move Python actor call API to `RayPyActor` class.
for (int i = 0; i <= MAX_PARAMETERS - 1; i++) {
buildPyCalls(i, true, false, false);
}
for (int i = 0; i <= MAX_PARAMETERS; i++) {
buildPyCalls(i, false, true, false);
buildPyCalls(i, false, true, true);
@@ -102,7 +98,29 @@ public class RayCallGenerator extends BaseGenerator {
buildCalls(i, true, false, true, false);
buildCalls(i, true, false, false, false);
}
newLine("}");
return sb.toString();
}
/**
* @return Whole file content of `PyActorCall.java`.
*/
private String generatePyActorCallDotJava() {
sb = new StringBuilder();
newLine("// Generated by `RayCallGenerator.java`. DO NOT EDIT.");
newLine("");
newLine("package org.ray.api;");
newLine("");
newLine("/**");
newLine(" * This class provides type-safe interfaces for remote actor calls.");
newLine(" **/");
newLine("@SuppressWarnings({\"rawtypes\", \"unchecked\"})");
newLine("interface PyActorCall {");
newLine("");
for (int i = 0; i <= MAX_PARAMETERS - 1; i++) {
buildPyCalls(i, true, false, false);
}
newLine("}");
return sb.toString();
}
@@ -213,13 +231,20 @@ public class RayCallGenerator extends BaseGenerator {
}
/**
* Build the `Ray.callPy` or `Ray.createPyActor` methods.
* Build `Ray.callPy`, `Ray.createPyActor` and `actor.call` methods with
* the given number of parameters.
*
* @param forActor Build actor api when true, otherwise build task api.
* @param numParameters the number of parameters
* @param forActor Build `actor.call` when true, otherwise build `Ray.callPy`.
* @param forActorCreation Build `Ray.createPyActor` when true, otherwise build `Ray.callPy`.
* @param hasOptionsParam Add ActorCreationOptions if forActorCreation is true;
* Add CallOptions if forActorCreation is false;
* No additional param if hasOptionsParam is false.
*/
private void buildPyCalls(int numParameters, boolean forActor,
boolean forActorCreation, boolean hasOptionsParam) {
String modifiers = forActor ? "default" : "public static";
String argList = "";
String paramList = "";
for (int i = 0; i < numParameters; i++) {
@@ -239,8 +264,8 @@ public class RayCallGenerator extends BaseGenerator {
paramPrefix += "String moduleName, String className";
funcArgs += "moduleName, className";
} else if (forActor) {
paramPrefix += "RayPyActor pyActor, String functionName";
funcArgs += "pyActor, functionName";
paramPrefix += "String functionName";
funcArgs += "functionName";
} else {
paramPrefix += "String moduleName, String functionName";
funcArgs += "moduleName, functionName";
@@ -268,17 +293,26 @@ public class RayCallGenerator extends BaseGenerator {
}
String returnType = !forActorCreation ? "RayObject" : "RayPyActor";
String funcName = !forActorCreation ? "callPy" : "createPyActor";
String funcName = forActorCreation ? "createPyActor" : forActor ? "call" : "callPy";
String internalCallFunc = forActorCreation ? "createPyActor" :
forActor ? "callPyActor" : "callPy";
funcArgs += ", args";
// Method signature.
newLine(1, String.format(
"public static %s %s(%s%s) {",
"%s %s %s(%s%s) {", modifiers,
returnType, funcName, paramPrefix + paramList, optionsParam
));
// Method body.
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg));
if (forActor) {
newLine(2, String.format("return Ray.internal().%s((RayPyActor)this, %s%s);",
internalCallFunc, funcArgs, optionsArg));
} else {
newLine(2, String.format("return Ray.internal().%s(%s%s);",
internalCallFunc, funcArgs, optionsArg));
}
newLine(1, "}");
newLine("");
}
private List<String> generateParameters(int numParams) {
@@ -307,6 +341,9 @@ public class RayCallGenerator extends BaseGenerator {
+ "/api/src/main/java/org/ray/api/ActorCall.java";
FileUtils.write(new File(path), new RayCallGenerator().generateActorCallDotJava(),
Charset.defaultCharset());
path = System.getProperty("user.dir")
+ "/api/src/main/java/org/ray/api/PyActorCall.java";
FileUtils.write(new File(path), new RayCallGenerator().generatePyActorCallDotJava(),
Charset.defaultCharset());
}
}
@@ -61,7 +61,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
@Test
public void testCallingPythonActor() {
RayPyActor actor = Ray.createPyActor(PYTHON_MODULE, "Counter", "1".getBytes());
RayObject res = Ray.callPy(actor, "increase", "1".getBytes());
RayObject res = actor.call("increase", "1".getBytes());
Assert.assertEquals(res.get(), "2".getBytes());
}
@@ -108,7 +108,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
public static byte[] callPythonActorHandle(byte[] value) {
// This function will be called from test_cross_language_invocation.py
NativeRayPyActor actor = (NativeRayPyActor)NativeRayActor.fromBytes(value);
RayObject res = Ray.callPy(actor, "increase", "1".getBytes());
RayObject res = actor.call("increase", "1".getBytes());
Assert.assertEquals(res.get(), "3".getBytes());
return (byte[])res.get();
}