[java] put function meta in task spec and load functions with function meta (#2881)

This PR adds a `function_desc` field into task spec. a function descriptor is a list of strings that can uniquely describe a function.
- For a Python function, it should be: [module_name, class_name, function_name]
- For a Java function, it should be: [class_name, method_name, type_descriptor]

There're a couple of purposes to add this field:

In this PR:
- Java worker needs to know function's class name to load it. Previously, since task spec didn't have such a field to hold this info, we did a hack by appending the class name to the argument list. With this change, we fixed that hack and significantly simplified function management in Java.

Will be done in subsequent PRs:
- Support cross-language invocation (#2576): currently Python worker manages functions by saving them in GCS and pass function id in task spec. However, if we want to call a Python function from Java, we cannot save it in GCS and get the function id. But instead, we can pass the function descriptor (module name, class name, function name) in task spec and use it to load the function.
- Support deployment: one major problem of Python worker's current function management mechanism is #2327. In prod env, we should have a mechanism to deploy code and dependencies to the cluster. And when code is already deployed, we don't need to save functions to GCS any more and can use `function_desc` to manage functions.
This commit is contained in:
Hao Chen
2018-09-26 14:05:05 +08:00
committed by Robert Nishihara
parent 3cccb49191
commit 971df5ea8a
34 changed files with 664 additions and 1414 deletions
-2
View File
@@ -35,8 +35,6 @@
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<!-- <scope>test</scope> -->
</dependency>
<!-- https://mvnrepository.com/artifact/commons-collections/commons-collections -->
@@ -1,211 +0,0 @@
package org.ray.api.test;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;
import org.junit.Assert;
import org.junit.Test;
import org.ray.api.function.RayFunc0;
import org.ray.api.function.RayFunc1;
import org.ray.api.function.RayFunc3;
import org.ray.runtime.util.MethodId;
import org.ray.runtime.util.logger.RayLog;
public class LambdaUtilsTest {
static final String CLASS_NAME = LambdaUtilsTest.class.getName();
static final Method CALL0;
static final Method CALL1;
static final Method CALL2;
static final Method CALL3;
static {
try {
CALL0 = LambdaUtilsTest.class.getDeclaredMethod("call0", new Class[0]);
CALL1 = LambdaUtilsTest.class.getDeclaredMethod("call1", new Class[]{Long.class});
CALL2 = LambdaUtilsTest.class.getDeclaredMethod("call2", new Class[0]);
CALL3 = LambdaUtilsTest.class
.getDeclaredMethod("call3", new Class[]{Long.class, String.class});
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public static <T0, T1, T2, R0> void testRemoteLambdaParse(RayFunc3<T0, T1, T2, R0> f, int n,
boolean forceNew, boolean debug)
throws Exception {
if (debug) {
RayLog.core.info("parse#" + f.getClass().getName());
}
long start = System.nanoTime();
for (int i = 0; i < n; i++) {
MethodId mid = MethodId.fromSerializedLambda(f, forceNew);
}
long end = System.nanoTime();
RayLog.core.info(String.format("remoteLambdaParse(new=%s):total=%sms, one=%s", forceNew,
TimeUnit.NANOSECONDS.toMillis(end - start),
(end - start) / n));
}
public static <T0, T1, T2, R0> void testRemoteLambdaSerde(RayFunc3<T0, T1, T2, R0> f, int n,
boolean de, boolean debug)
throws Exception {
if (debug) {
RayLog.core.info("se#" + f.getClass().getName());
}
long start = System.nanoTime();
for (int i = 0; i < n; i++) {
ByteArrayOutputStream bytes = new ByteArrayOutputStream(1024);
ObjectOutputStream out = new ObjectOutputStream(bytes);
out.writeObject(f);
out.close();
if (de) {
ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()));
RayFunc3 def = (RayFunc3) in.readObject();
in.close();
if (debug) {
RayLog.core.info("de#" + def.getClass().getName());
}
}
}
long end = System.nanoTime();
RayLog.core.info(
String.format("remoteLambdaSer(de=%s):total=%sms,one=%s", de,
TimeUnit.NANOSECONDS.toMillis(end - start),
(end - start) / n));
}
public static void testCall0(RayFunc0 f) {
MethodId mid = MethodId.fromSerializedLambda(f);
RayLog.core.info(mid.toString());
Assert.assertEquals(mid.load(), CALL0);
Assert.assertTrue(mid.isStatic);
}
public static <T, R> void testCall1(RayFunc1<T, R> f, T t) {
MethodId mid = MethodId.fromSerializedLambda(f);
RayLog.core.info(mid.toString());
Assert.assertEquals(mid.load(), CALL1);
Assert.assertTrue(mid.isStatic);
}
public static <T, R> void testCall2(RayFunc1<T, R> f) {
MethodId mid = MethodId.fromSerializedLambda(f);
RayLog.core.info(mid.toString());
Assert.assertEquals(mid.load(), CALL2);
Assert.assertTrue(!mid.isStatic);
}
public static <T0, T1, T2, R0> void testCall3(RayFunc3<T0, T1, T2, R0> f) {
MethodId mid = MethodId.fromSerializedLambda(f);
RayLog.core.info(mid.toString());
Assert.assertEquals(mid.load(), CALL3);
Assert.assertTrue(!mid.isStatic);
}
public static String call0() {
long t = System.currentTimeMillis();
RayLog.core.info("call0:" + t);
return String.valueOf(t);
}
public static String call1(Long v) {
for (int i = 0; i < 100; i++) {
v += i;
}
RayLog.core.info("call1:" + v);
return String.valueOf(v);
}
@Test
public void testLambdaSer() throws Exception {
testCall0(LambdaUtilsTest::call0);
testCall1(LambdaUtilsTest::call1, Long.valueOf(System.currentTimeMillis()));
testCall2(LambdaUtilsTest::call2);
testCall3(LambdaUtilsTest::call3);
}
/**
* to test the serdeLambda's perf.
*/
public void testBenchmark() throws Exception {
//test serde
testRemoteLambdaSerde(LambdaUtilsTest::call3, 2, true, true);
testRemoteLambdaSerde(LambdaUtilsTest::call3, 2, true, true);
//warmup
RayLog.core.info("warmup:serde################");
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, true, false);
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, false, false);
RayLog.core.info("benchmark:serde################");
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, true, false);
RayLog.core.info("benchmark:ser################");
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, false, false);
//test serde one new call's time, no class cache
long start = System.nanoTime();
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false);
long end = System.nanoTime();
RayLog.core.info("one sertime:" + (end - start));
//test serde one new call's time, no class cache
start = System.nanoTime();
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false);
end = System.nanoTime();
RayLog.core.info("one sertime:" + (end - start));
//test serde one new call's time, no class cache
start = System.nanoTime();
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false);
end = System.nanoTime();
RayLog.core.info("one sertime:" + (end - start));
//test serde one new call's time, no class cache
start = System.nanoTime();
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false);
end = System.nanoTime();
RayLog.core.info("one serdetime:" + (end - start));
//test serde one new call's time, no class cache
start = System.nanoTime();
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false);
end = System.nanoTime();
RayLog.core.info("one serdetime:" + (end - start));
//test serde one new call's time, no class cache
start = System.nanoTime();
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false);
end = System.nanoTime();
RayLog.core.info("one serdetime:" + (end - start));
//test lambda
testRemoteLambdaParse(LambdaUtilsTest::call3, 2, true, true);
testRemoteLambdaParse(LambdaUtilsTest::call3, 2, false, true);
//warmup
RayLog.core.info("warmup:parse################");
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, true, false);
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, false, false);
RayLog.core.info("benchmark:parseNew################");
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, true, false);
RayLog.core.info("benchmark:parseCache################");
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, false, false);
}
public String call2() {
long t = System.currentTimeMillis();
RayLog.core.info("call2:" + t);
return "call2:" + t;
}
public String call3(Long v, String s) {
for (int i = 0; i < 100; i++) {
v += i;
}
RayLog.core.info("call3:" + v);
return String.valueOf(v);
}
}
@@ -1,44 +0,0 @@
package org.ray.api.test;
import java.lang.reflect.Executable;
import org.junit.Assert;
import org.junit.Test;
import org.ray.api.function.RayFunc2;
import org.ray.runtime.util.MethodId;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class MethodIdTest {
private static final Logger LOGGER = LoggerFactory.getLogger(MethodIdTest.class);
@Test
public void testNormalMethod() throws Exception {
RayFunc2<Integer, String, String> f = MethodIdTest::foo;
MethodId m1 = MethodId.fromSerializedLambda(f);
Executable e = MethodIdTest.class.getDeclaredMethod("foo", int.class, String.class);
MethodId m2 = MethodId.fromExecutable(e);
LOGGER.info("{}, {}", m1, m2);
Assert.assertEquals(m1, m2);
}
@Test
public void testConstructor() throws Exception {
RayFunc2<Integer, String, Foo> f = Foo::new;
MethodId m1 = MethodId.fromSerializedLambda(f);
Executable e = Foo.class.getConstructor(int.class, String.class);
MethodId m2 = MethodId.fromExecutable(e);
LOGGER.info("{}, {}", m1, m2);
Assert.assertEquals(m1, m2);
}
public static String foo(int a, String b) {
return a + b;
}
public static class Foo {
public Foo(int a, String b) {}
}
}
@@ -1,32 +0,0 @@
package org.ray.api.test;
import org.junit.Assert;
import org.junit.Test;
import org.ray.api.annotation.RayRemote;
import org.ray.runtime.functionmanager.RayActorMethods;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class RayActorMethodsTest {
private static final Logger LOGGER = LoggerFactory.getLogger(RayActorMethodsTest.class);
@RayRemote
public static class ExampleActor {
public void func1() {}
public void func2() {}
public static void func3() {}
}
@Test
public void testActorMethods() {
RayActorMethods methods = RayActorMethods
.fromClass(ExampleActor.class.getName(), RayActorMethodsTest.class.getClassLoader());
LOGGER.info(methods.toString());
Assert.assertEquals(methods.functions.size(), 2);
Assert.assertEquals(methods.staticFunctions.size(), 1);
}
}
@@ -1,43 +0,0 @@
package org.ray.api.test;
import org.junit.Assert;
import org.junit.Test;
import org.ray.runtime.functionmanager.RayMethod;
import org.ray.runtime.functionmanager.RayTaskMethods;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class RayTaskMethodsTest {
private static final Logger LOGGER = LoggerFactory.getLogger(RayTaskMethodsTest.class);
private static class Foo {
public Foo() {}
public Foo(int x) {}
public static void f1() {}
public void f2() {}
}
@Test
public void testTask() {
RayTaskMethods methods = RayTaskMethods
.fromClass(Foo.class.getName(), Foo.class.getClassLoader());
LOGGER.info(methods.toString());
int numMethods = 0;
int numConstructors = 0;
for (RayMethod m : methods.functions.values()) {
if (m.isConstructor()) {
numConstructors += 1;
} else {
numMethods += 1;
}
}
Assert.assertEquals(numMethods, 1);
Assert.assertEquals(numConstructors, 2);
}
}