mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 05:52:36 +08:00
Java call Python use structured function descriptors (#7634)
This commit is contained in:
@@ -11,6 +11,9 @@ import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.function.PyActorClass;
|
||||
import org.ray.api.function.PyActorMethod;
|
||||
import org.ray.api.function.PyRemoteFunction;
|
||||
import org.ray.api.function.RayFunc;
|
||||
import org.ray.api.function.RayFuncVoid;
|
||||
import org.ray.api.id.ObjectId;
|
||||
@@ -98,7 +101,19 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callActor(RayFunc func, RayActor<?> actor, Object[] args) {
|
||||
public RayObject call(PyRemoteFunction pyRemoteFunction, Object[] args,
|
||||
CallOptions options) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(
|
||||
pyRemoteFunction.moduleName,
|
||||
"",
|
||||
pyRemoteFunction.functionName);
|
||||
// Python functions always have a return value, even if it's `None`.
|
||||
return callNormalFunction(functionDescriptor, args, /*numReturns=*/1, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callActor(RayActor<?> actor, RayFunc func, Object[] args) {
|
||||
FunctionDescriptor functionDescriptor =
|
||||
functionManager.getFunction(workerContext.getCurrentJobId(), func)
|
||||
.functionDescriptor;
|
||||
@@ -106,6 +121,15 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
return callActorFunction(actor, functionDescriptor, args, numReturns);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callActor(RayPyActor pyActor, PyActorMethod pyActorMethod, Object... args) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(),
|
||||
pyActor.getClassName(), pyActorMethod.methodName);
|
||||
// Python functions always have a return value, even if it's `None`.
|
||||
return callActorFunction(pyActor, functionDescriptor, args, /*numReturns=*/1);
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc,
|
||||
@@ -116,6 +140,17 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
return (RayActor<T>) createActorImpl(functionDescriptor, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayPyActor createActor(PyActorClass pyActorClass, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(
|
||||
pyActorClass.moduleName,
|
||||
pyActorClass.className,
|
||||
PYTHON_INIT_METHOD_NAME);
|
||||
return (RayPyActor) createActorImpl(functionDescriptor, args, options);
|
||||
}
|
||||
|
||||
private void checkPyArguments(Object[] args) {
|
||||
for (Object arg : args) {
|
||||
Preconditions.checkArgument(
|
||||
@@ -125,34 +160,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callPy(String moduleName, String functionName, Object[] args,
|
||||
CallOptions options) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(moduleName, "",
|
||||
functionName);
|
||||
// Python functions always have a return value, even if it's `None`.
|
||||
return callNormalFunction(functionDescriptor, args, /*numReturns=*/1, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callPyActor(RayPyActor pyActor, String functionName, Object... args) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(),
|
||||
pyActor.getClassName(), functionName);
|
||||
// Python functions always have a return value, even if it's `None`.
|
||||
return callActorFunction(pyActor, functionDescriptor, args, /*numReturns=*/1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayPyActor createPyActor(String moduleName, String className, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(moduleName, className,
|
||||
PYTHON_INIT_METHOD_NAME);
|
||||
return (RayPyActor) createActorImpl(functionDescriptor, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Runnable wrapRunnable(Runnable runnable) {
|
||||
return runnable;
|
||||
|
||||
@@ -8,6 +8,9 @@ import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.function.PyActorClass;
|
||||
import org.ray.api.function.PyActorMethod;
|
||||
import org.ray.api.function.PyRemoteFunction;
|
||||
import org.ray.api.function.RayFunc;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
@@ -150,8 +153,19 @@ public class RayMultiWorkerNativeRuntime implements RayRuntime {
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callActor(RayFunc func, RayActor<?> actor, Object[] args) {
|
||||
return getCurrentRuntime().callActor(func, actor, args);
|
||||
public RayObject call(PyRemoteFunction pyRemoteFunction, Object[] args,
|
||||
CallOptions options) {
|
||||
return getCurrentRuntime().call(pyRemoteFunction, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callActor(RayActor<?> actor, RayFunc func, Object[] args) {
|
||||
return getCurrentRuntime().callActor(actor, func, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callActor(RayPyActor pyActor, PyActorMethod pyActorMethod, Object[] args) {
|
||||
return getCurrentRuntime().callActor(pyActor, pyActorMethod, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -160,28 +174,17 @@ public class RayMultiWorkerNativeRuntime implements RayRuntime {
|
||||
return getCurrentRuntime().createActor(actorFactoryFunc, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayPyActor createActor(PyActorClass pyActorClass, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
return getCurrentRuntime().createActor(pyActorClass, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RuntimeContext getRuntimeContext() {
|
||||
return getCurrentRuntime().getRuntimeContext();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callPy(String moduleName, String functionName, Object[] args,
|
||||
CallOptions options) {
|
||||
return getCurrentRuntime().callPy(moduleName, functionName, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callPyActor(RayPyActor pyActor, String functionName, Object[] args) {
|
||||
return getCurrentRuntime().callPyActor(pyActor, functionName, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayPyActor createPyActor(String moduleName, String className, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
return getCurrentRuntime().createPyActor(moduleName, className, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getAsyncContext() {
|
||||
return getCurrentRuntime();
|
||||
|
||||
@@ -23,6 +23,8 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
newLine("");
|
||||
newLine("package org.ray.api;");
|
||||
newLine("");
|
||||
newLine("import org.ray.api.function.PyActorClass;");
|
||||
newLine("import org.ray.api.function.PyRemoteFunction;");
|
||||
for (int i = 0; i <= MAX_PARAMETERS; i++) {
|
||||
newLine("import org.ray.api.function.RayFunc" + i + ";");
|
||||
}
|
||||
@@ -112,6 +114,8 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
newLine("");
|
||||
newLine("package org.ray.api;");
|
||||
newLine("");
|
||||
newLine("import org.ray.api.function.PyActorMethod;");
|
||||
newLine("");
|
||||
newLine("/**");
|
||||
newLine(" * This class provides type-safe interfaces for remote actor calls.");
|
||||
newLine(" **/");
|
||||
@@ -216,11 +220,11 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
newLine(2, String.format("Object[] args = new Object[]{%s};", args));
|
||||
|
||||
// 5) Construct the third line.
|
||||
String callFuncArgs = "f, ";
|
||||
String callFuncArgs = "";
|
||||
if (forActor) {
|
||||
callFuncArgs += "(RayActor) this, ";
|
||||
}
|
||||
callFuncArgs += "args, ";
|
||||
callFuncArgs += "f, args, ";
|
||||
callFuncArgs += forActor ? "" : hasOptionsParam ? "options, " : "null, ";
|
||||
callFuncArgs = callFuncArgs.substring(0, callFuncArgs.length() - 2);
|
||||
newLine(2, String.format("%sRay.internal().%s(%s);",
|
||||
@@ -231,12 +235,12 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
}
|
||||
|
||||
/**
|
||||
* Build `Ray.callPy`, `Ray.createPyActor` and `actor.call` methods with
|
||||
* Build `Ray.call`, `Ray.createActor` and `actor.call` methods with
|
||||
* the given number of parameters.
|
||||
*
|
||||
* @param numParameters the number of parameters
|
||||
* @param forActor Build `actor.call` when true, otherwise build `Ray.callPy`.
|
||||
* @param forActorCreation Build `Ray.createPyActor` when true, otherwise build `Ray.callPy`.
|
||||
* @param forActor Build `actor.call` when true, otherwise build `Ray.call`.
|
||||
* @param forActorCreation Build `Ray.createActor` when true, otherwise build `Ray.call`.
|
||||
* @param hasOptionsParam Add ActorCreationOptions if forActorCreation is true;
|
||||
* Add CallOptions if forActorCreation is false;
|
||||
* No additional param if hasOptionsParam is false.
|
||||
@@ -261,14 +265,14 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
String paramPrefix = "";
|
||||
String funcArgs = "";
|
||||
if (forActorCreation) {
|
||||
paramPrefix += "String moduleName, String className";
|
||||
funcArgs += "moduleName, className";
|
||||
paramPrefix += "PyActorClass pyActorClass";
|
||||
funcArgs += "pyActorClass";
|
||||
} else if (forActor) {
|
||||
paramPrefix += "String functionName";
|
||||
funcArgs += "functionName";
|
||||
paramPrefix += "PyActorMethod<R> pyActorMethod";
|
||||
funcArgs += "pyActorMethod";
|
||||
} else {
|
||||
paramPrefix += "String moduleName, String functionName";
|
||||
funcArgs += "moduleName, functionName";
|
||||
paramPrefix += "PyRemoteFunction<R> pyRemoteFunction";
|
||||
funcArgs += "pyRemoteFunction";
|
||||
}
|
||||
if (numParameters > 0) {
|
||||
paramPrefix += ", ";
|
||||
@@ -292,14 +296,15 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
}
|
||||
}
|
||||
|
||||
String returnType = !forActorCreation ? "RayObject" : "RayPyActor";
|
||||
String funcName = forActorCreation ? "createPyActor" : forActor ? "call" : "callPy";
|
||||
String internalCallFunc = forActorCreation ? "createPyActor" :
|
||||
forActor ? "callPyActor" : "callPy";
|
||||
String genericType = forActorCreation ? "" : " <R>";
|
||||
String returnType = !forActorCreation ? "RayObject<R>" : "RayPyActor";
|
||||
String funcName = forActorCreation ? "createActor" : "call";
|
||||
String internalCallFunc = forActorCreation ? "createActor" :
|
||||
forActor ? "callActor" : "call";
|
||||
funcArgs += ", args";
|
||||
// Method signature.
|
||||
newLine(1, String.format(
|
||||
"%s %s %s(%s%s) {", modifiers,
|
||||
"%s%s %s %s(%s%s) {", modifiers, genericType,
|
||||
returnType, funcName, paramPrefix + paramList, optionsParam
|
||||
));
|
||||
// Method body.
|
||||
|
||||
Reference in New Issue
Block a user