[Java] Refine python function (#8943)

This commit is contained in:
chaokunyang
2020-06-16 16:22:49 +08:00
committed by GitHub
parent 14405b90d5
commit cb6f337372
13 changed files with 164 additions and 103 deletions
+15 -15
View File
@@ -8,7 +8,7 @@ import io.ray.api.call.PyTaskCaller;
import io.ray.api.call.TaskCaller;
import io.ray.api.call.VoidTaskCaller;
import io.ray.api.function.PyActorClass;
import io.ray.api.function.PyRemoteFunction;
import io.ray.api.function.PyFunction;
import io.ray.api.function.RayFunc0;
import io.ray.api.function.RayFunc1;
import io.ray.api.function.RayFunc2;
@@ -1942,39 +1942,39 @@ class RayCall {
// ===========================
// Cross-language methods.
// ===========================
public static <R> PyTaskCaller<R> task(PyRemoteFunction<R> pyRemoteFunction) {
public static <R> PyTaskCaller<R> task(PyFunction<R> pyFunction) {
Object[] args = new Object[]{};
return new PyTaskCaller<>(pyRemoteFunction, args);
return new PyTaskCaller<>(pyFunction, args);
}
public static <R> PyTaskCaller<R> task(PyRemoteFunction<R> pyRemoteFunction, Object obj0) {
public static <R> PyTaskCaller<R> task(PyFunction<R> pyFunction, Object obj0) {
Object[] args = new Object[]{obj0};
return new PyTaskCaller<>(pyRemoteFunction, args);
return new PyTaskCaller<>(pyFunction, args);
}
public static <R> PyTaskCaller<R> task(PyRemoteFunction<R> pyRemoteFunction, Object obj0, Object obj1) {
public static <R> PyTaskCaller<R> task(PyFunction<R> pyFunction, Object obj0, Object obj1) {
Object[] args = new Object[]{obj0, obj1};
return new PyTaskCaller<>(pyRemoteFunction, args);
return new PyTaskCaller<>(pyFunction, args);
}
public static <R> PyTaskCaller<R> task(PyRemoteFunction<R> pyRemoteFunction, Object obj0, Object obj1, Object obj2) {
public static <R> PyTaskCaller<R> task(PyFunction<R> pyFunction, Object obj0, Object obj1, Object obj2) {
Object[] args = new Object[]{obj0, obj1, obj2};
return new PyTaskCaller<>(pyRemoteFunction, args);
return new PyTaskCaller<>(pyFunction, args);
}
public static <R> PyTaskCaller<R> task(PyRemoteFunction<R> pyRemoteFunction, Object obj0, Object obj1, Object obj2, Object obj3) {
public static <R> PyTaskCaller<R> task(PyFunction<R> pyFunction, Object obj0, Object obj1, Object obj2, Object obj3) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
return new PyTaskCaller<>(pyRemoteFunction, args);
return new PyTaskCaller<>(pyFunction, args);
}
public static <R> PyTaskCaller<R> task(PyRemoteFunction<R> pyRemoteFunction, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) {
public static <R> PyTaskCaller<R> task(PyFunction<R> pyFunction, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
return new PyTaskCaller<>(pyRemoteFunction, args);
return new PyTaskCaller<>(pyFunction, args);
}
public static <R> PyTaskCaller<R> task(PyRemoteFunction<R> pyRemoteFunction, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5) {
public static <R> PyTaskCaller<R> task(PyFunction<R> pyFunction, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5};
return new PyTaskCaller<>(pyRemoteFunction, args);
return new PyTaskCaller<>(pyFunction, args);
}
public static PyActorCreator actor(PyActorClass pyActorClass) {
@@ -2,7 +2,7 @@ package io.ray.api.call;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.function.PyRemoteFunction;
import io.ray.api.function.PyFunction;
/**
* A helper to call python remote function.
@@ -10,10 +10,10 @@ import io.ray.api.function.PyRemoteFunction;
* @param <R> The type of the python function return value
*/
public class PyTaskCaller<R> extends BaseTaskCaller<PyTaskCaller<R>> {
private final PyRemoteFunction<R> func;
private final PyFunction<R> func;
private final Object[] args;
public PyTaskCaller(PyRemoteFunction<R> func, Object[] args) {
public PyTaskCaller(PyFunction<R> func, Object[] args) {
this.func = func;
this.args = args;
}
@@ -18,7 +18,7 @@ package io.ray.api.function;
* we can create this Python actor from Java:
*
* {@code
* PyActorHandle actor = Ray.createActor(new PyActorClass("example_package.example_module", "A"),
* PyActorHandle actor = Ray.createActor(PyActorClass.of("example_package.example_module", "A"),
* "the value for x");
* }
* </pre>
@@ -29,8 +29,20 @@ public class PyActorClass {
// The name of this actor class
public final String className;
public PyActorClass(String moduleName, String className) {
private PyActorClass(String moduleName, String className) {
this.moduleName = moduleName;
this.className = className;
}
/**
* Create a python actor class.
*
* @param moduleName The full module name of this actor class
* @param className The name of this actor class
* @return a python actor class
*/
public static PyActorClass of(String moduleName, String className) {
return new PyActorClass(moduleName, className);
}
}
@@ -1,8 +1,8 @@
package io.ray.api.function;
/**
* A class that represents a method of a Python actor.
*
* A class that represents a method of a Python actor.
* <p>
* Note, information about the actor will be inferred from the actor handle,
* so it's not specified in this class.
*
@@ -24,7 +24,7 @@ package io.ray.api.function;
*
* {@code
* // A.foo returns a string, so we have to set the returnType to String.class
* ObjectRef<String> res = actor.call(new PyActorMethod<>("foo", String.class));
* ObjectRef<String> res = actor.call(PyActorMethod.of("foo", String.class));
* String x = res.get();
* }
* </pre>
@@ -35,8 +35,31 @@ public class PyActorMethod<R> {
// Type of the return value of this actor method
public final Class<R> returnType;
public PyActorMethod(String methodName, Class<R> returnType) {
private PyActorMethod(String methodName, Class<R> returnType) {
this.methodName = methodName;
this.returnType = returnType;
}
/**
* Create a python actor method.
*
* @param methodName The name of this actor method
* @return a python actor method.
*/
public static PyActorMethod<Object> of(String methodName) {
return of(methodName, Object.class);
}
/**
* Create a python actor method.
*
* @param methodName The name of this actor method
* @param returnType Class of the return value of this actor method
* @param <R> The type of the return value of this actor method
* @return a python actor method.
*/
public static <R> PyActorMethod<R> of(String methodName, Class<R> returnType) {
return new PyActorMethod<>(methodName, returnType);
}
}
@@ -0,0 +1,74 @@
package io.ray.api.function;
/**
* A class that represents a Python remote function.
*
* <pre>
* example_package/
* ├──__init__.py
* └──example_module.py
*
* in example_module.py there is a function.
*
* \@ray.remote
* def bar(v):
* return v
*
* then we can call the Python function bar:
*
* {@code
* // bar returns input, so we have to set the returnType to int.class if bar accepts an int
* ObjectRef<Integer> res = actor.call(
* PyFunction.of("example_package.example_module", "bar", Integer.class),
* 1);
* Integer value = res.get();
*
* // bar returns input, so we have to set the returnType to String.class if bar accepts a string
* ObjectRef<String> res = actor.call(
* PyFunction.of("example_package.example_module", "bar", String.class),
* "Hello world!");
* String value = res.get();
* }
* </pre>
*/
public class PyFunction<R> {
// The full module name of this function
public final String moduleName;
// The name of this function
public final String functionName;
// Type of the return value of this function
public final Class<R> returnType;
private PyFunction(String moduleName, String functionName, Class<R> returnType) {
this.moduleName = moduleName;
this.functionName = functionName;
this.returnType = returnType;
}
/**
* Create a python function.
*
* @param moduleName The full module name of this function
* @param functionName The name of this function
* @return a python function.
*/
public static PyFunction<Object> of(
String moduleName, String functionName) {
return of(moduleName, functionName, Object.class);
}
/**
* Create a python function.
*
* @param moduleName The full module name of this function
* @param functionName The name of this function
* @param returnType Class of the return value of this function
* @param <R> Type of the return value of this function
* @return a python function.
*/
public static <R> PyFunction<R> of(
String moduleName, String functionName, Class<R> returnType) {
return new PyFunction<>(moduleName, functionName, returnType);
}
}
@@ -1,47 +0,0 @@
package io.ray.api.function;
/**
* A class that represents a Python remote function.
*
* <pre>
* example_package/
* ├──__init__.py
* └──example_module.py
*
* in example_module.py there is a function.
*
* \@ray.remote
* def bar(v):
* return v
*
* then we can call the Python function bar:
*
* {@code
* // bar returns input, so we have to set the returnType to int.class if bar accepts an int
* ObjectRef<Integer> res = actor.call(
* new PyRemoteFunction<>("example_package.example_module", "bar", Integer.class),
* 1);
* Integer value = res.get();
*
* // bar returns input, so we have to set the returnType to String.class if bar accepts a string
* ObjectRef<String> res = actor.call(
* new PyRemoteFunction<>("example_package.example_module", "bar", String.class),
* "Hello world!");
* String value = res.get();
* }
* </pre>
*/
public class PyRemoteFunction<R> {
// The full module name of this function
public final String moduleName;
// The name of this function
public final String functionName;
// Type of the return value of this function
public final Class<R> returnType;
public PyRemoteFunction(String moduleName, String functionName, Class<R> returnType) {
this.moduleName = moduleName;
this.functionName = functionName;
this.returnType = returnType;
}
}
@@ -7,7 +7,7 @@ import io.ray.api.PyActorHandle;
import io.ray.api.WaitResult;
import io.ray.api.function.PyActorClass;
import io.ray.api.function.PyActorMethod;
import io.ray.api.function.PyRemoteFunction;
import io.ray.api.function.PyFunction;
import io.ray.api.function.RayFunc;
import io.ray.api.id.ObjectId;
import io.ray.api.id.UniqueId;
@@ -103,12 +103,12 @@ public interface RayRuntime {
/**
* Invoke a remote Python function.
*
* @param pyRemoteFunction The Python function.
* @param pyFunction The Python function.
* @param args Arguments of the function.
* @param options The options for this call.
* @return The result object.
*/
ObjectRef call(PyRemoteFunction pyRemoteFunction, Object[] args, CallOptions options);
ObjectRef call(PyFunction pyFunction, Object[] args, CallOptions options);
/**
* Invoke a remote function on an actor.
@@ -11,7 +11,7 @@ import io.ray.api.WaitResult;
import io.ray.api.exception.RayException;
import io.ray.api.function.PyActorClass;
import io.ray.api.function.PyActorMethod;
import io.ray.api.function.PyRemoteFunction;
import io.ray.api.function.PyFunction;
import io.ray.api.function.RayFunc;
import io.ray.api.id.ObjectId;
import io.ray.api.options.ActorCreationOptions;
@@ -107,15 +107,14 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
}
@Override
public ObjectRef call(PyRemoteFunction pyRemoteFunction, Object[] args,
CallOptions options) {
public ObjectRef call(PyFunction pyFunction, Object[] args, CallOptions options) {
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(
pyRemoteFunction.moduleName,
pyFunction.moduleName,
"",
pyRemoteFunction.functionName);
pyFunction.functionName);
// Python functions always have a return value, even if it's `None`.
return callNormalFunction(functionDescriptor, args,
/*returnType=*/Optional.of(pyRemoteFunction.returnType), options);
/*returnType=*/Optional.of(pyFunction.returnType), options);
}
@Override
@@ -29,7 +29,7 @@ public class RayCallGenerator extends BaseGenerator {
newLine("import io.ray.api.call.TaskCaller;");
newLine("import io.ray.api.call.VoidTaskCaller;");
newLine("import io.ray.api.function.PyActorClass;");
newLine("import io.ray.api.function.PyRemoteFunction;");
newLine("import io.ray.api.function.PyFunction;");
for (int i = 0; i <= MAX_PARAMETERS; i++) {
newLine("import io.ray.api.function.RayFunc" + i + ";");
}
@@ -273,8 +273,8 @@ public class RayCallGenerator extends BaseGenerator {
paramPrefix += "PyActorMethod<R> pyActorMethod";
funcArgs += "pyActorMethod";
} else {
paramPrefix += "PyRemoteFunction<R> pyRemoteFunction";
funcArgs += "pyRemoteFunction";
paramPrefix += "PyFunction<R> pyFunction";
funcArgs += "pyFunction";
}
if (numParameters > 0) {
paramPrefix += ", ";
@@ -8,7 +8,7 @@ import io.ray.api.PyActorHandle;
import io.ray.api.Ray;
import io.ray.api.function.PyActorClass;
import io.ray.api.function.PyActorMethod;
import io.ray.api.function.PyRemoteFunction;
import io.ray.api.function.PyFunction;
import io.ray.runtime.actor.NativeActorHandle;
import io.ray.runtime.actor.NativePyActorHandle;
import java.io.File;
@@ -70,7 +70,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
"example binary".getBytes()}; // byte[]
for (Object o : inputs) {
ObjectRef res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", o.getClass()),
PyFunction.of(PYTHON_MODULE, "py_return_input", o.getClass()),
o).remote();
Assert.assertEquals(res.get(), o);
}
@@ -78,7 +78,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
{
Object input = null;
ObjectRef<Object> res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", Object.class), input).remote();
PyFunction.of(PYTHON_MODULE, "py_return_input", Object.class), input).remote();
Object r = res.get();
Assert.assertEquals(r, input);
}
@@ -86,7 +86,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
{
int[] input = new int[]{1, 2};
ObjectRef<int[]> res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", int[].class), input).remote();
PyFunction.of(PYTHON_MODULE, "py_return_input", int[].class), input).remote();
int[] r = res.get();
Assert.assertEquals(r, input);
}
@@ -95,7 +95,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
Object[] input = new Object[]{1, 2.3f, 4.56, "789", "10".getBytes(), null, true,
new int[]{1, 2}};
ObjectRef<Object[]> res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", Object[].class), input).remote();
PyFunction.of(PYTHON_MODULE, "py_return_input", Object[].class), input).remote();
Object[] r = res.get();
// If we tell the value type is Object, then all numbers will be Number type.
Assert.assertEquals(((Number) r[0]).intValue(), input[0]);
@@ -120,7 +120,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
Assert.expectThrows(Exception.class, () -> {
List<Integer> input = Arrays.asList(1, 2);
ObjectRef<List<Integer>> res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input",
PyFunction.of(PYTHON_MODULE, "py_return_input",
(Class<List<Integer>>) input.getClass()), input).remote();
List<Integer> r = res.get();
Assert.assertEquals(r, input);
@@ -130,7 +130,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
@Test
public void testPythonCallJavaFunction() {
ObjectRef<String> res = Ray.task(new PyRemoteFunction<>(
ObjectRef<String> res = Ray.task(PyFunction.of(
PYTHON_MODULE, "py_func_call_java_function", String.class)).remote();
Assert.assertEquals(res.get(), "success");
}
@@ -138,9 +138,9 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
@Test
public void testCallingPythonActor() {
PyActorHandle actor = Ray.actor(
new PyActorClass(PYTHON_MODULE, "Counter"), "1".getBytes()).remote();
PyActorClass.of(PYTHON_MODULE, "Counter"), "1".getBytes()).remote();
ObjectRef<byte[]> res = actor.task(
new PyActorMethod<>("increase", byte[].class),
PyActorMethod.of("increase", byte[].class),
"1".getBytes()).remote();
Assert.assertEquals(res.get(), "2".getBytes());
}
@@ -148,7 +148,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
@Test
public void testPythonCallJavaActor() {
ObjectRef<byte[]> res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE, "py_func_call_java_actor", byte[].class),
PyFunction.of(PYTHON_MODULE, "py_func_call_java_actor", byte[].class),
"1".getBytes()).remote();
Assert.assertEquals(res.get(), "Counter1".getBytes());
@@ -158,7 +158,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
public void testPassActorHandleFromPythonToJava() {
// Call a python function which creates a python actor
// and pass the actor handle to callPythonActorHandle.
ObjectRef<byte[]> res = Ray.task(new PyRemoteFunction<>(
ObjectRef<byte[]> res = Ray.task(PyFunction.of(
PYTHON_MODULE, "py_func_pass_python_actor_handle", byte[].class)).remote();
Assert.assertEquals(res.get(), "3".getBytes());
}
@@ -170,18 +170,18 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
Preconditions.checkState(javaActor instanceof NativeActorHandle);
byte[] actorHandleBytes = ((NativeActorHandle) javaActor).toBytes();
ObjectRef<byte[]> res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE,
PyFunction.of(PYTHON_MODULE,
"py_func_call_java_actor_from_handle",
byte[].class),
actorHandleBytes).remote();
Assert.assertEquals(res.get(), "12".getBytes());
// Create a python actor, and pass actor handle to python.
PyActorHandle pyActor = Ray.actor(
new PyActorClass(PYTHON_MODULE, "Counter"), "1".getBytes()).remote();
PyActorClass.of(PYTHON_MODULE, "Counter"), "1".getBytes()).remote();
Preconditions.checkState(pyActor instanceof NativeActorHandle);
actorHandleBytes = ((NativeActorHandle) pyActor).toBytes();
res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE,
PyFunction.of(PYTHON_MODULE,
"py_func_call_python_actor_from_handle",
byte[].class),
actorHandleBytes).remote();
@@ -221,7 +221,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
// This function will be called from test_cross_language_invocation.py
NativePyActorHandle actor = (NativePyActorHandle) NativeActorHandle.fromBytes(value);
ObjectRef<byte[]> res = actor.task(
new PyActorMethod<>("increase", byte[].class),
PyActorMethod.of("increase", byte[].class),
"1".getBytes()).remote();
Assert.assertEquals(res.get(), "3".getBytes());
return (byte[]) res.get();
@@ -13,7 +13,7 @@ public class RaySerializerTest extends BaseMultiLanguageTest {
@Test
public void testSerializePyActor() {
PyActorHandle pyActor = Ray.actor(
new PyActorClass("test", "RaySerializerTest")).remote();
PyActorClass.of("test", "RaySerializerTest")).remote();
NativeRayObject nativeRayObject = ObjectSerializer.serialize(pyActor);
PyActorHandle result = (PyActorHandle) ObjectSerializer
.deserialize(nativeRayObject, null, Object.class);