mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:37:39 +08:00
Cross language serialization for primitive types (#7711)
* Cross language serialization for Java and Python * Use strict types when Python serializing * Handle recursive objects in Python; Pin msgpack >= 0.6.0, < 1.0.0 * Disable gc for optimizing msgpack loads * Fix merge bug * Java call Python use returnType; Fix ClassLoaderTest * Fix RayMethodsTest * Fix checkstyle * Fix lint * prepare_args raises exception if try to transfer a non-deserializable object to another language * Fix CrossLanguageInvocationTest.java, Python msgpack treat float as double * Minor fixes * Fix compile error on linux * Fix lint in java/BUILD.bazel * Fix test_failure * Fix lint * Class<?> to Class<T>; Refine metadata bytes. * Rename FST to Fst; sort java dependencies * Change Class<?>[] to Optional<Class<?>>; sort requirements in setup.py * Improve CrossLanguageInvocationTest * Refactor MessagePackSerializer.java * Refactor MessagePackSerializer.java; Refine CrossLanguageInvocationTest.java * Remove unnecessary dependencies for Java; Add getReturnType() for RayFunction in Java * Fix bug * Remove custom cross language type support * Replace Serializer.Meta with MutableBoolean * Remove @SuppressWarnings support from checkstyle.xml; Add null test in CrossLanguageInvocationTest.java * Refine MessagePackSerializer.pack * Ray.get support RayObject as input * Improve comments and error info * Remove classLoader argument from serializer * Separate msgpack from pickle5 in Python * Pair<byte[], MutableBoolean> to Pair<byte[], Boolean> * Remove public static <T> T get(RayObject<T> object), use RayObject.get() instead * Refine test * small fixes Co-authored-by: 刘宝 <po.lb@antfin.com> Co-authored-by: Hao Chen <chenh1024@gmail.com>
This commit is contained in:
@@ -143,7 +143,7 @@ public class ActorTest extends BaseTest {
|
||||
try {
|
||||
// Try getting the object again, this should throw an UnreconstructableException.
|
||||
// Use `Ray.get()` to bypass the cache in `RayObjectImpl`.
|
||||
Ray.get(value.getId());
|
||||
Ray.get(value.getId(), value.getType());
|
||||
Assert.fail("This line should not be reachable.");
|
||||
} catch (UnreconstructableException e) {
|
||||
Assert.assertEquals(value.getId(), e.objectId);
|
||||
|
||||
@@ -4,6 +4,7 @@ import java.io.File;
|
||||
import java.lang.reflect.Method;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Optional;
|
||||
import javax.tools.JavaCompiler;
|
||||
import javax.tools.ToolProvider;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
@@ -101,12 +102,14 @@ public class ClassLoaderTest extends BaseTest {
|
||||
"()V");
|
||||
RayActor<?> actor1 = createActor(constructor);
|
||||
FunctionDescriptor getPid = new JavaFunctionDescriptor("ClassLoaderTester", "getPid", "()I");
|
||||
int pid = this.<Integer>callActorFunction(actor1, getPid, new Object[0], 1).get();
|
||||
int pid = this.<Integer>callActorFunction(actor1, getPid, new Object[0],
|
||||
Optional.of(Integer.class)).get();
|
||||
RayActor<?> actor2;
|
||||
while (true) {
|
||||
// Create another actor which share the same process of actor 1.
|
||||
actor2 = createActor(constructor);
|
||||
int actor2Pid = this.<Integer>callActorFunction(actor2, getPid, new Object[0], 1).get();
|
||||
int actor2Pid = this.<Integer>callActorFunction(actor2, getPid, new Object[0],
|
||||
Optional.of(Integer.class)).get();
|
||||
if (actor2Pid == pid) {
|
||||
break;
|
||||
}
|
||||
@@ -116,15 +119,17 @@ public class ClassLoaderTest extends BaseTest {
|
||||
"getClassLoaderHashCode",
|
||||
"()I");
|
||||
RayObject<Integer> hashCode1 = callActorFunction(actor1, getClassLoaderHashCode, new Object[0],
|
||||
1);
|
||||
Optional.of(Integer.class));
|
||||
RayObject<Integer> hashCode2 = callActorFunction(actor2, getClassLoaderHashCode, new Object[0],
|
||||
1);
|
||||
Optional.of(Integer.class));
|
||||
Assert.assertEquals(hashCode1.get(), hashCode2.get());
|
||||
|
||||
FunctionDescriptor increase = new JavaFunctionDescriptor("ClassLoaderTester", "increase",
|
||||
"()I");
|
||||
RayObject<Integer> value1 = callActorFunction(actor1, increase, new Object[0], 1);
|
||||
RayObject<Integer> value2 = callActorFunction(actor2, increase, new Object[0], 1);
|
||||
RayObject<Integer> value1 = callActorFunction(actor1, increase, new Object[0],
|
||||
Optional.of(Integer.class));
|
||||
RayObject<Integer> value2 = callActorFunction(actor2, increase, new Object[0],
|
||||
Optional.of(Integer.class));
|
||||
Assert.assertNotEquals(value1.get(), value2.get());
|
||||
}
|
||||
|
||||
@@ -138,11 +143,12 @@ public class ClassLoaderTest extends BaseTest {
|
||||
}
|
||||
|
||||
private <T> RayObject<T> callActorFunction(RayActor<?> rayActor,
|
||||
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) throws Exception {
|
||||
FunctionDescriptor functionDescriptor, Object[] args, Optional<Class<?>> returnType)
|
||||
throws Exception {
|
||||
Method callActorFunctionMethod = AbstractRayRuntime.class.getDeclaredMethod("callActorFunction",
|
||||
BaseActor.class, FunctionDescriptor.class, Object[].class, int.class);
|
||||
BaseActor.class, FunctionDescriptor.class, Object[].class, Optional.class);
|
||||
callActorFunctionMethod.setAccessible(true);
|
||||
return (RayObject<T>) callActorFunctionMethod
|
||||
.invoke(TestUtils.getUnderlyingRuntime(), rayActor, functionDescriptor, args, numReturns);
|
||||
.invoke(TestUtils.getUnderlyingRuntime(), rayActor, functionDescriptor, args, returnType);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ public class ClientExceptionTest extends BaseTest {
|
||||
public void testWaitAndCrash() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
ObjectId randomId = ObjectId.fromRandom();
|
||||
RayObject<String> notExisting = new RayObjectImpl(randomId);
|
||||
RayObject<String> notExisting = new RayObjectImpl(randomId, String.class);
|
||||
|
||||
Thread thread = new Thread(() -> {
|
||||
try {
|
||||
|
||||
@@ -5,6 +5,9 @@ import com.google.common.collect.ImmutableMap;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.math.BigInteger;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.ray.api.Ray;
|
||||
@@ -51,18 +54,85 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
|
||||
|
||||
@Test
|
||||
public void testCallingPythonFunction() {
|
||||
RayObject<byte[]> res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_func", byte[].class),
|
||||
"hello".getBytes());
|
||||
Assert.assertEquals(res.get(), "Response from Python: hello".getBytes());
|
||||
Object[] inputs = new Object[]{
|
||||
true, // Boolean
|
||||
Byte.MAX_VALUE, // Byte
|
||||
Short.MAX_VALUE, // Short
|
||||
Integer.MAX_VALUE, // Integer
|
||||
Long.MAX_VALUE, // Long
|
||||
// BigInteger can support max value of 2^64-1, please refer to:
|
||||
// https://github.com/msgpack/msgpack/blob/master/spec.md#int-format-family
|
||||
// If BigInteger larger than 2^64-1, the value can only be transferred among Java workers.
|
||||
BigInteger.valueOf(Long.MAX_VALUE), // BigInteger
|
||||
"Hello World!", // String
|
||||
1.234f, // Float
|
||||
1.234, // Double
|
||||
"example binary".getBytes()}; // byte[]
|
||||
for (Object o : inputs) {
|
||||
RayObject res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", o.getClass()),
|
||||
o);
|
||||
Assert.assertEquals(res.get(), o);
|
||||
}
|
||||
// null
|
||||
{
|
||||
Object input = null;
|
||||
RayObject<Object> res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", Object.class), input);
|
||||
Object r = res.get();
|
||||
Assert.assertEquals(r, input);
|
||||
}
|
||||
// array
|
||||
{
|
||||
int[] input = new int[]{1, 2};
|
||||
RayObject<int[]> res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", int[].class), input);
|
||||
int[] r = res.get();
|
||||
Assert.assertEquals(r, input);
|
||||
}
|
||||
// array of Object
|
||||
{
|
||||
Object[] input = new Object[]{1, 2.3f, 4.56, "789", "10".getBytes(), null, true,
|
||||
new int[]{1, 2}};
|
||||
RayObject<Object[]> res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", Object[].class), input);
|
||||
Object[] r = res.get();
|
||||
// If we tell the value type is Object, then all numbers will be Number type.
|
||||
Assert.assertEquals(((Number) r[0]).intValue(), input[0]);
|
||||
Assert.assertEquals(((Number) r[1]).floatValue(), input[1]);
|
||||
Assert.assertEquals(((Number) r[2]).doubleValue(), input[2]);
|
||||
// String cast
|
||||
Assert.assertEquals((String) r[3], input[3]);
|
||||
// binary cast
|
||||
Assert.assertEquals((byte[]) r[4], input[4]);
|
||||
// null
|
||||
Assert.assertEquals(r[5], input[5]);
|
||||
// Boolean cast
|
||||
Assert.assertEquals((Boolean) r[6], input[6]);
|
||||
// array cast
|
||||
Object[] r7array = (Object[]) r[7];
|
||||
int[] input7array = (int[]) input[7];
|
||||
Assert.assertEquals(((Number) r7array[0]).intValue(), input7array[0]);
|
||||
Assert.assertEquals(((Number) r7array[1]).intValue(), input7array[1]);
|
||||
}
|
||||
// Unsupported types, all Java specific types, e.g. List / Map...
|
||||
{
|
||||
Assert.expectThrows(Exception.class, () -> {
|
||||
List<Integer> input = Arrays.asList(1, 2);
|
||||
RayObject<List<Integer>> res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input",
|
||||
(Class<List<Integer>>) input.getClass()), input);
|
||||
List<Integer> r = res.get();
|
||||
Assert.assertEquals(r, input);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPythonCallJavaFunction() {
|
||||
RayObject<byte[]> res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_func_call_java_function", byte[].class),
|
||||
"hello".getBytes());
|
||||
Assert.assertEquals(res.get(), "[Python]py_func -> [Java]bytesEcho -> hello".getBytes());
|
||||
RayObject<String> res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_func_call_java_function", String.class));
|
||||
Assert.assertEquals(res.get(), "success");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -117,11 +187,33 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
|
||||
Assert.assertEquals(res.get(), "3".getBytes());
|
||||
}
|
||||
|
||||
public static byte[] bytesEcho(byte[] value) {
|
||||
public static Object[] pack(int i, String s, double f, Object[] o) {
|
||||
// This function will be called from test_cross_language_invocation.py
|
||||
String valueStr = new String(value);
|
||||
LOGGER.debug(String.format("bytesEcho called with: %s", valueStr));
|
||||
return ("[Java]bytesEcho -> " + valueStr).getBytes();
|
||||
return new Object[]{i, s, f, o};
|
||||
}
|
||||
|
||||
public static Object returnInput(Object o) {
|
||||
return o;
|
||||
}
|
||||
|
||||
public static boolean returnInputBoolean(boolean b) {
|
||||
return b;
|
||||
}
|
||||
|
||||
public static int returnInputInt(int i) {
|
||||
return i;
|
||||
}
|
||||
|
||||
public static double returnInputDouble(double d) {
|
||||
return d;
|
||||
}
|
||||
|
||||
public static String returnInputString(String s) {
|
||||
return s;
|
||||
}
|
||||
|
||||
public static int[] returnInputIntArray(int[] l) {
|
||||
return l;
|
||||
}
|
||||
|
||||
public static byte[] callPythonActorHandle(byte[] value) {
|
||||
@@ -135,6 +227,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
|
||||
}
|
||||
|
||||
public static class TestActor {
|
||||
|
||||
public TestActor(byte[] v) {
|
||||
value = v;
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ public class DynamicResourceTest extends BaseTest {
|
||||
// Assert ray call result.
|
||||
result = Ray.wait(ImmutableList.of(obj), 1, 1000);
|
||||
Assert.assertEquals(result.getReady().size(), 1);
|
||||
Assert.assertEquals(Ray.get(obj.getId()), "hi");
|
||||
Assert.assertEquals(obj.get(), "hi");
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -148,7 +148,7 @@ public class FailureTest extends BaseTest {
|
||||
RayObject<Integer> obj2 = Ray.call(FailureTest::slowFunc);
|
||||
Instant start = Instant.now();
|
||||
try {
|
||||
Ray.get(Arrays.asList(obj1.getId(), obj2.getId()));
|
||||
Ray.get(Arrays.asList(obj1, obj2));
|
||||
Assert.fail("Should throw RayException.");
|
||||
} catch (RayException e) {
|
||||
Instant end = Instant.now();
|
||||
|
||||
@@ -104,7 +104,7 @@ public class MultiThreadingTest extends BaseTest {
|
||||
runTestCaseInMultipleThreads(() -> {
|
||||
int arg = random.nextInt();
|
||||
RayObject<Integer> obj = Ray.put(arg);
|
||||
Assert.assertEquals(arg, (int) Ray.get(obj.getId()));
|
||||
Assert.assertEquals(arg, (int) obj.get());
|
||||
}, LOOP_COUNTER);
|
||||
|
||||
TestUtils.warmUpCluster();
|
||||
@@ -141,7 +141,7 @@ public class MultiThreadingTest extends BaseTest {
|
||||
final RayActor<Echo> fooActor = Ray.createActor(Echo::new);
|
||||
final Runnable[] runnables = new Runnable[]{
|
||||
() -> Ray.put(1),
|
||||
() -> Ray.get(fooObject.getId()),
|
||||
() -> Ray.get(fooObject.getId(), fooObject.getType()),
|
||||
fooObject::get,
|
||||
() -> Ray.wait(ImmutableList.of(fooObject)),
|
||||
Ray::getRuntimeContext,
|
||||
|
||||
@@ -16,8 +16,22 @@ public class ObjectStoreTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testPutAndGet() {
|
||||
RayObject<Integer> obj = Ray.put(1);
|
||||
Assert.assertEquals(1, (int) obj.get());
|
||||
{
|
||||
RayObject<Integer> obj = Ray.put(1);
|
||||
Assert.assertEquals(1, (int) obj.get());
|
||||
}
|
||||
|
||||
{
|
||||
String s = null;
|
||||
RayObject<String> obj = Ray.put(s);
|
||||
Assert.assertNull(obj.get());
|
||||
}
|
||||
|
||||
{
|
||||
List<List<String>> l = ImmutableList.of(ImmutableList.of("abc"));
|
||||
RayObject<List<List<String>>> obj = Ray.put(l);
|
||||
Assert.assertEquals(obj.get(), l);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -25,6 +39,6 @@ public class ObjectStoreTest extends BaseTest {
|
||||
List<Integer> ints = ImmutableList.of(1, 2, 3, 4, 5);
|
||||
List<ObjectId> ids = ints.stream().map(obj -> Ray.put(obj).getId())
|
||||
.collect(Collectors.toList());
|
||||
Assert.assertEquals(ints, Ray.get(ids));
|
||||
Assert.assertEquals(ints, Ray.get(ids, Integer.class));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,9 +15,9 @@ public class PlasmaStoreTest extends BaseTest {
|
||||
ObjectId objectId = ObjectId.fromRandom();
|
||||
ObjectStore objectStore = TestUtils.getRuntime().getObjectStore();
|
||||
objectStore.put("1", objectId);
|
||||
Assert.assertEquals(Ray.get(objectId), "1");
|
||||
Assert.assertEquals(Ray.get(objectId, String.class), "1");
|
||||
objectStore.put("2", objectId);
|
||||
// Putting the second object with duplicate ID should fail but ignored.
|
||||
Assert.assertEquals(Ray.get(objectId), "1");
|
||||
Assert.assertEquals(Ray.get(objectId, String.class), "1");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ public class RayCallTest extends BaseTest {
|
||||
|
||||
ObjectId randomObjectId = ObjectId.fromRandom();
|
||||
Ray.call(RayCallTest::testNoReturn, randomObjectId);
|
||||
Assert.assertEquals(((int) Ray.get(randomObjectId)), 1);
|
||||
Assert.assertEquals(((int) Ray.get(randomObjectId, Integer.class)), 1);
|
||||
}
|
||||
|
||||
private static int testNoParam() {
|
||||
|
||||
@@ -2,9 +2,7 @@ package org.ray.api.test;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.function.PyActorClass;
|
||||
import org.ray.runtime.context.WorkerContext;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectSerializer;
|
||||
import org.testng.Assert;
|
||||
@@ -15,10 +13,9 @@ public class RaySerializerTest extends BaseMultiLanguageTest {
|
||||
@Test
|
||||
public void testSerializePyActor() {
|
||||
RayPyActor pyActor = Ray.createActor(new PyActorClass("test", "RaySerializerTest"));
|
||||
WorkerContext workerContext = TestUtils.getRuntime().getWorkerContext();
|
||||
NativeRayObject nativeRayObject = ObjectSerializer.serialize(pyActor);
|
||||
RayPyActor result = (RayPyActor) ObjectSerializer
|
||||
.deserialize(nativeRayObject, null, workerContext.getCurrentClassLoader());
|
||||
.deserialize(nativeRayObject, null, Object.class);
|
||||
Assert.assertEquals(result.getId(), pyActor.getId());
|
||||
Assert.assertEquals(result.getModuleName(), "test");
|
||||
Assert.assertEquals(result.getClassName(), "RaySerializerTest");
|
||||
|
||||
@@ -28,7 +28,7 @@ public class StressTest extends BaseTest {
|
||||
resultIds.add(Ray.call(StressTest::echo, 1).getId());
|
||||
}
|
||||
|
||||
for (Integer result : Ray.<Integer>get(resultIds)) {
|
||||
for (Integer result : Ray.<Integer>get(resultIds, Integer.class)) {
|
||||
Assert.assertEquals(result, Integer.valueOf(1));
|
||||
}
|
||||
}
|
||||
@@ -67,7 +67,7 @@ public class StressTest extends BaseTest {
|
||||
objectIds.add(actor.call(Actor::ping).getId());
|
||||
}
|
||||
int sum = 0;
|
||||
for (Integer result : Ray.<Integer>get(objectIds)) {
|
||||
for (Integer result : Ray.<Integer>get(objectIds, Integer.class)) {
|
||||
sum += result;
|
||||
}
|
||||
return sum;
|
||||
@@ -84,7 +84,7 @@ public class StressTest extends BaseTest {
|
||||
objectIds.add(worker.call(Worker::ping, 100).getId());
|
||||
}
|
||||
|
||||
for (Integer result : Ray.<Integer>get(objectIds)) {
|
||||
for (Integer result : Ray.<Integer>get(objectIds, Integer.class)) {
|
||||
Assert.assertEquals(result, Integer.valueOf(100));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,18 +5,47 @@ import ray
|
||||
|
||||
|
||||
@ray.remote
|
||||
def py_func(value):
|
||||
assert isinstance(value, bytes)
|
||||
return b"Response from Python: " + value
|
||||
def py_return_input(v):
|
||||
return v
|
||||
|
||||
|
||||
@ray.remote
|
||||
def py_func_call_java_function(value):
|
||||
assert isinstance(value, bytes)
|
||||
f = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"bytesEcho")
|
||||
r = f.remote(value)
|
||||
return b"[Python]py_func -> " + ray.get(r)
|
||||
def py_func_call_java_function():
|
||||
try:
|
||||
# None
|
||||
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"returnInput").remote(None)
|
||||
assert ray.get(r) is None
|
||||
# bool
|
||||
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"returnInputBoolean").remote(True)
|
||||
assert ray.get(r) is True
|
||||
# int
|
||||
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"returnInputInt").remote(100)
|
||||
assert ray.get(r) == 100
|
||||
# double
|
||||
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"returnInputDouble").remote(1.23)
|
||||
assert ray.get(r) == 1.23
|
||||
# string
|
||||
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"returnInputString").remote("Hello World!")
|
||||
assert ray.get(r) == "Hello World!"
|
||||
# list (tuple will be packed by pickle,
|
||||
# so only list can be transferred across language)
|
||||
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"returnInputIntArray").remote([1, 2, 3])
|
||||
assert ray.get(r) == [1, 2, 3]
|
||||
# pack
|
||||
f = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"pack")
|
||||
input = [100, "hello", 1.23, [1, "2", 3.0]]
|
||||
r = f.remote(*input)
|
||||
assert ray.get(r) == input
|
||||
return "success"
|
||||
except Exception as ex:
|
||||
return str(ex)
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
||||
Reference in New Issue
Block a user