From 9d133e874c28902907a7ed1eccac130cabe85231 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Mon, 23 Nov 2020 14:07:10 +0800 Subject: [PATCH] [Java] support java actor class inheritance (#12001) --- .../functionmanager/FunctionManager.java | 23 +++- .../functionmanager/FunctionManagerTest.java | 101 +++++++++++++----- .../src/main/java/io/ray/test/ActorTest.java | 36 +++++++ 3 files changed, 133 insertions(+), 27 deletions(-) diff --git a/java/runtime/src/main/java/io/ray/runtime/functionmanager/FunctionManager.java b/java/runtime/src/main/java/io/ray/runtime/functionmanager/FunctionManager.java index 05b248dc7..c108139ac 100644 --- a/java/runtime/src/main/java/io/ray/runtime/functionmanager/FunctionManager.java +++ b/java/runtime/src/main/java/io/ray/runtime/functionmanager/FunctionManager.java @@ -1,5 +1,6 @@ package io.ray.runtime.functionmanager; +import com.google.common.collect.Lists; import io.ray.api.function.RayFunc; import io.ray.api.id.JobId; import io.ray.runtime.util.LambdaUtils; @@ -204,12 +205,28 @@ public class FunctionManager { Map, RayFunction> map = new HashMap<>(); try { Class clazz = Class.forName(className, true, classLoader); - List executables = new ArrayList<>(); executables.addAll(Arrays.asList(clazz.getDeclaredMethods())); - executables.addAll(Arrays.asList(clazz.getConstructors())); + executables.addAll(Arrays.asList(clazz.getDeclaredConstructors())); - for (Executable e : executables) { + Class clz = clazz; + clz = clz.getSuperclass(); + while (clz != null && clz != Object.class) { + executables.addAll(Arrays.asList(clz.getDeclaredMethods())); + clz = clz.getSuperclass(); + } + + // Put interface methods ahead, so that in can be override by subclass methods in `map.put` + for (Class baseInterface : clazz.getInterfaces()) { + for (Method method : baseInterface.getDeclaredMethods()) { + if (method.isDefault()) { + executables.add(method); + } + } + } + + // Use reverse order so that child class methods can override super class methods. + for (Executable e : Lists.reverse(executables)) { e.setAccessible(true); final String methodName = e instanceof Method ? e.getName() : CONSTRUCTOR_NAME; final Type type = diff --git a/java/runtime/src/test/java/io/ray/runtime/functionmanager/FunctionManagerTest.java b/java/runtime/src/test/java/io/ray/runtime/functionmanager/FunctionManagerTest.java index f369cc939..c6270c997 100644 --- a/java/runtime/src/test/java/io/ray/runtime/functionmanager/FunctionManagerTest.java +++ b/java/runtime/src/test/java/io/ray/runtime/functionmanager/FunctionManagerTest.java @@ -27,15 +27,36 @@ public class FunctionManagerTest { return null; } - public static class Bar { + public static class ParentClass { - public Bar() { + public Object foo() { + return null; } public Object bar() { return null; } + } + + public interface ChildClassInterface { + + default String interfaceName() { + return getClass().getName(); + } + + } + + public static class ChildClass extends ParentClass implements ChildClassInterface { + + public ChildClass() { + } + + @Override + public Object bar() { + return null; + } + public Object overloadFunction(int i) { return null; } @@ -46,24 +67,24 @@ public class FunctionManagerTest { } private static RayFunc0 fooFunc; - private static RayFunc1 barFunc; - private static RayFunc0 barConstructor; + private static RayFunc1 childClassBarFunc; + private static RayFunc0 childClassConstructor; private static JavaFunctionDescriptor fooDescriptor; - private static JavaFunctionDescriptor barDescriptor; - private static JavaFunctionDescriptor barConstructorDescriptor; + private static JavaFunctionDescriptor childClassBarDescriptor; + private static JavaFunctionDescriptor childClassConstructorDescriptor; private static JavaFunctionDescriptor overloadFunctionDescriptorInt; private static JavaFunctionDescriptor overloadFunctionDescriptorDouble; @BeforeClass public static void beforeClass() { fooFunc = FunctionManagerTest::foo; - barConstructor = Bar::new; - barFunc = Bar::bar; + childClassConstructor = ChildClass::new; + childClassBarFunc = ChildClass::bar; fooDescriptor = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(), "foo", "()Ljava/lang/Object;"); - barDescriptor = new JavaFunctionDescriptor(Bar.class.getName(), "bar", + childClassBarDescriptor = new JavaFunctionDescriptor(ChildClass.class.getName(), "bar", "()Ljava/lang/Object;"); - barConstructorDescriptor = new JavaFunctionDescriptor(Bar.class.getName(), + childClassConstructorDescriptor = new JavaFunctionDescriptor(ChildClass.class.getName(), FunctionManager.CONSTRUCTOR_NAME, "()V"); overloadFunctionDescriptorInt = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(), @@ -81,14 +102,14 @@ public class FunctionManagerTest { Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor); // Test actor method - func = functionManager.getFunction(JobId.NIL, barFunc); + func = functionManager.getFunction(JobId.NIL, childClassBarFunc); Assert.assertFalse(func.isConstructor()); - Assert.assertEquals(func.getFunctionDescriptor(), barDescriptor); + Assert.assertEquals(func.getFunctionDescriptor(), childClassBarDescriptor); // Test actor constructor - func = functionManager.getFunction(JobId.NIL, barConstructor); + func = functionManager.getFunction(JobId.NIL, childClassConstructor); Assert.assertTrue(func.isConstructor()); - Assert.assertEquals(func.getFunctionDescriptor(), barConstructorDescriptor); + Assert.assertEquals(func.getFunctionDescriptor(), childClassConstructorDescriptor); } @Test @@ -100,14 +121,14 @@ public class FunctionManagerTest { Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor); // Test actor method - func = functionManager.getFunction(JobId.NIL, barDescriptor); + func = functionManager.getFunction(JobId.NIL, childClassBarDescriptor); Assert.assertFalse(func.isConstructor()); - Assert.assertEquals(func.getFunctionDescriptor(), barDescriptor); + Assert.assertEquals(func.getFunctionDescriptor(), childClassBarDescriptor); // Test actor constructor - func = functionManager.getFunction(JobId.NIL, barConstructorDescriptor); + func = functionManager.getFunction(JobId.NIL, childClassConstructorDescriptor); Assert.assertTrue(func.isConstructor()); - Assert.assertEquals(func.getFunctionDescriptor(), barConstructorDescriptor); + Assert.assertEquals(func.getFunctionDescriptor(), childClassConstructorDescriptor); // Test raise overload exception Assert.expectThrows(RuntimeException.class, () -> { @@ -117,25 +138,57 @@ public class FunctionManagerTest { }); } + @Test + public void testInheritance() { + final FunctionManager functionManager = new FunctionManager(null); + // Check inheritance can work and FunctionManager can find method in parent class. + fooDescriptor = new JavaFunctionDescriptor(ParentClass.class.getName(), "foo", + "()Ljava/lang/Object;"); + Assert.assertEquals(functionManager.getFunction(JobId.NIL, fooDescriptor) + .executable.getDeclaringClass(), ParentClass.class); + RayFunction fooFunc = functionManager.getFunction(JobId.NIL, + new JavaFunctionDescriptor(ChildClass.class.getName(), "foo", + "()Ljava/lang/Object;")); + Assert.assertEquals(fooFunc.executable.getDeclaringClass(), ParentClass.class); + + // Check FunctionManager can use method in child class if child class methods overrides methods + // in parent class. + childClassBarDescriptor = new JavaFunctionDescriptor(ParentClass.class.getName(), "bar", + "()Ljava/lang/Object;"); + Assert.assertEquals(functionManager.getFunction(JobId.NIL, childClassBarDescriptor) + .executable.getDeclaringClass(), ParentClass.class); + RayFunction barFunc = functionManager.getFunction(JobId.NIL, + new JavaFunctionDescriptor(ChildClass.class.getName(), "bar", + "()Ljava/lang/Object;")); + Assert.assertEquals(barFunc.executable.getDeclaringClass(), ChildClass.class); + + // Check interface default methods. + RayFunction interfaceNameFunc = functionManager.getFunction(JobId.NIL, + new JavaFunctionDescriptor(ChildClass.class.getName(), "interfaceName", + "()Ljava/lang/String;")); + Assert.assertEquals(interfaceNameFunc.executable.getDeclaringClass(), + ChildClassInterface.class); + } + @Test public void testLoadFunctionTableForClass() { JobFunctionTable functionTable = new JobFunctionTable(getClass().getClassLoader()); Map, RayFunction> res = functionTable - .loadFunctionsForClass(Bar.class.getName()); + .loadFunctionsForClass(ChildClass.class.getName()); // The result should be 4 entries: // 1, the constructor with signature // 2, the constructor without signature // 3, bar with signature // 4, bar without signature - Assert.assertEquals(res.size(), 7); + Assert.assertEquals(res.size(), 11); Assert.assertTrue(res.containsKey( - ImmutablePair.of(barDescriptor.name, barDescriptor.signature))); + ImmutablePair.of(childClassBarDescriptor.name, childClassBarDescriptor.signature))); Assert.assertTrue(res.containsKey( - ImmutablePair.of(barConstructorDescriptor.name, barConstructorDescriptor.signature))); + ImmutablePair.of(childClassConstructorDescriptor.name, childClassConstructorDescriptor.signature))); Assert.assertTrue(res.containsKey( - ImmutablePair.of(barDescriptor.name, ""))); + ImmutablePair.of(childClassBarDescriptor.name, ""))); Assert.assertTrue(res.containsKey( - ImmutablePair.of(barConstructorDescriptor.name, ""))); + ImmutablePair.of(childClassConstructorDescriptor.name, ""))); Assert.assertTrue(res.containsKey( ImmutablePair.of(overloadFunctionDescriptorInt.name, overloadFunctionDescriptorInt.signature))); Assert.assertTrue(res.containsKey( diff --git a/java/test/src/main/java/io/ray/test/ActorTest.java b/java/test/src/main/java/io/ray/test/ActorTest.java index 045dd0191..a7e9c6ac6 100644 --- a/java/test/src/main/java/io/ray/test/ActorTest.java +++ b/java/test/src/main/java/io/ray/test/ActorTest.java @@ -144,4 +144,40 @@ public class ActorTest extends BaseTest { // Free deletes big objects from plasma store. Assert.expectThrows(UnreconstructableException.class, () -> largeValue.get()); } + + public interface ChildClassInterface { + + default String interfaceName() { + return ChildClassInterface.class.getName(); + } + + } + + public static class ChildClass extends Counter implements ChildClassInterface { + + public ChildClass(int initValue) { + super(initValue); + } + + @Override + public void increase(int delta) { + super.increase(-delta); + } + + } + + @Test(groups = {"cluster"}) + public void testInheritance() { + ActorHandle counter = Ray.actor(ChildClass::new, 100).remote(); + counter.task(ChildClass::increase, 10).remote(); + Assert.assertEquals(counter.task(ChildClass::getValue).remote().get(), Integer.valueOf(90)); + // Since `increase` method is overrided, call by super class method reference should still + // execute child class methods. + counter.task(Counter::increase, 10).remote(); + Assert.assertEquals(counter.task(Counter::getValue).remote().get(), Integer.valueOf(80)); + // test interface default methods + Assert.assertEquals(counter.task(ChildClassInterface::interfaceName).remote().get(), + ChildClassInterface.class.getName()); + } + }