Java call Python actor method use actor.call (#7614)

This commit is contained in:
fyrestone
2020-03-17 14:52:43 +08:00
committed by GitHub
parent ffa9df4683
commit 7697ea2be2
10 changed files with 127 additions and 46 deletions
@@ -135,7 +135,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
@Override
public RayObject callPy(RayPyActor pyActor, String functionName, Object... args) {
public RayObject callPyActor(RayPyActor pyActor, String functionName, Object... args) {
checkPyArguments(args);
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(),
pyActor.getClassName(), functionName);
@@ -171,8 +171,8 @@ public class RayMultiWorkerNativeRuntime implements RayRuntime {
}
@Override
public RayObject callPy(RayPyActor pyActor, String functionName, Object[] args) {
return getCurrentRuntime().callPy(pyActor, functionName, args);
public RayObject callPyActor(RayPyActor pyActor, String functionName, Object[] args) {
return getCurrentRuntime().callPyActor(pyActor, functionName, args);
}
@Override
@@ -63,10 +63,6 @@ public class RayCallGenerator extends BaseGenerator {
buildPyCalls(i, false, false, false);
buildPyCalls(i, false, false, true);
}
// TODO(hchen): move Python actor call API to `RayPyActor` class.
for (int i = 0; i <= MAX_PARAMETERS - 1; i++) {
buildPyCalls(i, true, false, false);
}
for (int i = 0; i <= MAX_PARAMETERS; i++) {
buildPyCalls(i, false, true, false);
buildPyCalls(i, false, true, true);
@@ -102,7 +98,29 @@ public class RayCallGenerator extends BaseGenerator {
buildCalls(i, true, false, true, false);
buildCalls(i, true, false, false, false);
}
newLine("}");
return sb.toString();
}
/**
* @return Whole file content of `PyActorCall.java`.
*/
private String generatePyActorCallDotJava() {
sb = new StringBuilder();
newLine("// Generated by `RayCallGenerator.java`. DO NOT EDIT.");
newLine("");
newLine("package org.ray.api;");
newLine("");
newLine("/**");
newLine(" * This class provides type-safe interfaces for remote actor calls.");
newLine(" **/");
newLine("@SuppressWarnings({\"rawtypes\", \"unchecked\"})");
newLine("interface PyActorCall {");
newLine("");
for (int i = 0; i <= MAX_PARAMETERS - 1; i++) {
buildPyCalls(i, true, false, false);
}
newLine("}");
return sb.toString();
}
@@ -213,13 +231,20 @@ public class RayCallGenerator extends BaseGenerator {
}
/**
* Build the `Ray.callPy` or `Ray.createPyActor` methods.
* Build `Ray.callPy`, `Ray.createPyActor` and `actor.call` methods with
* the given number of parameters.
*
* @param forActor Build actor api when true, otherwise build task api.
* @param numParameters the number of parameters
* @param forActor Build `actor.call` when true, otherwise build `Ray.callPy`.
* @param forActorCreation Build `Ray.createPyActor` when true, otherwise build `Ray.callPy`.
* @param hasOptionsParam Add ActorCreationOptions if forActorCreation is true;
* Add CallOptions if forActorCreation is false;
* No additional param if hasOptionsParam is false.
*/
private void buildPyCalls(int numParameters, boolean forActor,
boolean forActorCreation, boolean hasOptionsParam) {
String modifiers = forActor ? "default" : "public static";
String argList = "";
String paramList = "";
for (int i = 0; i < numParameters; i++) {
@@ -239,8 +264,8 @@ public class RayCallGenerator extends BaseGenerator {
paramPrefix += "String moduleName, String className";
funcArgs += "moduleName, className";
} else if (forActor) {
paramPrefix += "RayPyActor pyActor, String functionName";
funcArgs += "pyActor, functionName";
paramPrefix += "String functionName";
funcArgs += "functionName";
} else {
paramPrefix += "String moduleName, String functionName";
funcArgs += "moduleName, functionName";
@@ -268,17 +293,26 @@ public class RayCallGenerator extends BaseGenerator {
}
String returnType = !forActorCreation ? "RayObject" : "RayPyActor";
String funcName = !forActorCreation ? "callPy" : "createPyActor";
String funcName = forActorCreation ? "createPyActor" : forActor ? "call" : "callPy";
String internalCallFunc = forActorCreation ? "createPyActor" :
forActor ? "callPyActor" : "callPy";
funcArgs += ", args";
// Method signature.
newLine(1, String.format(
"public static %s %s(%s%s) {",
"%s %s %s(%s%s) {", modifiers,
returnType, funcName, paramPrefix + paramList, optionsParam
));
// Method body.
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg));
if (forActor) {
newLine(2, String.format("return Ray.internal().%s((RayPyActor)this, %s%s);",
internalCallFunc, funcArgs, optionsArg));
} else {
newLine(2, String.format("return Ray.internal().%s(%s%s);",
internalCallFunc, funcArgs, optionsArg));
}
newLine(1, "}");
newLine("");
}
private List<String> generateParameters(int numParams) {
@@ -307,6 +341,9 @@ public class RayCallGenerator extends BaseGenerator {
+ "/api/src/main/java/org/ray/api/ActorCall.java";
FileUtils.write(new File(path), new RayCallGenerator().generateActorCallDotJava(),
Charset.defaultCharset());
path = System.getProperty("user.dir")
+ "/api/src/main/java/org/ray/api/PyActorCall.java";
FileUtils.write(new File(path), new RayCallGenerator().generatePyActorCallDotJava(),
Charset.defaultCharset());
}
}