[Java] Refine python function (#8943)

This commit is contained in:
chaokunyang
2020-06-16 16:22:49 +08:00
committed by GitHub
parent 14405b90d5
commit cb6f337372
13 changed files with 164 additions and 103 deletions
@@ -8,7 +8,7 @@ import io.ray.api.PyActorHandle;
import io.ray.api.Ray;
import io.ray.api.function.PyActorClass;
import io.ray.api.function.PyActorMethod;
import io.ray.api.function.PyRemoteFunction;
import io.ray.api.function.PyFunction;
import io.ray.runtime.actor.NativeActorHandle;
import io.ray.runtime.actor.NativePyActorHandle;
import java.io.File;
@@ -70,7 +70,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
"example binary".getBytes()}; // byte[]
for (Object o : inputs) {
ObjectRef res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", o.getClass()),
PyFunction.of(PYTHON_MODULE, "py_return_input", o.getClass()),
o).remote();
Assert.assertEquals(res.get(), o);
}
@@ -78,7 +78,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
{
Object input = null;
ObjectRef<Object> res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", Object.class), input).remote();
PyFunction.of(PYTHON_MODULE, "py_return_input", Object.class), input).remote();
Object r = res.get();
Assert.assertEquals(r, input);
}
@@ -86,7 +86,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
{
int[] input = new int[]{1, 2};
ObjectRef<int[]> res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", int[].class), input).remote();
PyFunction.of(PYTHON_MODULE, "py_return_input", int[].class), input).remote();
int[] r = res.get();
Assert.assertEquals(r, input);
}
@@ -95,7 +95,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
Object[] input = new Object[]{1, 2.3f, 4.56, "789", "10".getBytes(), null, true,
new int[]{1, 2}};
ObjectRef<Object[]> res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", Object[].class), input).remote();
PyFunction.of(PYTHON_MODULE, "py_return_input", Object[].class), input).remote();
Object[] r = res.get();
// If we tell the value type is Object, then all numbers will be Number type.
Assert.assertEquals(((Number) r[0]).intValue(), input[0]);
@@ -120,7 +120,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
Assert.expectThrows(Exception.class, () -> {
List<Integer> input = Arrays.asList(1, 2);
ObjectRef<List<Integer>> res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input",
PyFunction.of(PYTHON_MODULE, "py_return_input",
(Class<List<Integer>>) input.getClass()), input).remote();
List<Integer> r = res.get();
Assert.assertEquals(r, input);
@@ -130,7 +130,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
@Test
public void testPythonCallJavaFunction() {
ObjectRef<String> res = Ray.task(new PyRemoteFunction<>(
ObjectRef<String> res = Ray.task(PyFunction.of(
PYTHON_MODULE, "py_func_call_java_function", String.class)).remote();
Assert.assertEquals(res.get(), "success");
}
@@ -138,9 +138,9 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
@Test
public void testCallingPythonActor() {
PyActorHandle actor = Ray.actor(
new PyActorClass(PYTHON_MODULE, "Counter"), "1".getBytes()).remote();
PyActorClass.of(PYTHON_MODULE, "Counter"), "1".getBytes()).remote();
ObjectRef<byte[]> res = actor.task(
new PyActorMethod<>("increase", byte[].class),
PyActorMethod.of("increase", byte[].class),
"1".getBytes()).remote();
Assert.assertEquals(res.get(), "2".getBytes());
}
@@ -148,7 +148,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
@Test
public void testPythonCallJavaActor() {
ObjectRef<byte[]> res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE, "py_func_call_java_actor", byte[].class),
PyFunction.of(PYTHON_MODULE, "py_func_call_java_actor", byte[].class),
"1".getBytes()).remote();
Assert.assertEquals(res.get(), "Counter1".getBytes());
@@ -158,7 +158,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
public void testPassActorHandleFromPythonToJava() {
// Call a python function which creates a python actor
// and pass the actor handle to callPythonActorHandle.
ObjectRef<byte[]> res = Ray.task(new PyRemoteFunction<>(
ObjectRef<byte[]> res = Ray.task(PyFunction.of(
PYTHON_MODULE, "py_func_pass_python_actor_handle", byte[].class)).remote();
Assert.assertEquals(res.get(), "3".getBytes());
}
@@ -170,18 +170,18 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
Preconditions.checkState(javaActor instanceof NativeActorHandle);
byte[] actorHandleBytes = ((NativeActorHandle) javaActor).toBytes();
ObjectRef<byte[]> res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE,
PyFunction.of(PYTHON_MODULE,
"py_func_call_java_actor_from_handle",
byte[].class),
actorHandleBytes).remote();
Assert.assertEquals(res.get(), "12".getBytes());
// Create a python actor, and pass actor handle to python.
PyActorHandle pyActor = Ray.actor(
new PyActorClass(PYTHON_MODULE, "Counter"), "1".getBytes()).remote();
PyActorClass.of(PYTHON_MODULE, "Counter"), "1".getBytes()).remote();
Preconditions.checkState(pyActor instanceof NativeActorHandle);
actorHandleBytes = ((NativeActorHandle) pyActor).toBytes();
res = Ray.task(
new PyRemoteFunction<>(PYTHON_MODULE,
PyFunction.of(PYTHON_MODULE,
"py_func_call_python_actor_from_handle",
byte[].class),
actorHandleBytes).remote();
@@ -221,7 +221,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
// This function will be called from test_cross_language_invocation.py
NativePyActorHandle actor = (NativePyActorHandle) NativeActorHandle.fromBytes(value);
ObjectRef<byte[]> res = actor.task(
new PyActorMethod<>("increase", byte[].class),
PyActorMethod.of("increase", byte[].class),
"1".getBytes()).remote();
Assert.assertEquals(res.get(), "3".getBytes());
return (byte[]) res.get();
@@ -13,7 +13,7 @@ public class RaySerializerTest extends BaseMultiLanguageTest {
@Test
public void testSerializePyActor() {
PyActorHandle pyActor = Ray.actor(
new PyActorClass("test", "RaySerializerTest")).remote();
PyActorClass.of("test", "RaySerializerTest")).remote();
NativeRayObject nativeRayObject = ObjectSerializer.serialize(pyActor);
PyActorHandle result = (PyActorHandle) ObjectSerializer
.deserialize(nativeRayObject, null, Object.class);