mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 04:44:28 +08:00
[xlang] Cross language Python support (#6709)
This commit is contained in:
@@ -84,8 +84,8 @@ public class FunctionManager {
|
||||
SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func);
|
||||
final String className = serializedLambda.getImplClass().replace('/', '.');
|
||||
final String methodName = serializedLambda.getImplMethodName();
|
||||
final String typeDescriptor = serializedLambda.getImplMethodSignature();
|
||||
functionDescriptor = new JavaFunctionDescriptor(className, methodName, typeDescriptor);
|
||||
final String signature = serializedLambda.getImplMethodSignature();
|
||||
functionDescriptor = new JavaFunctionDescriptor(className, methodName, signature);
|
||||
RAY_FUNC_CACHE.get().put(func.getClass(), functionDescriptor);
|
||||
}
|
||||
return getFunction(jobId, functionDescriptor);
|
||||
@@ -167,13 +167,26 @@ public class FunctionManager {
|
||||
}
|
||||
}
|
||||
}
|
||||
return classFunctions.get(ImmutablePair.of(descriptor.name, descriptor.typeDescriptor));
|
||||
final Pair<String, String> key = ImmutablePair.of(descriptor.name, descriptor.signature);
|
||||
RayFunction func = classFunctions.get(key);
|
||||
if (func == null) {
|
||||
if (classFunctions.containsKey(key)) {
|
||||
throw new RuntimeException(
|
||||
String.format("RayFunction %s is overloaded, the signature can't be empty.",
|
||||
descriptor.toString()));
|
||||
} else {
|
||||
throw new RuntimeException(
|
||||
String.format("RayFunction %s not found", descriptor.toString()));
|
||||
}
|
||||
}
|
||||
return func;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load all functions from a class.
|
||||
*/
|
||||
Map<Pair<String, String>, RayFunction> loadFunctionsForClass(String className) {
|
||||
// If RayFunction is null, the function is overloaded.
|
||||
Map<Pair<String, String>, RayFunction> map = new HashMap<>();
|
||||
try {
|
||||
Class clazz = Class.forName(className, true, classLoader);
|
||||
@@ -187,10 +200,17 @@ public class FunctionManager {
|
||||
final String methodName = e instanceof Method ? e.getName() : CONSTRUCTOR_NAME;
|
||||
final Type type =
|
||||
e instanceof Method ? Type.getType((Method) e) : Type.getType((Constructor) e);
|
||||
final String typeDescriptor = type.getDescriptor();
|
||||
final String signature = type.getDescriptor();
|
||||
RayFunction rayFunction = new RayFunction(e, classLoader,
|
||||
new JavaFunctionDescriptor(className, methodName, typeDescriptor));
|
||||
map.put(ImmutablePair.of(methodName, typeDescriptor), rayFunction);
|
||||
new JavaFunctionDescriptor(className, methodName, signature));
|
||||
map.put(ImmutablePair.of(methodName, signature), rayFunction);
|
||||
// For cross language call java function without signature
|
||||
final Pair<String, String> emptyDescriptor = ImmutablePair.of(methodName, "");
|
||||
if (map.containsKey(emptyDescriptor)) {
|
||||
map.put(emptyDescriptor, null); // Mark this function as overloaded.
|
||||
} else {
|
||||
map.put(emptyDescriptor, rayFunction);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to load functions from class " + className, e);
|
||||
|
||||
+7
-7
@@ -19,14 +19,14 @@ public final class JavaFunctionDescriptor implements FunctionDescriptor {
|
||||
*/
|
||||
public final String name;
|
||||
/**
|
||||
* Function's type descriptor.
|
||||
* Function's signature.
|
||||
*/
|
||||
public final String typeDescriptor;
|
||||
public final String signature;
|
||||
|
||||
public JavaFunctionDescriptor(String className, String name, String typeDescriptor) {
|
||||
public JavaFunctionDescriptor(String className, String name, String signature) {
|
||||
this.className = className;
|
||||
this.name = name;
|
||||
this.typeDescriptor = typeDescriptor;
|
||||
this.signature = signature;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -45,17 +45,17 @@ public final class JavaFunctionDescriptor implements FunctionDescriptor {
|
||||
JavaFunctionDescriptor that = (JavaFunctionDescriptor) o;
|
||||
return Objects.equal(className, that.className) &&
|
||||
Objects.equal(name, that.name) &&
|
||||
Objects.equal(typeDescriptor, that.typeDescriptor);
|
||||
Objects.equal(signature, that.signature);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(className, name, typeDescriptor);
|
||||
return Objects.hashCode(className, name, signature);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> toList() {
|
||||
return ImmutableList.of(className, name, typeDescriptor);
|
||||
return ImmutableList.of(className, name, signature);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
+1
-1
@@ -28,7 +28,7 @@ public class PyFunctionDescriptor implements FunctionDescriptor {
|
||||
|
||||
@Override
|
||||
public List<String> toList() {
|
||||
return Arrays.asList(moduleName, className, functionName);
|
||||
return Arrays.asList(moduleName, className, functionName, "" /* function hash */);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -146,8 +146,12 @@ public class RunManager {
|
||||
e.printStackTrace();
|
||||
}
|
||||
if (!p.isAlive()) {
|
||||
throw new RuntimeException(
|
||||
String.format("Failed to start %s. Exit code: %d.", name, p.exitValue()));
|
||||
String message = String.format("Failed to start %s. Exit code: %d.",
|
||||
name, p.exitValue());
|
||||
if (rayConfig.redirectOutput) {
|
||||
message += String.format(" Logs are redirected to %s and %s.", stdout, stderr);
|
||||
}
|
||||
throw new RuntimeException(message);
|
||||
}
|
||||
processes.add(Pair.of(name, p));
|
||||
if (LOGGER.isInfoEnabled()) {
|
||||
|
||||
@@ -28,6 +28,7 @@ import org.ray.runtime.actor.LocalModeRayActor;
|
||||
import org.ray.runtime.context.LocalModeWorkerContext;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.generated.Common;
|
||||
import org.ray.runtime.generated.Common.ActorCreationTaskSpec;
|
||||
import org.ray.runtime.generated.Common.ActorTaskSpec;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
@@ -153,14 +154,20 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
List<FunctionArg> args) {
|
||||
byte[] taskIdBytes = new byte[TaskId.LENGTH];
|
||||
new Random().nextBytes(taskIdBytes);
|
||||
List<String> functionDescriptorList = functionDescriptor.toList();
|
||||
Preconditions.checkState(functionDescriptorList.size() >= 3);
|
||||
return TaskSpec.newBuilder()
|
||||
.setType(taskType)
|
||||
.setLanguage(Language.JAVA)
|
||||
.setJobId(
|
||||
ByteString.copyFrom(runtime.getRayConfig().getJobId().getBytes()))
|
||||
.setTaskId(ByteString.copyFrom(taskIdBytes))
|
||||
.addAllFunctionDescriptor(functionDescriptor.toList().stream().map(ByteString::copyFromUtf8)
|
||||
.collect(Collectors.toList()))
|
||||
.setFunctionDescriptor(org.ray.runtime.generated.Common.FunctionDescriptor.newBuilder()
|
||||
.setJavaFunctionDescriptor(
|
||||
org.ray.runtime.generated.Common.JavaFunctionDescriptor.newBuilder()
|
||||
.setClassName(functionDescriptorList.get(0))
|
||||
.setFunctionName(functionDescriptorList.get(1))
|
||||
.setSignature(functionDescriptorList.get(2))))
|
||||
.addAllArgs(args.stream().map(arg -> arg.id != null ? TaskArg.newBuilder()
|
||||
.addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build()
|
||||
: TaskArg.newBuilder().setData(ByteString.copyFrom(arg.value.data))
|
||||
@@ -307,9 +314,17 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
}
|
||||
|
||||
private static JavaFunctionDescriptor getJavaFunctionDescriptor(TaskSpec taskSpec) {
|
||||
List<ByteString> functionDescriptor = taskSpec.getFunctionDescriptorList();
|
||||
return new JavaFunctionDescriptor(functionDescriptor.get(0).toStringUtf8(),
|
||||
functionDescriptor.get(1).toStringUtf8(), functionDescriptor.get(2).toStringUtf8());
|
||||
org.ray.runtime.generated.Common.FunctionDescriptor functionDescriptor =
|
||||
taskSpec.getFunctionDescriptor();
|
||||
if (functionDescriptor.getFunctionDescriptorCase() ==
|
||||
Common.FunctionDescriptor.FunctionDescriptorCase.JAVA_FUNCTION_DESCRIPTOR) {
|
||||
return new JavaFunctionDescriptor(
|
||||
functionDescriptor.getJavaFunctionDescriptor().getClassName(),
|
||||
functionDescriptor.getJavaFunctionDescriptor().getFunctionName(),
|
||||
functionDescriptor.getJavaFunctionDescriptor().getSignature());
|
||||
} else {
|
||||
throw new RuntimeException("Can't build non java function descriptor");
|
||||
}
|
||||
}
|
||||
|
||||
private static List<FunctionArg> getFunctionArgs(TaskSpec taskSpec) {
|
||||
|
||||
@@ -48,11 +48,11 @@ public abstract class TaskExecutor {
|
||||
|
||||
List<NativeRayObject> returnObjects = new ArrayList<>();
|
||||
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
|
||||
// Find the executable object.
|
||||
RayFunction rayFunction = runtime.getFunctionManager()
|
||||
.getFunction(jobId, parseFunctionDescriptor(rayFunctionInfo));
|
||||
Preconditions.checkNotNull(rayFunction);
|
||||
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
|
||||
RayFunction rayFunction = null;
|
||||
try {
|
||||
// Find the executable object.
|
||||
rayFunction = runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
|
||||
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
|
||||
runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);
|
||||
|
||||
@@ -91,7 +91,9 @@ public abstract class TaskExecutor {
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Error executing task " + taskId, e);
|
||||
if (taskType != TaskType.ACTOR_CREATION_TASK) {
|
||||
if (rayFunction.hasReturn()) {
|
||||
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
|
||||
boolean isCrossLanguage = functionDescriptor.signature.equals("");
|
||||
if (hasReturn || isCrossLanguage) {
|
||||
returnObjects.add(ObjectSerializer
|
||||
.serialize(new RayTaskException("Error executing task " + taskId, e)));
|
||||
}
|
||||
|
||||
+44
-4
@@ -37,6 +37,14 @@ public class FunctionManagerTest {
|
||||
public Object bar() {
|
||||
return null;
|
||||
}
|
||||
|
||||
public Object overloadFunction(int i) {
|
||||
return null;
|
||||
}
|
||||
|
||||
public Object overloadFunction(double d) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private static RayFunc0<Object> fooFunc;
|
||||
@@ -45,6 +53,8 @@ public class FunctionManagerTest {
|
||||
private static JavaFunctionDescriptor fooDescriptor;
|
||||
private static JavaFunctionDescriptor barDescriptor;
|
||||
private static JavaFunctionDescriptor barConstructorDescriptor;
|
||||
private static JavaFunctionDescriptor overloadFunctionDescriptorInt;
|
||||
private static JavaFunctionDescriptor overloadFunctionDescriptorDouble;
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass() {
|
||||
@@ -58,6 +68,10 @@ public class FunctionManagerTest {
|
||||
barConstructorDescriptor = new JavaFunctionDescriptor(Bar.class.getName(),
|
||||
FunctionManager.CONSTRUCTOR_NAME,
|
||||
"()V");
|
||||
overloadFunctionDescriptorInt = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(),
|
||||
"overloadFunction", "(I)Ljava/lang/Object;");
|
||||
overloadFunctionDescriptorDouble = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(),
|
||||
"overloadFunction", "(D)Ljava/lang/Object;");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -102,6 +116,13 @@ public class FunctionManagerTest {
|
||||
Assert.assertTrue(func.isConstructor());
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), barConstructorDescriptor);
|
||||
Assert.assertNotNull(func.getRayRemoteAnnotation());
|
||||
|
||||
// Test raise overload exception
|
||||
Assert.expectThrows(RuntimeException.class, () -> {
|
||||
functionManager.getFunction(JobId.NIL,
|
||||
new JavaFunctionDescriptor(FunctionManagerTest.class.getName(),
|
||||
"overloadFunction", ""));
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -109,12 +130,31 @@ public class FunctionManagerTest {
|
||||
JobFunctionTable functionTable = new JobFunctionTable(getClass().getClassLoader());
|
||||
Map<Pair<String, String>, RayFunction> res = functionTable
|
||||
.loadFunctionsForClass(Bar.class.getName());
|
||||
// The result should 2 entries, one for the constructor, the other for bar.
|
||||
Assert.assertEquals(res.size(), 2);
|
||||
// The result should be 4 entries:
|
||||
// 1, the constructor with signature
|
||||
// 2, the constructor without signature
|
||||
// 3, bar with signature
|
||||
// 4, bar without signature
|
||||
Assert.assertEquals(res.size(), 7);
|
||||
Assert.assertTrue(res.containsKey(
|
||||
ImmutablePair.of(barDescriptor.name, barDescriptor.typeDescriptor)));
|
||||
ImmutablePair.of(barDescriptor.name, barDescriptor.signature)));
|
||||
Assert.assertTrue(res.containsKey(
|
||||
ImmutablePair.of(barConstructorDescriptor.name, barConstructorDescriptor.typeDescriptor)));
|
||||
ImmutablePair.of(barConstructorDescriptor.name, barConstructorDescriptor.signature)));
|
||||
Assert.assertTrue(res.containsKey(
|
||||
ImmutablePair.of(barDescriptor.name, "")));
|
||||
Assert.assertTrue(res.containsKey(
|
||||
ImmutablePair.of(barConstructorDescriptor.name, "")));
|
||||
Assert.assertTrue(res.containsKey(
|
||||
ImmutablePair.of(overloadFunctionDescriptorInt.name, overloadFunctionDescriptorInt.signature)));
|
||||
Assert.assertTrue(res.containsKey(
|
||||
ImmutablePair.of(overloadFunctionDescriptorDouble.name, overloadFunctionDescriptorDouble.signature)));
|
||||
Assert.assertTrue(res.containsKey(
|
||||
ImmutablePair.of(overloadFunctionDescriptorInt.name, "")));
|
||||
Pair<String, String> overloadKey = ImmutablePair.of(overloadFunctionDescriptorInt.name, "");
|
||||
RayFunction func = res.get(overloadKey);
|
||||
// The function is overloaded.
|
||||
Assert.assertTrue(res.containsKey(overloadKey));
|
||||
Assert.assertNull(func);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
Reference in New Issue
Block a user