Cross language serialization for primitive types (#7711)

* Cross language serialization for Java and Python

* Use strict types when Python serializing

* Handle recursive objects in Python; Pin msgpack >= 0.6.0, < 1.0.0

* Disable gc for optimizing msgpack loads

* Fix merge bug

* Java call Python use returnType; Fix ClassLoaderTest

* Fix RayMethodsTest

* Fix checkstyle

* Fix lint

* prepare_args raises exception if try to transfer a non-deserializable object to another language

* Fix CrossLanguageInvocationTest.java, Python msgpack treat float as double

* Minor fixes

* Fix compile error on linux

* Fix lint in java/BUILD.bazel

* Fix test_failure

* Fix lint

* Class<?> to Class<T>; Refine metadata bytes.

* Rename FST to Fst; sort java dependencies

* Change Class<?>[] to Optional<Class<?>>; sort requirements in setup.py

* Improve CrossLanguageInvocationTest

* Refactor MessagePackSerializer.java

* Refactor MessagePackSerializer.java; Refine CrossLanguageInvocationTest.java

* Remove unnecessary dependencies for Java; Add getReturnType() for RayFunction in Java

* Fix bug

* Remove custom cross language type support

* Replace Serializer.Meta with MutableBoolean

* Remove @SuppressWarnings support from checkstyle.xml; Add null test in CrossLanguageInvocationTest.java

* Refine MessagePackSerializer.pack

* Ray.get support RayObject as input

* Improve comments and error info

* Remove classLoader argument from serializer

* Separate msgpack from pickle5 in Python

* Pair<byte[], MutableBoolean> to Pair<byte[], Boolean>

* Remove public static <T> T get(RayObject<T> object), use RayObject.get() instead

* Refine test

* small fixes

Co-authored-by: 刘宝 <po.lb@antfin.com>
Co-authored-by: Hao Chen <chenh1024@gmail.com>
This commit is contained in:
fyrestone
2020-04-08 21:10:57 +08:00
committed by GitHub
parent e8c19aba41
commit fc6259a656
42 changed files with 1057 additions and 313 deletions
@@ -143,7 +143,7 @@ public class ActorTest extends BaseTest {
try {
// Try getting the object again, this should throw an UnreconstructableException.
// Use `Ray.get()` to bypass the cache in `RayObjectImpl`.
Ray.get(value.getId());
Ray.get(value.getId(), value.getType());
Assert.fail("This line should not be reachable.");
} catch (UnreconstructableException e) {
Assert.assertEquals(value.getId(), e.objectId);
@@ -4,6 +4,7 @@ import java.io.File;
import java.lang.reflect.Method;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Optional;
import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;
import org.apache.commons.io.FileUtils;
@@ -101,12 +102,14 @@ public class ClassLoaderTest extends BaseTest {
"()V");
RayActor<?> actor1 = createActor(constructor);
FunctionDescriptor getPid = new JavaFunctionDescriptor("ClassLoaderTester", "getPid", "()I");
int pid = this.<Integer>callActorFunction(actor1, getPid, new Object[0], 1).get();
int pid = this.<Integer>callActorFunction(actor1, getPid, new Object[0],
Optional.of(Integer.class)).get();
RayActor<?> actor2;
while (true) {
// Create another actor which share the same process of actor 1.
actor2 = createActor(constructor);
int actor2Pid = this.<Integer>callActorFunction(actor2, getPid, new Object[0], 1).get();
int actor2Pid = this.<Integer>callActorFunction(actor2, getPid, new Object[0],
Optional.of(Integer.class)).get();
if (actor2Pid == pid) {
break;
}
@@ -116,15 +119,17 @@ public class ClassLoaderTest extends BaseTest {
"getClassLoaderHashCode",
"()I");
RayObject<Integer> hashCode1 = callActorFunction(actor1, getClassLoaderHashCode, new Object[0],
1);
Optional.of(Integer.class));
RayObject<Integer> hashCode2 = callActorFunction(actor2, getClassLoaderHashCode, new Object[0],
1);
Optional.of(Integer.class));
Assert.assertEquals(hashCode1.get(), hashCode2.get());
FunctionDescriptor increase = new JavaFunctionDescriptor("ClassLoaderTester", "increase",
"()I");
RayObject<Integer> value1 = callActorFunction(actor1, increase, new Object[0], 1);
RayObject<Integer> value2 = callActorFunction(actor2, increase, new Object[0], 1);
RayObject<Integer> value1 = callActorFunction(actor1, increase, new Object[0],
Optional.of(Integer.class));
RayObject<Integer> value2 = callActorFunction(actor2, increase, new Object[0],
Optional.of(Integer.class));
Assert.assertNotEquals(value1.get(), value2.get());
}
@@ -138,11 +143,12 @@ public class ClassLoaderTest extends BaseTest {
}
private <T> RayObject<T> callActorFunction(RayActor<?> rayActor,
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) throws Exception {
FunctionDescriptor functionDescriptor, Object[] args, Optional<Class<?>> returnType)
throws Exception {
Method callActorFunctionMethod = AbstractRayRuntime.class.getDeclaredMethod("callActorFunction",
BaseActor.class, FunctionDescriptor.class, Object[].class, int.class);
BaseActor.class, FunctionDescriptor.class, Object[].class, Optional.class);
callActorFunctionMethod.setAccessible(true);
return (RayObject<T>) callActorFunctionMethod
.invoke(TestUtils.getUnderlyingRuntime(), rayActor, functionDescriptor, args, numReturns);
.invoke(TestUtils.getUnderlyingRuntime(), rayActor, functionDescriptor, args, returnType);
}
}
@@ -23,7 +23,7 @@ public class ClientExceptionTest extends BaseTest {
public void testWaitAndCrash() {
TestUtils.skipTestUnderSingleProcess();
ObjectId randomId = ObjectId.fromRandom();
RayObject<String> notExisting = new RayObjectImpl(randomId);
RayObject<String> notExisting = new RayObjectImpl(randomId, String.class);
Thread thread = new Thread(() -> {
try {
@@ -5,6 +5,9 @@ import com.google.common.collect.ImmutableMap;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.ray.api.Ray;
@@ -51,18 +54,85 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
@Test
public void testCallingPythonFunction() {
RayObject<byte[]> res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_func", byte[].class),
"hello".getBytes());
Assert.assertEquals(res.get(), "Response from Python: hello".getBytes());
Object[] inputs = new Object[]{
true, // Boolean
Byte.MAX_VALUE, // Byte
Short.MAX_VALUE, // Short
Integer.MAX_VALUE, // Integer
Long.MAX_VALUE, // Long
// BigInteger can support max value of 2^64-1, please refer to:
// https://github.com/msgpack/msgpack/blob/master/spec.md#int-format-family
// If BigInteger larger than 2^64-1, the value can only be transferred among Java workers.
BigInteger.valueOf(Long.MAX_VALUE), // BigInteger
"Hello World!", // String
1.234f, // Float
1.234, // Double
"example binary".getBytes()}; // byte[]
for (Object o : inputs) {
RayObject res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", o.getClass()),
o);
Assert.assertEquals(res.get(), o);
}
// null
{
Object input = null;
RayObject<Object> res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", Object.class), input);
Object r = res.get();
Assert.assertEquals(r, input);
}
// array
{
int[] input = new int[]{1, 2};
RayObject<int[]> res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", int[].class), input);
int[] r = res.get();
Assert.assertEquals(r, input);
}
// array of Object
{
Object[] input = new Object[]{1, 2.3f, 4.56, "789", "10".getBytes(), null, true,
new int[]{1, 2}};
RayObject<Object[]> res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", Object[].class), input);
Object[] r = res.get();
// If we tell the value type is Object, then all numbers will be Number type.
Assert.assertEquals(((Number) r[0]).intValue(), input[0]);
Assert.assertEquals(((Number) r[1]).floatValue(), input[1]);
Assert.assertEquals(((Number) r[2]).doubleValue(), input[2]);
// String cast
Assert.assertEquals((String) r[3], input[3]);
// binary cast
Assert.assertEquals((byte[]) r[4], input[4]);
// null
Assert.assertEquals(r[5], input[5]);
// Boolean cast
Assert.assertEquals((Boolean) r[6], input[6]);
// array cast
Object[] r7array = (Object[]) r[7];
int[] input7array = (int[]) input[7];
Assert.assertEquals(((Number) r7array[0]).intValue(), input7array[0]);
Assert.assertEquals(((Number) r7array[1]).intValue(), input7array[1]);
}
// Unsupported types, all Java specific types, e.g. List / Map...
{
Assert.expectThrows(Exception.class, () -> {
List<Integer> input = Arrays.asList(1, 2);
RayObject<List<Integer>> res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input",
(Class<List<Integer>>) input.getClass()), input);
List<Integer> r = res.get();
Assert.assertEquals(r, input);
});
}
}
@Test
public void testPythonCallJavaFunction() {
RayObject<byte[]> res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_func_call_java_function", byte[].class),
"hello".getBytes());
Assert.assertEquals(res.get(), "[Python]py_func -> [Java]bytesEcho -> hello".getBytes());
RayObject<String> res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_func_call_java_function", String.class));
Assert.assertEquals(res.get(), "success");
}
@Test
@@ -117,11 +187,33 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
Assert.assertEquals(res.get(), "3".getBytes());
}
public static byte[] bytesEcho(byte[] value) {
public static Object[] pack(int i, String s, double f, Object[] o) {
// This function will be called from test_cross_language_invocation.py
String valueStr = new String(value);
LOGGER.debug(String.format("bytesEcho called with: %s", valueStr));
return ("[Java]bytesEcho -> " + valueStr).getBytes();
return new Object[]{i, s, f, o};
}
public static Object returnInput(Object o) {
return o;
}
public static boolean returnInputBoolean(boolean b) {
return b;
}
public static int returnInputInt(int i) {
return i;
}
public static double returnInputDouble(double d) {
return d;
}
public static String returnInputString(String s) {
return s;
}
public static int[] returnInputIntArray(int[] l) {
return l;
}
public static byte[] callPythonActorHandle(byte[] value) {
@@ -135,6 +227,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
}
public static class TestActor {
public TestActor(byte[] v) {
value = v;
}
@@ -45,7 +45,7 @@ public class DynamicResourceTest extends BaseTest {
// Assert ray call result.
result = Ray.wait(ImmutableList.of(obj), 1, 1000);
Assert.assertEquals(result.getReady().size(), 1);
Assert.assertEquals(Ray.get(obj.getId()), "hi");
Assert.assertEquals(obj.get(), "hi");
}
@@ -148,7 +148,7 @@ public class FailureTest extends BaseTest {
RayObject<Integer> obj2 = Ray.call(FailureTest::slowFunc);
Instant start = Instant.now();
try {
Ray.get(Arrays.asList(obj1.getId(), obj2.getId()));
Ray.get(Arrays.asList(obj1, obj2));
Assert.fail("Should throw RayException.");
} catch (RayException e) {
Instant end = Instant.now();
@@ -104,7 +104,7 @@ public class MultiThreadingTest extends BaseTest {
runTestCaseInMultipleThreads(() -> {
int arg = random.nextInt();
RayObject<Integer> obj = Ray.put(arg);
Assert.assertEquals(arg, (int) Ray.get(obj.getId()));
Assert.assertEquals(arg, (int) obj.get());
}, LOOP_COUNTER);
TestUtils.warmUpCluster();
@@ -141,7 +141,7 @@ public class MultiThreadingTest extends BaseTest {
final RayActor<Echo> fooActor = Ray.createActor(Echo::new);
final Runnable[] runnables = new Runnable[]{
() -> Ray.put(1),
() -> Ray.get(fooObject.getId()),
() -> Ray.get(fooObject.getId(), fooObject.getType()),
fooObject::get,
() -> Ray.wait(ImmutableList.of(fooObject)),
Ray::getRuntimeContext,
@@ -16,8 +16,22 @@ public class ObjectStoreTest extends BaseTest {
@Test
public void testPutAndGet() {
RayObject<Integer> obj = Ray.put(1);
Assert.assertEquals(1, (int) obj.get());
{
RayObject<Integer> obj = Ray.put(1);
Assert.assertEquals(1, (int) obj.get());
}
{
String s = null;
RayObject<String> obj = Ray.put(s);
Assert.assertNull(obj.get());
}
{
List<List<String>> l = ImmutableList.of(ImmutableList.of("abc"));
RayObject<List<List<String>>> obj = Ray.put(l);
Assert.assertEquals(obj.get(), l);
}
}
@Test
@@ -25,6 +39,6 @@ public class ObjectStoreTest extends BaseTest {
List<Integer> ints = ImmutableList.of(1, 2, 3, 4, 5);
List<ObjectId> ids = ints.stream().map(obj -> Ray.put(obj).getId())
.collect(Collectors.toList());
Assert.assertEquals(ints, Ray.get(ids));
Assert.assertEquals(ints, Ray.get(ids, Integer.class));
}
}
@@ -15,9 +15,9 @@ public class PlasmaStoreTest extends BaseTest {
ObjectId objectId = ObjectId.fromRandom();
ObjectStore objectStore = TestUtils.getRuntime().getObjectStore();
objectStore.put("1", objectId);
Assert.assertEquals(Ray.get(objectId), "1");
Assert.assertEquals(Ray.get(objectId, String.class), "1");
objectStore.put("2", objectId);
// Putting the second object with duplicate ID should fail but ignored.
Assert.assertEquals(Ray.get(objectId), "1");
Assert.assertEquals(Ray.get(objectId, String.class), "1");
}
}
@@ -87,7 +87,7 @@ public class RayCallTest extends BaseTest {
ObjectId randomObjectId = ObjectId.fromRandom();
Ray.call(RayCallTest::testNoReturn, randomObjectId);
Assert.assertEquals(((int) Ray.get(randomObjectId)), 1);
Assert.assertEquals(((int) Ray.get(randomObjectId, Integer.class)), 1);
}
private static int testNoParam() {
@@ -2,9 +2,7 @@ package org.ray.api.test;
import org.ray.api.Ray;
import org.ray.api.RayPyActor;
import org.ray.api.TestUtils;
import org.ray.api.function.PyActorClass;
import org.ray.runtime.context.WorkerContext;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectSerializer;
import org.testng.Assert;
@@ -15,10 +13,9 @@ public class RaySerializerTest extends BaseMultiLanguageTest {
@Test
public void testSerializePyActor() {
RayPyActor pyActor = Ray.createActor(new PyActorClass("test", "RaySerializerTest"));
WorkerContext workerContext = TestUtils.getRuntime().getWorkerContext();
NativeRayObject nativeRayObject = ObjectSerializer.serialize(pyActor);
RayPyActor result = (RayPyActor) ObjectSerializer
.deserialize(nativeRayObject, null, workerContext.getCurrentClassLoader());
.deserialize(nativeRayObject, null, Object.class);
Assert.assertEquals(result.getId(), pyActor.getId());
Assert.assertEquals(result.getModuleName(), "test");
Assert.assertEquals(result.getClassName(), "RaySerializerTest");
@@ -28,7 +28,7 @@ public class StressTest extends BaseTest {
resultIds.add(Ray.call(StressTest::echo, 1).getId());
}
for (Integer result : Ray.<Integer>get(resultIds)) {
for (Integer result : Ray.<Integer>get(resultIds, Integer.class)) {
Assert.assertEquals(result, Integer.valueOf(1));
}
}
@@ -67,7 +67,7 @@ public class StressTest extends BaseTest {
objectIds.add(actor.call(Actor::ping).getId());
}
int sum = 0;
for (Integer result : Ray.<Integer>get(objectIds)) {
for (Integer result : Ray.<Integer>get(objectIds, Integer.class)) {
sum += result;
}
return sum;
@@ -84,7 +84,7 @@ public class StressTest extends BaseTest {
objectIds.add(worker.call(Worker::ping, 100).getId());
}
for (Integer result : Ray.<Integer>get(objectIds)) {
for (Integer result : Ray.<Integer>get(objectIds, Integer.class)) {
Assert.assertEquals(result, Integer.valueOf(100));
}
}
@@ -5,18 +5,47 @@ import ray
@ray.remote
def py_func(value):
assert isinstance(value, bytes)
return b"Response from Python: " + value
def py_return_input(v):
return v
@ray.remote
def py_func_call_java_function(value):
assert isinstance(value, bytes)
f = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"bytesEcho")
r = f.remote(value)
return b"[Python]py_func -> " + ray.get(r)
def py_func_call_java_function():
try:
# None
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"returnInput").remote(None)
assert ray.get(r) is None
# bool
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"returnInputBoolean").remote(True)
assert ray.get(r) is True
# int
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"returnInputInt").remote(100)
assert ray.get(r) == 100
# double
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"returnInputDouble").remote(1.23)
assert ray.get(r) == 1.23
# string
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"returnInputString").remote("Hello World!")
assert ray.get(r) == "Hello World!"
# list (tuple will be packed by pickle,
# so only list can be transferred across language)
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"returnInputIntArray").remote([1, 2, 3])
assert ray.get(r) == [1, 2, 3]
# pack
f = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"pack")
input = [100, "hello", 1.23, [1, "2", 3.0]]
r = f.remote(*input)
assert ray.get(r) == input
return "success"
except Exception as ex:
return str(ex)
@ray.remote