[xlang] Cross language Python support (#6709)

This commit is contained in:
fyrestone
2020-02-08 13:01:28 +08:00
committed by GitHub
parent f146d05b36
commit 0648bd28ef
59 changed files with 1412 additions and 580 deletions
@@ -84,8 +84,8 @@ public class FunctionManager {
SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func);
final String className = serializedLambda.getImplClass().replace('/', '.');
final String methodName = serializedLambda.getImplMethodName();
final String typeDescriptor = serializedLambda.getImplMethodSignature();
functionDescriptor = new JavaFunctionDescriptor(className, methodName, typeDescriptor);
final String signature = serializedLambda.getImplMethodSignature();
functionDescriptor = new JavaFunctionDescriptor(className, methodName, signature);
RAY_FUNC_CACHE.get().put(func.getClass(), functionDescriptor);
}
return getFunction(jobId, functionDescriptor);
@@ -167,13 +167,26 @@ public class FunctionManager {
}
}
}
return classFunctions.get(ImmutablePair.of(descriptor.name, descriptor.typeDescriptor));
final Pair<String, String> key = ImmutablePair.of(descriptor.name, descriptor.signature);
RayFunction func = classFunctions.get(key);
if (func == null) {
if (classFunctions.containsKey(key)) {
throw new RuntimeException(
String.format("RayFunction %s is overloaded, the signature can't be empty.",
descriptor.toString()));
} else {
throw new RuntimeException(
String.format("RayFunction %s not found", descriptor.toString()));
}
}
return func;
}
/**
* Load all functions from a class.
*/
Map<Pair<String, String>, RayFunction> loadFunctionsForClass(String className) {
// If RayFunction is null, the function is overloaded.
Map<Pair<String, String>, RayFunction> map = new HashMap<>();
try {
Class clazz = Class.forName(className, true, classLoader);
@@ -187,10 +200,17 @@ public class FunctionManager {
final String methodName = e instanceof Method ? e.getName() : CONSTRUCTOR_NAME;
final Type type =
e instanceof Method ? Type.getType((Method) e) : Type.getType((Constructor) e);
final String typeDescriptor = type.getDescriptor();
final String signature = type.getDescriptor();
RayFunction rayFunction = new RayFunction(e, classLoader,
new JavaFunctionDescriptor(className, methodName, typeDescriptor));
map.put(ImmutablePair.of(methodName, typeDescriptor), rayFunction);
new JavaFunctionDescriptor(className, methodName, signature));
map.put(ImmutablePair.of(methodName, signature), rayFunction);
// For cross language call java function without signature
final Pair<String, String> emptyDescriptor = ImmutablePair.of(methodName, "");
if (map.containsKey(emptyDescriptor)) {
map.put(emptyDescriptor, null); // Mark this function as overloaded.
} else {
map.put(emptyDescriptor, rayFunction);
}
}
} catch (Exception e) {
throw new RuntimeException("Failed to load functions from class " + className, e);
@@ -19,14 +19,14 @@ public final class JavaFunctionDescriptor implements FunctionDescriptor {
*/
public final String name;
/**
* Function's type descriptor.
* Function's signature.
*/
public final String typeDescriptor;
public final String signature;
public JavaFunctionDescriptor(String className, String name, String typeDescriptor) {
public JavaFunctionDescriptor(String className, String name, String signature) {
this.className = className;
this.name = name;
this.typeDescriptor = typeDescriptor;
this.signature = signature;
}
@Override
@@ -45,17 +45,17 @@ public final class JavaFunctionDescriptor implements FunctionDescriptor {
JavaFunctionDescriptor that = (JavaFunctionDescriptor) o;
return Objects.equal(className, that.className) &&
Objects.equal(name, that.name) &&
Objects.equal(typeDescriptor, that.typeDescriptor);
Objects.equal(signature, that.signature);
}
@Override
public int hashCode() {
return Objects.hashCode(className, name, typeDescriptor);
return Objects.hashCode(className, name, signature);
}
@Override
public List<String> toList() {
return ImmutableList.of(className, name, typeDescriptor);
return ImmutableList.of(className, name, signature);
}
@Override
@@ -28,7 +28,7 @@ public class PyFunctionDescriptor implements FunctionDescriptor {
@Override
public List<String> toList() {
return Arrays.asList(moduleName, className, functionName);
return Arrays.asList(moduleName, className, functionName, "" /* function hash */);
}
@Override
@@ -146,8 +146,12 @@ public class RunManager {
e.printStackTrace();
}
if (!p.isAlive()) {
throw new RuntimeException(
String.format("Failed to start %s. Exit code: %d.", name, p.exitValue()));
String message = String.format("Failed to start %s. Exit code: %d.",
name, p.exitValue());
if (rayConfig.redirectOutput) {
message += String.format(" Logs are redirected to %s and %s.", stdout, stderr);
}
throw new RuntimeException(message);
}
processes.add(Pair.of(name, p));
if (LOGGER.isInfoEnabled()) {
@@ -28,6 +28,7 @@ import org.ray.runtime.actor.LocalModeRayActor;
import org.ray.runtime.context.LocalModeWorkerContext;
import org.ray.runtime.functionmanager.FunctionDescriptor;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.generated.Common;
import org.ray.runtime.generated.Common.ActorCreationTaskSpec;
import org.ray.runtime.generated.Common.ActorTaskSpec;
import org.ray.runtime.generated.Common.Language;
@@ -153,14 +154,20 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
List<FunctionArg> args) {
byte[] taskIdBytes = new byte[TaskId.LENGTH];
new Random().nextBytes(taskIdBytes);
List<String> functionDescriptorList = functionDescriptor.toList();
Preconditions.checkState(functionDescriptorList.size() >= 3);
return TaskSpec.newBuilder()
.setType(taskType)
.setLanguage(Language.JAVA)
.setJobId(
ByteString.copyFrom(runtime.getRayConfig().getJobId().getBytes()))
.setTaskId(ByteString.copyFrom(taskIdBytes))
.addAllFunctionDescriptor(functionDescriptor.toList().stream().map(ByteString::copyFromUtf8)
.collect(Collectors.toList()))
.setFunctionDescriptor(org.ray.runtime.generated.Common.FunctionDescriptor.newBuilder()
.setJavaFunctionDescriptor(
org.ray.runtime.generated.Common.JavaFunctionDescriptor.newBuilder()
.setClassName(functionDescriptorList.get(0))
.setFunctionName(functionDescriptorList.get(1))
.setSignature(functionDescriptorList.get(2))))
.addAllArgs(args.stream().map(arg -> arg.id != null ? TaskArg.newBuilder()
.addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build()
: TaskArg.newBuilder().setData(ByteString.copyFrom(arg.value.data))
@@ -307,9 +314,17 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
private static JavaFunctionDescriptor getJavaFunctionDescriptor(TaskSpec taskSpec) {
List<ByteString> functionDescriptor = taskSpec.getFunctionDescriptorList();
return new JavaFunctionDescriptor(functionDescriptor.get(0).toStringUtf8(),
functionDescriptor.get(1).toStringUtf8(), functionDescriptor.get(2).toStringUtf8());
org.ray.runtime.generated.Common.FunctionDescriptor functionDescriptor =
taskSpec.getFunctionDescriptor();
if (functionDescriptor.getFunctionDescriptorCase() ==
Common.FunctionDescriptor.FunctionDescriptorCase.JAVA_FUNCTION_DESCRIPTOR) {
return new JavaFunctionDescriptor(
functionDescriptor.getJavaFunctionDescriptor().getClassName(),
functionDescriptor.getJavaFunctionDescriptor().getFunctionName(),
functionDescriptor.getJavaFunctionDescriptor().getSignature());
} else {
throw new RuntimeException("Can't build non java function descriptor");
}
}
private static List<FunctionArg> getFunctionArgs(TaskSpec taskSpec) {
@@ -48,11 +48,11 @@ public abstract class TaskExecutor {
List<NativeRayObject> returnObjects = new ArrayList<>();
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
// Find the executable object.
RayFunction rayFunction = runtime.getFunctionManager()
.getFunction(jobId, parseFunctionDescriptor(rayFunctionInfo));
Preconditions.checkNotNull(rayFunction);
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
RayFunction rayFunction = null;
try {
// Find the executable object.
rayFunction = runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);
@@ -91,7 +91,9 @@ public abstract class TaskExecutor {
} catch (Exception e) {
LOGGER.error("Error executing task " + taskId, e);
if (taskType != TaskType.ACTOR_CREATION_TASK) {
if (rayFunction.hasReturn()) {
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
boolean isCrossLanguage = functionDescriptor.signature.equals("");
if (hasReturn || isCrossLanguage) {
returnObjects.add(ObjectSerializer
.serialize(new RayTaskException("Error executing task " + taskId, e)));
}
@@ -37,6 +37,14 @@ public class FunctionManagerTest {
public Object bar() {
return null;
}
public Object overloadFunction(int i) {
return null;
}
public Object overloadFunction(double d) {
return null;
}
}
private static RayFunc0<Object> fooFunc;
@@ -45,6 +53,8 @@ public class FunctionManagerTest {
private static JavaFunctionDescriptor fooDescriptor;
private static JavaFunctionDescriptor barDescriptor;
private static JavaFunctionDescriptor barConstructorDescriptor;
private static JavaFunctionDescriptor overloadFunctionDescriptorInt;
private static JavaFunctionDescriptor overloadFunctionDescriptorDouble;
@BeforeClass
public static void beforeClass() {
@@ -58,6 +68,10 @@ public class FunctionManagerTest {
barConstructorDescriptor = new JavaFunctionDescriptor(Bar.class.getName(),
FunctionManager.CONSTRUCTOR_NAME,
"()V");
overloadFunctionDescriptorInt = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(),
"overloadFunction", "(I)Ljava/lang/Object;");
overloadFunctionDescriptorDouble = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(),
"overloadFunction", "(D)Ljava/lang/Object;");
}
@Test
@@ -102,6 +116,13 @@ public class FunctionManagerTest {
Assert.assertTrue(func.isConstructor());
Assert.assertEquals(func.getFunctionDescriptor(), barConstructorDescriptor);
Assert.assertNotNull(func.getRayRemoteAnnotation());
// Test raise overload exception
Assert.expectThrows(RuntimeException.class, () -> {
functionManager.getFunction(JobId.NIL,
new JavaFunctionDescriptor(FunctionManagerTest.class.getName(),
"overloadFunction", ""));
});
}
@Test
@@ -109,12 +130,31 @@ public class FunctionManagerTest {
JobFunctionTable functionTable = new JobFunctionTable(getClass().getClassLoader());
Map<Pair<String, String>, RayFunction> res = functionTable
.loadFunctionsForClass(Bar.class.getName());
// The result should 2 entries, one for the constructor, the other for bar.
Assert.assertEquals(res.size(), 2);
// The result should be 4 entries:
// 1, the constructor with signature
// 2, the constructor without signature
// 3, bar with signature
// 4, bar without signature
Assert.assertEquals(res.size(), 7);
Assert.assertTrue(res.containsKey(
ImmutablePair.of(barDescriptor.name, barDescriptor.typeDescriptor)));
ImmutablePair.of(barDescriptor.name, barDescriptor.signature)));
Assert.assertTrue(res.containsKey(
ImmutablePair.of(barConstructorDescriptor.name, barConstructorDescriptor.typeDescriptor)));
ImmutablePair.of(barConstructorDescriptor.name, barConstructorDescriptor.signature)));
Assert.assertTrue(res.containsKey(
ImmutablePair.of(barDescriptor.name, "")));
Assert.assertTrue(res.containsKey(
ImmutablePair.of(barConstructorDescriptor.name, "")));
Assert.assertTrue(res.containsKey(
ImmutablePair.of(overloadFunctionDescriptorInt.name, overloadFunctionDescriptorInt.signature)));
Assert.assertTrue(res.containsKey(
ImmutablePair.of(overloadFunctionDescriptorDouble.name, overloadFunctionDescriptorDouble.signature)));
Assert.assertTrue(res.containsKey(
ImmutablePair.of(overloadFunctionDescriptorInt.name, "")));
Pair<String, String> overloadKey = ImmutablePair.of(overloadFunctionDescriptorInt.name, "");
RayFunction func = res.get(overloadKey);
// The function is overloaded.
Assert.assertTrue(res.containsKey(overloadKey));
Assert.assertNull(func);
}
@Test