mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 06:08:03 +08:00
[Java] Replace binary rewrite with Remote Lambda Cache (SerdeLambda) (#2245)
* <feature> : serde lambda * <feature>:fixed CR with issue #2245 * <feature>: fixed CR
This commit is contained in:
committed by
Philipp Moritz
parent
62de86ff7a
commit
fa0ade2bc5
@@ -1,5 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
|
||||
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
@@ -24,35 +23,35 @@
|
||||
<artifactId>ray-api</artifactId>
|
||||
<version>1.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.ray</groupId>
|
||||
<artifactId>ray-hook</artifactId>
|
||||
<version>1.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>de.ruedigermoeller</groupId>
|
||||
<artifactId>fst</artifactId>
|
||||
<version>2.47</version>
|
||||
</dependency>
|
||||
|
||||
<!-- https://mvnrepository.com/artifact/com.github.davidmoten/flatbuffers-java -->
|
||||
<dependency>
|
||||
<groupId>com.github.davidmoten</groupId>
|
||||
<artifactId>flatbuffers-java</artifactId>
|
||||
<version>1.7.0.1</version>
|
||||
</dependency>
|
||||
|
||||
<!-- https://mvnrepository.com/artifact/redis.clients/jedis -->
|
||||
<dependency>
|
||||
<groupId>redis.clients</groupId>
|
||||
<artifactId>jedis</artifactId>
|
||||
<version>2.8.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.arrow</groupId>
|
||||
<artifactId>arrow-plasma</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-io</groupId>
|
||||
<artifactId>commons-io</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
</project>
|
||||
</project>
|
||||
@@ -4,6 +4,7 @@ import java.io.Serializable;
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -83,7 +84,9 @@ public class ArgumentsBuilder {
|
||||
@SuppressWarnings({"rawtypes", "unchecked"})
|
||||
public static Pair<Object, Object[]> unwrap(TaskSpec task, Method m, ClassLoader classLoader)
|
||||
throws TaskExecutionException {
|
||||
FunctionArg[] fargs = task.args;
|
||||
// the last arg is className
|
||||
|
||||
FunctionArg[] fargs = Arrays.copyOf(task.args, task.args.length - 1);
|
||||
Object current = null;
|
||||
Object[] realArgs;
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ public class InvocationExecutor {
|
||||
}
|
||||
|
||||
private static void executeInternal(TaskSpec task, Pair<ClassLoader, RayMethod> pr,
|
||||
String taskdesc)
|
||||
String taskdesc)
|
||||
throws IllegalAccessException, IllegalArgumentException, InvocationTargetException {
|
||||
Method m = pr.getRight().invokable;
|
||||
Map<?, UniqueID> userRayReturnIdMap = null;
|
||||
@@ -86,11 +86,12 @@ public class InvocationExecutor {
|
||||
}
|
||||
|
||||
// execute
|
||||
Object result;
|
||||
if (!UniqueIdHelper.isLambdaFunction(task.functionId)) {
|
||||
Object result = null;
|
||||
try {
|
||||
result = m.invoke(realArgs.getLeft(), realArgs.getRight());
|
||||
} else {
|
||||
result = m.invoke(realArgs.getLeft(), new Object[] {realArgs.getRight()});
|
||||
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
|
||||
RayLog.core.error("invoke failed:" + m);
|
||||
throw e;
|
||||
}
|
||||
|
||||
if (task.returnIds == null || task.returnIds.length == 0) {
|
||||
@@ -147,4 +148,4 @@ public class InvocationExecutor {
|
||||
private static void safePut(UniqueID objectId, Object obj) {
|
||||
RayRuntime.getInstance().putRaw(objectId, obj);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,14 @@
|
||||
package org.ray.core;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.hook.MethodId;
|
||||
import org.ray.hook.runtime.JarLoader;
|
||||
import org.ray.hook.runtime.LoadedFunctions;
|
||||
import org.ray.spi.RemoteFunctionManager;
|
||||
import org.ray.spi.model.FunctionArg;
|
||||
import org.ray.spi.model.RayActorMethods;
|
||||
import org.ray.spi.model.RayMethod;
|
||||
import org.ray.spi.model.RayTaskMethods;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
/**
|
||||
@@ -18,6 +17,7 @@ import org.ray.util.logger.RayLog;
|
||||
public class LocalFunctionManager {
|
||||
|
||||
private final RemoteFunctionManager remoteLoader;
|
||||
|
||||
private final ConcurrentHashMap<UniqueID, FunctionTable> functionTables
|
||||
= new ConcurrentHashMap<>();
|
||||
|
||||
@@ -29,74 +29,44 @@ public class LocalFunctionManager {
|
||||
this.remoteLoader = remoteLoader;
|
||||
}
|
||||
|
||||
private FunctionTable loadDriverFunctions(UniqueID driverId) {
|
||||
FunctionTable functionTable = functionTables.get(driverId);
|
||||
if (functionTable == null) {
|
||||
RayLog.core.info("DriverId " + driverId + " Try to load functions");
|
||||
ClassLoader classLoader = remoteLoader.loadResource(driverId);
|
||||
if (classLoader == null) {
|
||||
throw new RuntimeException(
|
||||
"Cannot find resource' classLoader for app " + driverId.toString());
|
||||
}
|
||||
functionTable = new FunctionTable(classLoader);
|
||||
functionTables.put(driverId, functionTable);
|
||||
}
|
||||
return functionTable;
|
||||
}
|
||||
|
||||
Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID actorId,
|
||||
UniqueID methodId, String className) {
|
||||
// assert the driver's resource is load.
|
||||
FunctionTable functionTable = loadDriverFunctions(driverId);
|
||||
Preconditions.checkNotNull(functionTable, "driver's resource is not loaded:%s", driverId);
|
||||
RayMethod method = actorId.isNil() ? functionTable.getTaskMethod(methodId, className)
|
||||
: functionTable.getActorMethod(methodId, className);
|
||||
Preconditions
|
||||
.checkNotNull(method, "method not found, class=%s, methodId=%s, driverId=%s", className,
|
||||
methodId, driverId);
|
||||
return Pair.of(functionTable.classLoader, method);
|
||||
}
|
||||
|
||||
/**
|
||||
* get local method for executing, which pulls information from remote repo on-demand, therefore
|
||||
* it may block for a while if the related resources (e.g., jars) are not ready on local machine
|
||||
*/
|
||||
public Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID methodId,
|
||||
FunctionArg[] args) throws NoSuchMethodException,
|
||||
SecurityException, ClassNotFoundException {
|
||||
FunctionTable funcs = loadDriverFunctions(driverId);
|
||||
RayMethod m;
|
||||
|
||||
// hooked methods
|
||||
if (!UniqueIdHelper.isLambdaFunction(methodId)) {
|
||||
m = funcs.functions.get(methodId);
|
||||
if (null == m) {
|
||||
throw new RuntimeException(
|
||||
"DriverId " + driverId + " load remote function methodId:" + methodId + " failed");
|
||||
}
|
||||
} else { // remote lambda
|
||||
assert args.length >= 2;
|
||||
String fname = Serializer.decode(args[args.length - 2].data);
|
||||
Method fm = Class.forName(fname).getMethod("execute", Object[].class);
|
||||
m = new RayMethod(fm);
|
||||
}
|
||||
|
||||
return Pair.of(funcs.linkedFunctions.loader, m);
|
||||
}
|
||||
|
||||
private synchronized FunctionTable loadDriverFunctions(UniqueID driverId) {
|
||||
FunctionTable funcs = functionTables.get(driverId);
|
||||
if (null == funcs) {
|
||||
RayLog.core.debug("DriverId " + driverId + " Try to load functions");
|
||||
LoadedFunctions funcs2 = remoteLoader.loadFunctions(driverId);
|
||||
if (funcs2 == null) {
|
||||
throw new RuntimeException("Cannot find resource for app " + driverId.toString());
|
||||
}
|
||||
funcs = new FunctionTable();
|
||||
funcs.linkedFunctions = funcs2;
|
||||
for (MethodId mid : funcs.linkedFunctions.functions) {
|
||||
Method m = mid.load();
|
||||
assert (m != null);
|
||||
RayMethod v = new RayMethod(m);
|
||||
v.check();
|
||||
UniqueID k = new UniqueID(mid.getSha1Hash());
|
||||
String logInfo =
|
||||
"DriverId" + driverId + " load remote function " + m.getName() + ", hash = " + k
|
||||
.toString();
|
||||
RayLog.core.debug(logInfo);
|
||||
System.err.println(logInfo);
|
||||
funcs.functions.put(k, v);
|
||||
}
|
||||
|
||||
functionTables.put(driverId, funcs);
|
||||
} else { // reSync automatically
|
||||
// more functions are loaded
|
||||
if (funcs.linkedFunctions.functions.size() > funcs.functions.size()) {
|
||||
for (MethodId mid : funcs.linkedFunctions.functions) {
|
||||
UniqueID k = new UniqueID(mid.getSha1Hash());
|
||||
if (!funcs.functions.containsKey(k)) {
|
||||
Method m = mid.load();
|
||||
assert (m != null);
|
||||
RayMethod v = new RayMethod(m);
|
||||
v.check();
|
||||
funcs.functions.put(k, v);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return funcs;
|
||||
public Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID actorId,
|
||||
UniqueID methodId,
|
||||
FunctionArg[] args) throws NoSuchMethodException, SecurityException, ClassNotFoundException {
|
||||
Preconditions.checkArgument(args.length >= 1, "method's args len %s<=1", args.length);
|
||||
String className = (String) Serializer.decode(args[args.length - 1].data);
|
||||
return getMethod(driverId, actorId, methodId, className);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -106,13 +76,47 @@ public class LocalFunctionManager {
|
||||
FunctionTable funcs = functionTables.get(driverId);
|
||||
if (funcs != null) {
|
||||
functionTables.remove(driverId);
|
||||
JarLoader.unloadJars(funcs.linkedFunctions.loader);
|
||||
remoteLoader.unloadFunctions(driverId);
|
||||
}
|
||||
}
|
||||
|
||||
static class FunctionTable {
|
||||
private static class FunctionTable {
|
||||
|
||||
public final ConcurrentHashMap<UniqueID, RayMethod> functions = new ConcurrentHashMap<>();
|
||||
public LoadedFunctions linkedFunctions;
|
||||
final ClassLoader classLoader;
|
||||
final ConcurrentHashMap<String, RayTaskMethods> taskMethods = new ConcurrentHashMap<>();
|
||||
final ConcurrentHashMap<String, RayActorMethods> actors = new ConcurrentHashMap<>();
|
||||
|
||||
FunctionTable(ClassLoader classLoader) {
|
||||
this.classLoader = classLoader;
|
||||
}
|
||||
|
||||
RayMethod getTaskMethod(UniqueID methodId, String className) {
|
||||
RayTaskMethods tasks = taskMethods.get(className);
|
||||
if (tasks == null) {
|
||||
tasks = RayTaskMethods.fromClass(className, classLoader);
|
||||
RayLog.core.info("create RayTaskMethods:" + tasks);
|
||||
taskMethods.put(className, tasks);
|
||||
}
|
||||
RayMethod m = tasks.functions.get(methodId);
|
||||
if (m != null) {
|
||||
return m;
|
||||
}
|
||||
// it is a actor static func.
|
||||
return getActorMethod(methodId, className, true);
|
||||
}
|
||||
|
||||
RayMethod getActorMethod(UniqueID methodId, String className) {
|
||||
return getActorMethod(methodId, className, false);
|
||||
}
|
||||
|
||||
private RayMethod getActorMethod(UniqueID methodId, String className, boolean isStatic) {
|
||||
RayActorMethods actor = actors.get(className);
|
||||
if (actor == null) {
|
||||
actor = RayActorMethods.fromClass(className, classLoader);
|
||||
RayLog.core.info("create RayActorMethods:" + actor);
|
||||
actors.put(className, actor);
|
||||
}
|
||||
return isStatic ? actor.staticFunctions.get(methodId) : actor.functions.get(methodId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,7 @@ import org.ray.api.RayObject;
|
||||
import org.ray.api.RayObjects;
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.internal.Callable;
|
||||
import org.ray.api.internal.RayFunc;
|
||||
import org.ray.core.model.RayParameters;
|
||||
import org.ray.spi.LocalSchedulerLink;
|
||||
import org.ray.spi.LocalSchedulerProxy;
|
||||
@@ -261,49 +261,23 @@ public abstract class RayRuntime implements RayApi {
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObjects call(UniqueID taskId, Callable funcRun, int returnCount, Object... args) {
|
||||
return worker.rpc(taskId, funcRun, returnCount, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObjects call(UniqueID taskId, Class<?> funcCls, Serializable lambda, int returnCount,
|
||||
Object... args) {
|
||||
return worker.rpc(taskId, UniqueID.nil, funcCls, lambda, returnCount, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Callable funcRun,
|
||||
Collection<RIDT> returnIds,
|
||||
Object... args) {
|
||||
return worker.rpcWithReturnLabels(taskId, funcRun, returnIds, args);
|
||||
public RayObjects call(UniqueID taskId, Class<?> funcCls, RayFunc lambda, int returnCount,
|
||||
Object... args) {
|
||||
return worker.rpc(taskId, funcCls, lambda, returnCount, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Class<?> funcCls,
|
||||
Serializable lambda, Collection<RIDT>
|
||||
returnids,
|
||||
Object... args) {
|
||||
RayFunc lambda, Collection<RIDT> returnids, Object... args) {
|
||||
return worker.rpcWithReturnLabels(taskId, funcCls, lambda, returnids, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R> RayList<R> callWithReturnIndices(UniqueID taskId, Callable funcRun,
|
||||
Integer returnCount, Object... args) {
|
||||
return worker.rpcWithReturnIndices(taskId, funcRun, returnCount, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R> RayList<R> callWithReturnIndices(UniqueID taskId, Class<?> funcCls,
|
||||
Serializable lambda, Integer returnCount, Object...
|
||||
args) {
|
||||
RayFunc lambda, Integer returnCount, Object... args) {
|
||||
return worker.rpcWithReturnIndices(taskId, funcCls, lambda, returnCount, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isRemoteLambda() {
|
||||
return params.run_mode.isRemoteLambda();
|
||||
}
|
||||
|
||||
private <T> List<T> doGet(List<UniqueID> objectIds, boolean isMetadata)
|
||||
throws TaskExecutionException {
|
||||
boolean wasBlocked = false;
|
||||
@@ -465,4 +439,4 @@ public abstract class RayRuntime implements RayApi {
|
||||
public RemoteFunctionManager getRemoteFunctionManager() {
|
||||
return remoteFunctionManager;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -249,12 +249,6 @@ public class UniqueIdHelper {
|
||||
return taskId;
|
||||
}
|
||||
|
||||
public static boolean isLambdaFunction(UniqueID functionId) {
|
||||
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
return wbb.getLong() == 0xffffffffffffffffL;
|
||||
}
|
||||
|
||||
public static void markCreateActorStage1Function(UniqueID functionId) {
|
||||
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
@@ -279,4 +273,4 @@ public class UniqueIdHelper {
|
||||
TASK,
|
||||
ACTOR,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
package org.ray.core;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.io.Serializable;
|
||||
import java.lang.invoke.SerializedLambda;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.apache.commons.lang3.SerializationUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayList;
|
||||
@@ -12,12 +13,13 @@ import org.ray.api.RayMap;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayObjects;
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.api.internal.Callable;
|
||||
import org.ray.hook.runtime.MethodSwitcher;
|
||||
import org.ray.api.internal.RayFunc;
|
||||
import org.ray.spi.LocalSchedulerProxy;
|
||||
import org.ray.spi.model.RayInvocation;
|
||||
import org.ray.spi.model.RayMethod;
|
||||
import org.ray.spi.model.TaskSpec;
|
||||
import org.ray.util.LambdaUtils;
|
||||
import org.ray.util.MethodId;
|
||||
import org.ray.util.exception.TaskExecutionException;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
@@ -47,13 +49,13 @@ public class Worker {
|
||||
RayLog.core.info("Task " + task.taskId + " start execute");
|
||||
Throwable ex = null;
|
||||
|
||||
|
||||
if (!task.actorId.isNil() || (task.createActorId != null && !task.createActorId.isNil())) {
|
||||
task.returnIds = ArrayUtils.subarray(task.returnIds, 0, task.returnIds.length - 1);
|
||||
}
|
||||
|
||||
try {
|
||||
Pair<ClassLoader, RayMethod> pr = funcs.getMethod(task.driverId, task.functionId, task.args);
|
||||
Pair<ClassLoader, RayMethod> pr = funcs
|
||||
.getMethod(task.driverId, task.actorId, task.functionId, task.args);
|
||||
WorkerContext.prepare(task, pr.getLeft());
|
||||
InvocationExecutor.execute(task, pr);
|
||||
} catch (NoSuchMethodException | SecurityException | ClassNotFoundException e) {
|
||||
@@ -72,161 +74,75 @@ public class Worker {
|
||||
|
||||
}
|
||||
|
||||
public RayObjects rpc(UniqueID taskId, Callable funcRun, int returnCount, Object[] args) {
|
||||
byte[] fid = fidFromHook(funcRun);
|
||||
return submit(taskId, fid, returnCount, false, args);
|
||||
}
|
||||
|
||||
public RayObjects rpc(UniqueID taskId, UniqueID functionId, Class<?> funcCls, Serializable lambda,
|
||||
int returnCount, Object[] args) {
|
||||
byte[] fid = functionId.getBytes();
|
||||
|
||||
Object[] ls = Arrays.copyOf(args, args.length + 2);
|
||||
ls[args.length] = funcCls.getName();
|
||||
ls[args.length + 1] = SerializationUtils.serialize(lambda);
|
||||
|
||||
return submit(taskId, fid, returnCount, false, ls);
|
||||
}
|
||||
|
||||
public RayObjects rpc(UniqueID taskId, RayActor<?> actor, Callable funcRun, int returnCount,
|
||||
Object[] args) {
|
||||
byte[] fid = fidFromHook(funcRun);
|
||||
return actorTaskSubmit(taskId, fid, returnCount, false, args, actor);
|
||||
}
|
||||
|
||||
public RayObjects rpc(UniqueID taskId, UniqueID functionId, RayActor<?> actor, Class<?> funcCls,
|
||||
Serializable lambda, int returnCount, Object[] args) {
|
||||
byte[] fid = functionId.getBytes();
|
||||
|
||||
Object[] ls = Arrays.copyOf(args, args.length + 2);
|
||||
ls[args.length] = funcCls.getName();
|
||||
ls[args.length + 1] = SerializationUtils.serialize(lambda);
|
||||
|
||||
return actorTaskSubmit(taskId, fid, returnCount, false, ls, actor);
|
||||
}
|
||||
|
||||
private byte[] fidFromHook(Callable funcRun) {
|
||||
MethodSwitcher.IsRemoteCall.set(true);
|
||||
try {
|
||||
funcRun.run();
|
||||
} catch (Throwable e) {
|
||||
RayLog.core.error(
|
||||
"make sure you are using code rewritten using the rewrite tool, see JarRewriter for"
|
||||
+ " options", e);
|
||||
throw new RuntimeException("make sure you are using code rewritten using the rewrite tool,"
|
||||
+ "see JarRewriter for options");
|
||||
}
|
||||
byte[] fid = MethodSwitcher.MethodId.get();//get the identity of function from hook
|
||||
MethodSwitcher.IsRemoteCall.set(false);
|
||||
return fid;
|
||||
}
|
||||
|
||||
private RayObjects submit(UniqueID taskId,
|
||||
byte[] fid,
|
||||
int returnCount,
|
||||
boolean multiReturn,
|
||||
Object[] args) {
|
||||
if (taskId == null) {
|
||||
taskId = UniqueIdHelper.nextTaskId(-1);
|
||||
}
|
||||
if (args.length > 0 && args[0].getClass().equals(RayActor.class)) {
|
||||
return actorTaskSubmit(taskId, fid, returnCount, multiReturn, args, (RayActor<?>) args[0]);
|
||||
} else {
|
||||
return taskSubmit(taskId, fid, returnCount, multiReturn, args);
|
||||
}
|
||||
private RayObjects taskSubmit(UniqueID taskId,
|
||||
MethodId methodId,
|
||||
int returnCount,
|
||||
boolean multiReturn,
|
||||
Object[] args) {
|
||||
RayInvocation ri = createRemoteInvocation(methodId, args, RayActor.nil);
|
||||
return scheduler.submit(taskId, ri, returnCount, multiReturn);
|
||||
}
|
||||
|
||||
private RayObjects actorTaskSubmit(UniqueID taskId,
|
||||
byte[] fid,
|
||||
int returnCount,
|
||||
boolean multiReturn,
|
||||
Object[] args,
|
||||
RayActor<?> actor) {
|
||||
RayInvocation ri = new RayInvocation(fid, args, actor);
|
||||
MethodId methodId,
|
||||
int returnCount,
|
||||
boolean multiReturn,
|
||||
Object[] args,
|
||||
RayActor<?> actor) {
|
||||
RayInvocation ri = createRemoteInvocation(methodId, args, actor);
|
||||
RayObjects returnObjs = scheduler.submit(taskId, ri, returnCount + 1, multiReturn);
|
||||
actor.setTaskCursor(returnObjs.pop().getId());
|
||||
return returnObjs;
|
||||
}
|
||||
|
||||
private RayObjects taskSubmit(UniqueID taskId,
|
||||
byte[] fid,
|
||||
int returnCount,
|
||||
boolean multiReturn,
|
||||
Object[] args) {
|
||||
RayInvocation ri = new RayInvocation(fid, args);
|
||||
return scheduler.submit(taskId, ri, returnCount, multiReturn);
|
||||
}
|
||||
|
||||
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId, Callable funcRun,
|
||||
int returnCount,
|
||||
Object[] args) {
|
||||
byte[] fid = fidFromHook(funcRun);
|
||||
RayInvocation ri = new RayInvocation(fid, new Object[] {});
|
||||
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
|
||||
}
|
||||
|
||||
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId, UniqueID functionId,
|
||||
Class<?> funcCls, Serializable lambda, int returnCount,
|
||||
Object[] args) {
|
||||
byte[] fid = functionId.getBytes();
|
||||
|
||||
Object[] ls = Arrays.copyOf(args, args.length + 2);
|
||||
ls[args.length] = funcCls.getName();
|
||||
ls[args.length + 1] = SerializationUtils.serialize(lambda);
|
||||
|
||||
RayInvocation ri = new RayInvocation(fid, ls);
|
||||
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
|
||||
}
|
||||
|
||||
public <R, RIDT> RayMap<RIDT, R> rpcWithReturnLabels(UniqueID taskId, Callable funcRun,
|
||||
Collection<RIDT> returnids,
|
||||
Object[] args) {
|
||||
byte[] fid = fidFromHook(funcRun);
|
||||
private RayObjects submit(UniqueID taskId,
|
||||
MethodId methodId,
|
||||
int returnCount,
|
||||
boolean multiReturn,
|
||||
Object[] args) {
|
||||
if (taskId == null) {
|
||||
taskId = UniqueIdHelper.nextTaskId(-1);
|
||||
}
|
||||
return scheduler.submit(taskId, new RayInvocation(fid, args), returnids);
|
||||
if (args.length > 0 && args[0].getClass().equals(RayActor.class)) {
|
||||
return actorTaskSubmit(taskId, methodId, returnCount, multiReturn, args,
|
||||
(RayActor<?>) args[0]);
|
||||
} else {
|
||||
return taskSubmit(taskId, methodId, returnCount, multiReturn, args);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public RayObjects rpc(UniqueID taskId, Class<?> funcCls, RayFunc lambda,
|
||||
int returnCount, Object[] args) {
|
||||
MethodId mid = methodIdOf(lambda);
|
||||
return submit(taskId, mid, returnCount, false, args);
|
||||
}
|
||||
|
||||
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId,
|
||||
Class<?> funcCls, RayFunc lambda, int returnCount, Object[] args) {
|
||||
Preconditions.checkNotNull(taskId);
|
||||
MethodId mid = methodIdOf(lambda);
|
||||
RayInvocation ri = createRemoteInvocation(mid, args, RayActor.nil);
|
||||
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
|
||||
}
|
||||
|
||||
public <R, RIDT> RayMap<RIDT, R> rpcWithReturnLabels(UniqueID taskId, Class<?> funcCls,
|
||||
Serializable lambda,
|
||||
Collection<RIDT> returnids,
|
||||
Object[] args) {
|
||||
final byte[] fid = UniqueID.nil.getBytes();
|
||||
RayFunc lambda, Collection<RIDT> returnids,
|
||||
Object[] args) {
|
||||
MethodId mid = methodIdOf(lambda);
|
||||
if (taskId == null) {
|
||||
taskId = UniqueIdHelper.nextTaskId(-1);
|
||||
}
|
||||
|
||||
Object[] ls = Arrays.copyOf(args, args.length + 2);
|
||||
ls[args.length] = funcCls.getName();
|
||||
ls[args.length + 1] = SerializationUtils.serialize(lambda);
|
||||
|
||||
return scheduler.submit(taskId, new RayInvocation(fid, ls), returnids);
|
||||
RayInvocation ri = createRemoteInvocation(mid, args, RayActor.nil);
|
||||
return scheduler.submit(taskId, ri, returnids);
|
||||
}
|
||||
|
||||
public <R> RayList<R> rpcWithReturnIndices(UniqueID taskId, Callable funcRun,
|
||||
Integer returnCount,
|
||||
Object[] args) {
|
||||
byte[] fid = fidFromHook(funcRun);
|
||||
RayObjects objs = submit(taskId, fid, returnCount, true, args);
|
||||
RayList<R> rets = new RayList<>();
|
||||
for (RayObject obj : objs.getObjs()) {
|
||||
rets.add(obj);
|
||||
}
|
||||
return rets;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public <R> RayList<R> rpcWithReturnIndices(UniqueID taskId, Class<?> funcCls,
|
||||
Serializable lambda, Integer returnCount,
|
||||
Object[] args) {
|
||||
byte[] fid = UniqueID.nil.getBytes();
|
||||
Object[] ls = Arrays.copyOf(args, args.length + 2);
|
||||
ls[args.length] = funcCls.getName();
|
||||
ls[args.length + 1] = SerializationUtils.serialize(lambda);
|
||||
|
||||
RayObjects objs = submit(taskId, fid, returnCount, true, ls);
|
||||
|
||||
RayFunc lambda, Integer returnCount,
|
||||
Object[] args) {
|
||||
MethodId mid = methodIdOf(lambda);
|
||||
RayObjects objs = submit(taskId, mid, returnCount, true, args);
|
||||
RayList<R> rets = new RayList<>();
|
||||
for (RayObject obj : objs.getObjs()) {
|
||||
rets.add(obj);
|
||||
@@ -234,6 +150,27 @@ public class Worker {
|
||||
return rets;
|
||||
}
|
||||
|
||||
private RayInvocation createRemoteInvocation(MethodId methodId, Object[] args,
|
||||
RayActor<?> actor) {
|
||||
UniqueID driverId = WorkerContext.currentTask().driverId;
|
||||
|
||||
Object[] ls = Arrays.copyOf(args, args.length + 1);
|
||||
ls[args.length] = methodId.className;
|
||||
|
||||
RayMethod method = functions
|
||||
.getMethod(driverId, actor.getId(), new UniqueID(methodId.getSha1Hash()),
|
||||
methodId.className).getRight();
|
||||
|
||||
RayInvocation ri = new RayInvocation(methodId.className, method.getFuncId(),
|
||||
ls, method.remoteAnnotation, actor);
|
||||
return ri;
|
||||
}
|
||||
|
||||
private MethodId methodIdOf(RayFunc serialLambda) {
|
||||
MethodId mid = MethodId.fromSerializedLambda(serialLambda);
|
||||
return mid;
|
||||
}
|
||||
|
||||
public UniqueID getCurrentTaskId() {
|
||||
return WorkerContext.currentTask().taskId;
|
||||
}
|
||||
@@ -246,4 +183,4 @@ public class Worker {
|
||||
public UniqueID[] getCurrentTaskReturnIDs() {
|
||||
return WorkerContext.currentTask().returnIds;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,40 +1,23 @@
|
||||
package org.ray.core.model;
|
||||
|
||||
public enum RunMode {
|
||||
SINGLE_PROCESS(true, false, true, false), // remote lambda, dev path, dev runtime
|
||||
SINGLE_BOX(true, false, true, true), // remote lambda, dev path, native runtime
|
||||
CLUSTER(false, true, false, true); // static rewrite, deploy path, naive runtime
|
||||
SINGLE_PROCESS(true, false), // dev path, dev runtime
|
||||
SINGLE_BOX(true, true), // dev path, native runtime
|
||||
CLUSTER(false, true); // deploy path, naive runtime
|
||||
|
||||
private final boolean remoteLambda;
|
||||
private final boolean staticRewrite;
|
||||
private final boolean devPathManager;
|
||||
private final boolean nativeRuntime;
|
||||
|
||||
RunMode(boolean remoteLambda, boolean staticRewrite, boolean devPathManager,
|
||||
boolean nativeRuntime) {
|
||||
this.remoteLambda = remoteLambda;
|
||||
this.staticRewrite = staticRewrite;
|
||||
RunMode(boolean devPathManager,
|
||||
boolean nativeRuntime) {
|
||||
this.devPathManager = devPathManager;
|
||||
this.nativeRuntime = nativeRuntime;
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>remoteLambda</tt>.
|
||||
*
|
||||
* @return property value of remoteLambda
|
||||
* the jar has add to java -cp, no need to load jar after started.
|
||||
*/
|
||||
public boolean isRemoteLambda() {
|
||||
return remoteLambda;
|
||||
}
|
||||
private final boolean devPathManager;
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>staticRewrite</tt>.
|
||||
*
|
||||
* @return property value of staticRewrite
|
||||
*/
|
||||
public boolean isStaticRewrite() {
|
||||
return staticRewrite;
|
||||
}
|
||||
private final boolean nativeRuntime;
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>devPathManager</tt>.
|
||||
@@ -53,4 +36,4 @@ public enum RunMode {
|
||||
public boolean isNativeRuntime() {
|
||||
return nativeRuntime;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -97,7 +97,7 @@ public class LocalSchedulerProxy {
|
||||
|
||||
task.args = ArgumentsBuilder.wrap(invocation);
|
||||
task.driverId = current.driverId;
|
||||
task.functionId = new UniqueID(invocation.getId());
|
||||
task.functionId = invocation.getId();
|
||||
task.parentCounter = -1; // TODO: this field is not used in core or python logically yet
|
||||
task.parentTaskId = current.taskId;
|
||||
task.actorHandleId = invocation.getActor().getActorHandleId();
|
||||
|
||||
@@ -1,18 +1,12 @@
|
||||
package org.ray.spi;
|
||||
|
||||
import java.util.Set;
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.hook.MethodId;
|
||||
import org.ray.hook.runtime.LoadedFunctions;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
/**
|
||||
* mock version of remote function manager using local loaded jars + runtime hook.
|
||||
* mock version of remote function manager using local loaded jars.
|
||||
*/
|
||||
public class NopRemoteFunctionManager implements RemoteFunctionManager {
|
||||
|
||||
private final LoadedFunctions loadedFunctions = new LoadedFunctions();
|
||||
|
||||
public NopRemoteFunctionManager(UniqueID driverId) {
|
||||
//onLoad(driverId, Agent.hookedMethods);
|
||||
//Agent.consumers.add(m -> { this.onLoad(m); });
|
||||
@@ -51,14 +45,9 @@ public class NopRemoteFunctionManager implements RemoteFunctionManager {
|
||||
}
|
||||
|
||||
@Override
|
||||
public LoadedFunctions loadFunctions(UniqueID driverId) {
|
||||
public ClassLoader loadResource(UniqueID driverId) {
|
||||
//assert (startupDriverId().equals(driverId));
|
||||
if (loadedFunctions == null) {
|
||||
RayLog.rapp.error("cannot find functions for " + driverId);
|
||||
return null;
|
||||
} else {
|
||||
return loadedFunctions;
|
||||
}
|
||||
return this.getClass().getClassLoader();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -66,15 +55,4 @@ public class NopRemoteFunctionManager implements RemoteFunctionManager {
|
||||
// never
|
||||
//assert (startupDriverId().equals(driverId));
|
||||
}
|
||||
|
||||
private void onLoad(UniqueID driverId, Set<MethodId> methods) {
|
||||
//assert (startupDriverId().equals(driverId));
|
||||
for (MethodId mid : methods) {
|
||||
onLoad(mid);
|
||||
}
|
||||
}
|
||||
|
||||
private void onLoad(MethodId mid) {
|
||||
loadedFunctions.functions.add(mid);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package org.ray.spi;
|
||||
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.hook.runtime.LoadedFunctions;
|
||||
|
||||
/**
|
||||
* register and load functions from function table.
|
||||
@@ -53,14 +52,13 @@ public interface RemoteFunctionManager {
|
||||
void unregisterApp(UniqueID driverId);
|
||||
|
||||
/**
|
||||
* load resource and functions for this driver this function is used by the workers on demand when
|
||||
* a required function is not found in {@code LocalFunctionManager}.
|
||||
* load resource.
|
||||
*/
|
||||
LoadedFunctions loadFunctions(UniqueID driverId);
|
||||
ClassLoader loadResource(UniqueID driverId);
|
||||
|
||||
/**
|
||||
* unload functions for this driver
|
||||
* this function is used by the workers on demand when a driver is dead.
|
||||
*/
|
||||
void unloadFunctions(UniqueID driverId);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package org.ray.spi.model;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.ray.api.RayRemote;
|
||||
import org.ray.api.UniqueID;
|
||||
|
||||
|
||||
public final class RayActorMethods {
|
||||
|
||||
public final Class clazz;
|
||||
public final RayRemote remoteAnnotation;
|
||||
public final Map<UniqueID, RayMethod> functions;
|
||||
/**
|
||||
* the static function in Actor, call as task.
|
||||
*/
|
||||
public final Map<UniqueID, RayMethod> staticFunctions;
|
||||
|
||||
private RayActorMethods(Class clazz, RayRemote remoteAnnotation,
|
||||
Map<UniqueID, RayMethod> functions, Map<UniqueID, RayMethod> staticFunctions) {
|
||||
this.clazz = clazz;
|
||||
this.remoteAnnotation = remoteAnnotation;
|
||||
this.functions = Collections.unmodifiableMap(new HashMap<>(functions));
|
||||
this.staticFunctions = Collections.unmodifiableMap(new HashMap<>(staticFunctions));
|
||||
}
|
||||
|
||||
public static RayActorMethods fromClass(String clazzName, ClassLoader classLoader) {
|
||||
try {
|
||||
Class clazz = Class.forName(clazzName, true, classLoader);
|
||||
RayRemote remoteAnnotation = (RayRemote) clazz.getAnnotation(RayRemote.class);
|
||||
Preconditions
|
||||
.checkNotNull(remoteAnnotation, "%s must declare @RayRemote", clazzName);
|
||||
Method[] methods = clazz.getDeclaredMethods();
|
||||
Map<UniqueID, RayMethod> functions = new HashMap<>(methods.length * 2);
|
||||
Map<UniqueID, RayMethod> staticFunctions = new HashMap<>(methods.length * 2);
|
||||
|
||||
for (Method m : methods) {
|
||||
if (!Modifier.isPublic(m.getModifiers())) {
|
||||
continue;
|
||||
}
|
||||
RayMethod rayMethod = RayMethod.from(m, remoteAnnotation);
|
||||
if (Modifier.isStatic(m.getModifiers())) {
|
||||
staticFunctions.put(rayMethod.getFuncId(), rayMethod);
|
||||
} else {
|
||||
functions.put(rayMethod.getFuncId(), rayMethod);
|
||||
}
|
||||
}
|
||||
return new RayActorMethods(clazz, remoteAnnotation, functions, staticFunctions);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("failed to get RayActorMethods from " + clazzName, e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String
|
||||
.format("RayActorMethods:%s, funcNum=%s:{%s}, sfuncNum=%s:{%s}", clazz, functions.size(),
|
||||
functions.values(),
|
||||
staticFunctions.size(), staticFunctions.values());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.ray.spi.model;
|
||||
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayRemote;
|
||||
import org.ray.api.UniqueID;
|
||||
|
||||
/**
|
||||
@@ -9,30 +10,30 @@ import org.ray.api.UniqueID;
|
||||
public class RayInvocation {
|
||||
|
||||
private static final RayActor<?> nil = new RayActor<>(UniqueID.nil, UniqueID.nil);
|
||||
public final String className;
|
||||
/**
|
||||
* unique id for a method.
|
||||
*
|
||||
* @see UniqueID
|
||||
*/
|
||||
private final byte[] id;
|
||||
private final RayActor<?> actor;
|
||||
private final UniqueID id;
|
||||
private final RayRemote remoteAnnotation;
|
||||
/**
|
||||
* function arguments.
|
||||
*/
|
||||
private Object[] args;
|
||||
|
||||
public RayInvocation(byte[] id, Object[] args) {
|
||||
this(id, args, nil);
|
||||
}
|
||||
|
||||
public RayInvocation(byte[] id, Object[] args, RayActor<?> actor) {
|
||||
super();
|
||||
private RayActor<?> actor;
|
||||
|
||||
public RayInvocation(String className, UniqueID id, Object[] args, RayRemote remoteAnnotation,
|
||||
RayActor<?> actor) {
|
||||
this.className = className;
|
||||
this.id = id;
|
||||
this.args = args;
|
||||
this.actor = actor;
|
||||
this.remoteAnnotation = remoteAnnotation;
|
||||
}
|
||||
|
||||
public byte[] getId() {
|
||||
public UniqueID getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
@@ -48,4 +49,8 @@ public class RayInvocation {
|
||||
return actor;
|
||||
}
|
||||
|
||||
}
|
||||
public RayRemote getRemoteAnnotation() {
|
||||
return remoteAnnotation;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
package org.ray.spi.model;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import org.ray.api.RayRemote;
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.util.MethodId;
|
||||
|
||||
/**
|
||||
* method info.
|
||||
@@ -9,12 +12,8 @@ public class RayMethod {
|
||||
|
||||
public final Method invokable;
|
||||
public final String fullName;
|
||||
// TODO: other annotated information
|
||||
|
||||
public RayMethod(Method m) {
|
||||
invokable = m;
|
||||
fullName = m.getDeclaringClass().getName() + "." + m.getName();
|
||||
}
|
||||
public final RayRemote remoteAnnotation;
|
||||
private final UniqueID funcId;
|
||||
|
||||
public void check() {
|
||||
for (Class<?> paramCls : invokable.getParameterTypes()) {
|
||||
@@ -24,4 +23,32 @@ public class RayMethod {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private RayMethod(Method m, RayRemote remoteAnnotation, UniqueID funcId) {
|
||||
this.invokable = m;
|
||||
this.remoteAnnotation = remoteAnnotation;
|
||||
this.funcId = funcId;
|
||||
fullName = m.getDeclaringClass().getName() + "." + m.getName();
|
||||
}
|
||||
|
||||
public static RayMethod from(Method m, RayRemote parentRemoteAnnotation) {
|
||||
Class<?> clazz = m.getDeclaringClass();
|
||||
RayRemote remoteAnnotation = m.getAnnotation(RayRemote.class);
|
||||
MethodId mid = MethodId.fromMethod(m);
|
||||
UniqueID funcId = new UniqueID(mid.getSha1Hash());
|
||||
RayMethod method = new RayMethod(m,
|
||||
remoteAnnotation != null ? remoteAnnotation : parentRemoteAnnotation,
|
||||
funcId);
|
||||
method.check();
|
||||
return method;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return fullName;
|
||||
}
|
||||
|
||||
public UniqueID getFuncId() {
|
||||
return funcId;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package org.ray.spi.model;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.ray.api.RayRemote;
|
||||
import org.ray.api.UniqueID;
|
||||
|
||||
|
||||
public final class RayTaskMethods {
|
||||
|
||||
public final Class clazz;
|
||||
public final Map<UniqueID, RayMethod> functions;
|
||||
|
||||
public RayTaskMethods(Class clazz,
|
||||
Map<UniqueID, RayMethod> functions) {
|
||||
this.clazz = clazz;
|
||||
this.functions = Collections.unmodifiableMap(new HashMap<>(functions));
|
||||
}
|
||||
|
||||
public static RayTaskMethods fromClass(String clazzName, ClassLoader classLoader) {
|
||||
try {
|
||||
Class clazz = Class.forName(clazzName, true, classLoader);
|
||||
Method[] methods = clazz.getDeclaredMethods();
|
||||
Map<UniqueID, RayMethod> functions = new HashMap<>(methods.length * 2);
|
||||
|
||||
for (Method m : methods) {
|
||||
if (!Modifier.isStatic(m.getModifiers())) {
|
||||
continue;
|
||||
}
|
||||
//task method only for static.
|
||||
RayRemote remoteAnnotation = m.getAnnotation(RayRemote.class);
|
||||
if (remoteAnnotation == null) {
|
||||
continue;
|
||||
}
|
||||
m.setAccessible(true);
|
||||
RayMethod rayMethod = RayMethod.from(m, null);
|
||||
functions.put(rayMethod.getFuncId(), rayMethod);
|
||||
}
|
||||
return new RayTaskMethods(clazz, functions);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("failed to get RayTaskMethods from " + clazzName, e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String
|
||||
.format("RayTaskMethods:%s, funcNum=%s:{%s}", clazz, functions.size(), functions.values());
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user