Java call Python use structured function descriptors (#7634)

This commit is contained in:
fyrestone
2020-03-20 17:29:45 +08:00
committed by GitHub
parent 7d08b418fc
commit a1ae935839
14 changed files with 475 additions and 305 deletions
@@ -11,6 +11,9 @@ import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
import org.ray.api.function.PyActorClass;
import org.ray.api.function.PyActorMethod;
import org.ray.api.function.PyRemoteFunction;
import org.ray.api.function.RayFunc;
import org.ray.api.function.RayFuncVoid;
import org.ray.api.id.ObjectId;
@@ -98,7 +101,19 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
@Override
public RayObject callActor(RayFunc func, RayActor<?> actor, Object[] args) {
public RayObject call(PyRemoteFunction pyRemoteFunction, Object[] args,
CallOptions options) {
checkPyArguments(args);
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(
pyRemoteFunction.moduleName,
"",
pyRemoteFunction.functionName);
// Python functions always have a return value, even if it's `None`.
return callNormalFunction(functionDescriptor, args, /*numReturns=*/1, options);
}
@Override
public RayObject callActor(RayActor<?> actor, RayFunc func, Object[] args) {
FunctionDescriptor functionDescriptor =
functionManager.getFunction(workerContext.getCurrentJobId(), func)
.functionDescriptor;
@@ -106,6 +121,15 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return callActorFunction(actor, functionDescriptor, args, numReturns);
}
@Override
public RayObject callActor(RayPyActor pyActor, PyActorMethod pyActorMethod, Object... args) {
checkPyArguments(args);
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(),
pyActor.getClassName(), pyActorMethod.methodName);
// Python functions always have a return value, even if it's `None`.
return callActorFunction(pyActor, functionDescriptor, args, /*numReturns=*/1);
}
@Override
@SuppressWarnings("unchecked")
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc,
@@ -116,6 +140,17 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return (RayActor<T>) createActorImpl(functionDescriptor, args, options);
}
@Override
public RayPyActor createActor(PyActorClass pyActorClass, Object[] args,
ActorCreationOptions options) {
checkPyArguments(args);
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(
pyActorClass.moduleName,
pyActorClass.className,
PYTHON_INIT_METHOD_NAME);
return (RayPyActor) createActorImpl(functionDescriptor, args, options);
}
private void checkPyArguments(Object[] args) {
for (Object arg : args) {
Preconditions.checkArgument(
@@ -125,34 +160,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
}
@Override
public RayObject callPy(String moduleName, String functionName, Object[] args,
CallOptions options) {
checkPyArguments(args);
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(moduleName, "",
functionName);
// Python functions always have a return value, even if it's `None`.
return callNormalFunction(functionDescriptor, args, /*numReturns=*/1, options);
}
@Override
public RayObject callPyActor(RayPyActor pyActor, String functionName, Object... args) {
checkPyArguments(args);
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(),
pyActor.getClassName(), functionName);
// Python functions always have a return value, even if it's `None`.
return callActorFunction(pyActor, functionDescriptor, args, /*numReturns=*/1);
}
@Override
public RayPyActor createPyActor(String moduleName, String className, Object[] args,
ActorCreationOptions options) {
checkPyArguments(args);
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(moduleName, className,
PYTHON_INIT_METHOD_NAME);
return (RayPyActor) createActorImpl(functionDescriptor, args, options);
}
@Override
public Runnable wrapRunnable(Runnable runnable) {
return runnable;
@@ -8,6 +8,9 @@ import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
import org.ray.api.WaitResult;
import org.ray.api.function.PyActorClass;
import org.ray.api.function.PyActorMethod;
import org.ray.api.function.PyRemoteFunction;
import org.ray.api.function.RayFunc;
import org.ray.api.id.ObjectId;
import org.ray.api.id.UniqueId;
@@ -150,8 +153,19 @@ public class RayMultiWorkerNativeRuntime implements RayRuntime {
}
@Override
public RayObject callActor(RayFunc func, RayActor<?> actor, Object[] args) {
return getCurrentRuntime().callActor(func, actor, args);
public RayObject call(PyRemoteFunction pyRemoteFunction, Object[] args,
CallOptions options) {
return getCurrentRuntime().call(pyRemoteFunction, args, options);
}
@Override
public RayObject callActor(RayActor<?> actor, RayFunc func, Object[] args) {
return getCurrentRuntime().callActor(actor, func, args);
}
@Override
public RayObject callActor(RayPyActor pyActor, PyActorMethod pyActorMethod, Object[] args) {
return getCurrentRuntime().callActor(pyActor, pyActorMethod, args);
}
@Override
@@ -160,28 +174,17 @@ public class RayMultiWorkerNativeRuntime implements RayRuntime {
return getCurrentRuntime().createActor(actorFactoryFunc, args, options);
}
@Override
public RayPyActor createActor(PyActorClass pyActorClass, Object[] args,
ActorCreationOptions options) {
return getCurrentRuntime().createActor(pyActorClass, args, options);
}
@Override
public RuntimeContext getRuntimeContext() {
return getCurrentRuntime().getRuntimeContext();
}
@Override
public RayObject callPy(String moduleName, String functionName, Object[] args,
CallOptions options) {
return getCurrentRuntime().callPy(moduleName, functionName, args, options);
}
@Override
public RayObject callPyActor(RayPyActor pyActor, String functionName, Object[] args) {
return getCurrentRuntime().callPyActor(pyActor, functionName, args);
}
@Override
public RayPyActor createPyActor(String moduleName, String className, Object[] args,
ActorCreationOptions options) {
return getCurrentRuntime().createPyActor(moduleName, className, args, options);
}
@Override
public Object getAsyncContext() {
return getCurrentRuntime();
@@ -23,6 +23,8 @@ public class RayCallGenerator extends BaseGenerator {
newLine("");
newLine("package org.ray.api;");
newLine("");
newLine("import org.ray.api.function.PyActorClass;");
newLine("import org.ray.api.function.PyRemoteFunction;");
for (int i = 0; i <= MAX_PARAMETERS; i++) {
newLine("import org.ray.api.function.RayFunc" + i + ";");
}
@@ -112,6 +114,8 @@ public class RayCallGenerator extends BaseGenerator {
newLine("");
newLine("package org.ray.api;");
newLine("");
newLine("import org.ray.api.function.PyActorMethod;");
newLine("");
newLine("/**");
newLine(" * This class provides type-safe interfaces for remote actor calls.");
newLine(" **/");
@@ -216,11 +220,11 @@ public class RayCallGenerator extends BaseGenerator {
newLine(2, String.format("Object[] args = new Object[]{%s};", args));
// 5) Construct the third line.
String callFuncArgs = "f, ";
String callFuncArgs = "";
if (forActor) {
callFuncArgs += "(RayActor) this, ";
}
callFuncArgs += "args, ";
callFuncArgs += "f, args, ";
callFuncArgs += forActor ? "" : hasOptionsParam ? "options, " : "null, ";
callFuncArgs = callFuncArgs.substring(0, callFuncArgs.length() - 2);
newLine(2, String.format("%sRay.internal().%s(%s);",
@@ -231,12 +235,12 @@ public class RayCallGenerator extends BaseGenerator {
}
/**
* Build `Ray.callPy`, `Ray.createPyActor` and `actor.call` methods with
* Build `Ray.call`, `Ray.createActor` and `actor.call` methods with
* the given number of parameters.
*
* @param numParameters the number of parameters
* @param forActor Build `actor.call` when true, otherwise build `Ray.callPy`.
* @param forActorCreation Build `Ray.createPyActor` when true, otherwise build `Ray.callPy`.
* @param forActor Build `actor.call` when true, otherwise build `Ray.call`.
* @param forActorCreation Build `Ray.createActor` when true, otherwise build `Ray.call`.
* @param hasOptionsParam Add ActorCreationOptions if forActorCreation is true;
* Add CallOptions if forActorCreation is false;
* No additional param if hasOptionsParam is false.
@@ -261,14 +265,14 @@ public class RayCallGenerator extends BaseGenerator {
String paramPrefix = "";
String funcArgs = "";
if (forActorCreation) {
paramPrefix += "String moduleName, String className";
funcArgs += "moduleName, className";
paramPrefix += "PyActorClass pyActorClass";
funcArgs += "pyActorClass";
} else if (forActor) {
paramPrefix += "String functionName";
funcArgs += "functionName";
paramPrefix += "PyActorMethod<R> pyActorMethod";
funcArgs += "pyActorMethod";
} else {
paramPrefix += "String moduleName, String functionName";
funcArgs += "moduleName, functionName";
paramPrefix += "PyRemoteFunction<R> pyRemoteFunction";
funcArgs += "pyRemoteFunction";
}
if (numParameters > 0) {
paramPrefix += ", ";
@@ -292,14 +296,15 @@ public class RayCallGenerator extends BaseGenerator {
}
}
String returnType = !forActorCreation ? "RayObject" : "RayPyActor";
String funcName = forActorCreation ? "createPyActor" : forActor ? "call" : "callPy";
String internalCallFunc = forActorCreation ? "createPyActor" :
forActor ? "callPyActor" : "callPy";
String genericType = forActorCreation ? "" : " <R>";
String returnType = !forActorCreation ? "RayObject<R>" : "RayPyActor";
String funcName = forActorCreation ? "createActor" : "call";
String internalCallFunc = forActorCreation ? "createActor" :
forActor ? "callActor" : "call";
funcArgs += ", args";
// Method signature.
newLine(1, String.format(
"%s %s %s(%s%s) {", modifiers,
"%s%s %s %s(%s%s) {", modifiers, genericType,
returnType, funcName, paramPrefix + paramList, optionsParam
));
// Method body.