[Java] Replace binary rewrite with Remote Lambda Cache (SerdeLambda) (#2245)

* <feature> : serde lambda

* <feature>:fixed CR

with issue #2245

* <feature>: fixed CR
This commit is contained in:
mylinyuzhi
2018-06-14 03:58:07 +08:00
committed by Philipp Moritz
parent 62de86ff7a
commit fa0ade2bc5
89 changed files with 2633 additions and 7668 deletions
+9 -10
View File
@@ -1,5 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
@@ -24,35 +23,35 @@
<artifactId>ray-api</artifactId>
<version>1.0</version>
</dependency>
<dependency>
<groupId>org.ray</groupId>
<artifactId>ray-hook</artifactId>
<version>1.0</version>
</dependency>
<dependency>
<groupId>de.ruedigermoeller</groupId>
<artifactId>fst</artifactId>
<version>2.47</version>
</dependency>
<!-- https://mvnrepository.com/artifact/com.github.davidmoten/flatbuffers-java -->
<dependency>
<groupId>com.github.davidmoten</groupId>
<artifactId>flatbuffers-java</artifactId>
<version>1.7.0.1</version>
</dependency>
<!-- https://mvnrepository.com/artifact/redis.clients/jedis -->
<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
<version>2.8.0</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-plasma</artifactId>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
</dependencies>
</project>
</project>
@@ -4,6 +4,7 @@ import java.io.Serializable;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -83,7 +84,9 @@ public class ArgumentsBuilder {
@SuppressWarnings({"rawtypes", "unchecked"})
public static Pair<Object, Object[]> unwrap(TaskSpec task, Method m, ClassLoader classLoader)
throws TaskExecutionException {
FunctionArg[] fargs = task.args;
// the last arg is className
FunctionArg[] fargs = Arrays.copyOf(task.args, task.args.length - 1);
Object current = null;
Object[] realArgs;
@@ -68,7 +68,7 @@ public class InvocationExecutor {
}
private static void executeInternal(TaskSpec task, Pair<ClassLoader, RayMethod> pr,
String taskdesc)
String taskdesc)
throws IllegalAccessException, IllegalArgumentException, InvocationTargetException {
Method m = pr.getRight().invokable;
Map<?, UniqueID> userRayReturnIdMap = null;
@@ -86,11 +86,12 @@ public class InvocationExecutor {
}
// execute
Object result;
if (!UniqueIdHelper.isLambdaFunction(task.functionId)) {
Object result = null;
try {
result = m.invoke(realArgs.getLeft(), realArgs.getRight());
} else {
result = m.invoke(realArgs.getLeft(), new Object[] {realArgs.getRight()});
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
RayLog.core.error("invoke failed:" + m);
throw e;
}
if (task.returnIds == null || task.returnIds.length == 0) {
@@ -147,4 +148,4 @@ public class InvocationExecutor {
private static void safePut(UniqueID objectId, Object obj) {
RayRuntime.getInstance().putRaw(objectId, obj);
}
}
}
@@ -1,15 +1,14 @@
package org.ray.core;
import java.lang.reflect.Method;
import com.google.common.base.Preconditions;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.UniqueID;
import org.ray.hook.MethodId;
import org.ray.hook.runtime.JarLoader;
import org.ray.hook.runtime.LoadedFunctions;
import org.ray.spi.RemoteFunctionManager;
import org.ray.spi.model.FunctionArg;
import org.ray.spi.model.RayActorMethods;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.RayTaskMethods;
import org.ray.util.logger.RayLog;
/**
@@ -18,6 +17,7 @@ import org.ray.util.logger.RayLog;
public class LocalFunctionManager {
private final RemoteFunctionManager remoteLoader;
private final ConcurrentHashMap<UniqueID, FunctionTable> functionTables
= new ConcurrentHashMap<>();
@@ -29,74 +29,44 @@ public class LocalFunctionManager {
this.remoteLoader = remoteLoader;
}
private FunctionTable loadDriverFunctions(UniqueID driverId) {
FunctionTable functionTable = functionTables.get(driverId);
if (functionTable == null) {
RayLog.core.info("DriverId " + driverId + " Try to load functions");
ClassLoader classLoader = remoteLoader.loadResource(driverId);
if (classLoader == null) {
throw new RuntimeException(
"Cannot find resource' classLoader for app " + driverId.toString());
}
functionTable = new FunctionTable(classLoader);
functionTables.put(driverId, functionTable);
}
return functionTable;
}
Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID actorId,
UniqueID methodId, String className) {
// assert the driver's resource is load.
FunctionTable functionTable = loadDriverFunctions(driverId);
Preconditions.checkNotNull(functionTable, "driver's resource is not loaded:%s", driverId);
RayMethod method = actorId.isNil() ? functionTable.getTaskMethod(methodId, className)
: functionTable.getActorMethod(methodId, className);
Preconditions
.checkNotNull(method, "method not found, class=%s, methodId=%s, driverId=%s", className,
methodId, driverId);
return Pair.of(functionTable.classLoader, method);
}
/**
* get local method for executing, which pulls information from remote repo on-demand, therefore
* it may block for a while if the related resources (e.g., jars) are not ready on local machine
*/
public Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID methodId,
FunctionArg[] args) throws NoSuchMethodException,
SecurityException, ClassNotFoundException {
FunctionTable funcs = loadDriverFunctions(driverId);
RayMethod m;
// hooked methods
if (!UniqueIdHelper.isLambdaFunction(methodId)) {
m = funcs.functions.get(methodId);
if (null == m) {
throw new RuntimeException(
"DriverId " + driverId + " load remote function methodId:" + methodId + " failed");
}
} else { // remote lambda
assert args.length >= 2;
String fname = Serializer.decode(args[args.length - 2].data);
Method fm = Class.forName(fname).getMethod("execute", Object[].class);
m = new RayMethod(fm);
}
return Pair.of(funcs.linkedFunctions.loader, m);
}
private synchronized FunctionTable loadDriverFunctions(UniqueID driverId) {
FunctionTable funcs = functionTables.get(driverId);
if (null == funcs) {
RayLog.core.debug("DriverId " + driverId + " Try to load functions");
LoadedFunctions funcs2 = remoteLoader.loadFunctions(driverId);
if (funcs2 == null) {
throw new RuntimeException("Cannot find resource for app " + driverId.toString());
}
funcs = new FunctionTable();
funcs.linkedFunctions = funcs2;
for (MethodId mid : funcs.linkedFunctions.functions) {
Method m = mid.load();
assert (m != null);
RayMethod v = new RayMethod(m);
v.check();
UniqueID k = new UniqueID(mid.getSha1Hash());
String logInfo =
"DriverId" + driverId + " load remote function " + m.getName() + ", hash = " + k
.toString();
RayLog.core.debug(logInfo);
System.err.println(logInfo);
funcs.functions.put(k, v);
}
functionTables.put(driverId, funcs);
} else { // reSync automatically
// more functions are loaded
if (funcs.linkedFunctions.functions.size() > funcs.functions.size()) {
for (MethodId mid : funcs.linkedFunctions.functions) {
UniqueID k = new UniqueID(mid.getSha1Hash());
if (!funcs.functions.containsKey(k)) {
Method m = mid.load();
assert (m != null);
RayMethod v = new RayMethod(m);
v.check();
funcs.functions.put(k, v);
}
}
}
}
return funcs;
public Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID actorId,
UniqueID methodId,
FunctionArg[] args) throws NoSuchMethodException, SecurityException, ClassNotFoundException {
Preconditions.checkArgument(args.length >= 1, "method's args len %s<=1", args.length);
String className = (String) Serializer.decode(args[args.length - 1].data);
return getMethod(driverId, actorId, methodId, className);
}
/**
@@ -106,13 +76,47 @@ public class LocalFunctionManager {
FunctionTable funcs = functionTables.get(driverId);
if (funcs != null) {
functionTables.remove(driverId);
JarLoader.unloadJars(funcs.linkedFunctions.loader);
remoteLoader.unloadFunctions(driverId);
}
}
static class FunctionTable {
private static class FunctionTable {
public final ConcurrentHashMap<UniqueID, RayMethod> functions = new ConcurrentHashMap<>();
public LoadedFunctions linkedFunctions;
final ClassLoader classLoader;
final ConcurrentHashMap<String, RayTaskMethods> taskMethods = new ConcurrentHashMap<>();
final ConcurrentHashMap<String, RayActorMethods> actors = new ConcurrentHashMap<>();
FunctionTable(ClassLoader classLoader) {
this.classLoader = classLoader;
}
RayMethod getTaskMethod(UniqueID methodId, String className) {
RayTaskMethods tasks = taskMethods.get(className);
if (tasks == null) {
tasks = RayTaskMethods.fromClass(className, classLoader);
RayLog.core.info("create RayTaskMethods:" + tasks);
taskMethods.put(className, tasks);
}
RayMethod m = tasks.functions.get(methodId);
if (m != null) {
return m;
}
// it is a actor static func.
return getActorMethod(methodId, className, true);
}
RayMethod getActorMethod(UniqueID methodId, String className) {
return getActorMethod(methodId, className, false);
}
private RayMethod getActorMethod(UniqueID methodId, String className, boolean isStatic) {
RayActorMethods actor = actors.get(className);
if (actor == null) {
actor = RayActorMethods.fromClass(className, classLoader);
RayLog.core.info("create RayActorMethods:" + actor);
actors.put(className, actor);
}
return isStatic ? actor.staticFunctions.get(methodId) : actor.functions.get(methodId);
}
}
}
}
@@ -19,7 +19,7 @@ import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.api.internal.Callable;
import org.ray.api.internal.RayFunc;
import org.ray.core.model.RayParameters;
import org.ray.spi.LocalSchedulerLink;
import org.ray.spi.LocalSchedulerProxy;
@@ -261,49 +261,23 @@ public abstract class RayRuntime implements RayApi {
}
@Override
public RayObjects call(UniqueID taskId, Callable funcRun, int returnCount, Object... args) {
return worker.rpc(taskId, funcRun, returnCount, args);
}
@Override
public RayObjects call(UniqueID taskId, Class<?> funcCls, Serializable lambda, int returnCount,
Object... args) {
return worker.rpc(taskId, UniqueID.nil, funcCls, lambda, returnCount, args);
}
@Override
public <R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Callable funcRun,
Collection<RIDT> returnIds,
Object... args) {
return worker.rpcWithReturnLabels(taskId, funcRun, returnIds, args);
public RayObjects call(UniqueID taskId, Class<?> funcCls, RayFunc lambda, int returnCount,
Object... args) {
return worker.rpc(taskId, funcCls, lambda, returnCount, args);
}
@Override
public <R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Class<?> funcCls,
Serializable lambda, Collection<RIDT>
returnids,
Object... args) {
RayFunc lambda, Collection<RIDT> returnids, Object... args) {
return worker.rpcWithReturnLabels(taskId, funcCls, lambda, returnids, args);
}
@Override
public <R> RayList<R> callWithReturnIndices(UniqueID taskId, Callable funcRun,
Integer returnCount, Object... args) {
return worker.rpcWithReturnIndices(taskId, funcRun, returnCount, args);
}
@Override
public <R> RayList<R> callWithReturnIndices(UniqueID taskId, Class<?> funcCls,
Serializable lambda, Integer returnCount, Object...
args) {
RayFunc lambda, Integer returnCount, Object... args) {
return worker.rpcWithReturnIndices(taskId, funcCls, lambda, returnCount, args);
}
@Override
public boolean isRemoteLambda() {
return params.run_mode.isRemoteLambda();
}
private <T> List<T> doGet(List<UniqueID> objectIds, boolean isMetadata)
throws TaskExecutionException {
boolean wasBlocked = false;
@@ -465,4 +439,4 @@ public abstract class RayRuntime implements RayApi {
public RemoteFunctionManager getRemoteFunctionManager() {
return remoteFunctionManager;
}
}
}
@@ -249,12 +249,6 @@ public class UniqueIdHelper {
return taskId;
}
public static boolean isLambdaFunction(UniqueID functionId) {
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
return wbb.getLong() == 0xffffffffffffffffL;
}
public static void markCreateActorStage1Function(UniqueID functionId) {
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
@@ -279,4 +273,4 @@ public class UniqueIdHelper {
TASK,
ACTOR,
}
}
}
@@ -1,10 +1,11 @@
package org.ray.core;
import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.Collection;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.RayActor;
import org.ray.api.RayList;
@@ -12,12 +13,13 @@ import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.api.internal.Callable;
import org.ray.hook.runtime.MethodSwitcher;
import org.ray.api.internal.RayFunc;
import org.ray.spi.LocalSchedulerProxy;
import org.ray.spi.model.RayInvocation;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.TaskSpec;
import org.ray.util.LambdaUtils;
import org.ray.util.MethodId;
import org.ray.util.exception.TaskExecutionException;
import org.ray.util.logger.RayLog;
@@ -47,13 +49,13 @@ public class Worker {
RayLog.core.info("Task " + task.taskId + " start execute");
Throwable ex = null;
if (!task.actorId.isNil() || (task.createActorId != null && !task.createActorId.isNil())) {
task.returnIds = ArrayUtils.subarray(task.returnIds, 0, task.returnIds.length - 1);
}
try {
Pair<ClassLoader, RayMethod> pr = funcs.getMethod(task.driverId, task.functionId, task.args);
Pair<ClassLoader, RayMethod> pr = funcs
.getMethod(task.driverId, task.actorId, task.functionId, task.args);
WorkerContext.prepare(task, pr.getLeft());
InvocationExecutor.execute(task, pr);
} catch (NoSuchMethodException | SecurityException | ClassNotFoundException e) {
@@ -72,161 +74,75 @@ public class Worker {
}
public RayObjects rpc(UniqueID taskId, Callable funcRun, int returnCount, Object[] args) {
byte[] fid = fidFromHook(funcRun);
return submit(taskId, fid, returnCount, false, args);
}
public RayObjects rpc(UniqueID taskId, UniqueID functionId, Class<?> funcCls, Serializable lambda,
int returnCount, Object[] args) {
byte[] fid = functionId.getBytes();
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
return submit(taskId, fid, returnCount, false, ls);
}
public RayObjects rpc(UniqueID taskId, RayActor<?> actor, Callable funcRun, int returnCount,
Object[] args) {
byte[] fid = fidFromHook(funcRun);
return actorTaskSubmit(taskId, fid, returnCount, false, args, actor);
}
public RayObjects rpc(UniqueID taskId, UniqueID functionId, RayActor<?> actor, Class<?> funcCls,
Serializable lambda, int returnCount, Object[] args) {
byte[] fid = functionId.getBytes();
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
return actorTaskSubmit(taskId, fid, returnCount, false, ls, actor);
}
private byte[] fidFromHook(Callable funcRun) {
MethodSwitcher.IsRemoteCall.set(true);
try {
funcRun.run();
} catch (Throwable e) {
RayLog.core.error(
"make sure you are using code rewritten using the rewrite tool, see JarRewriter for"
+ " options", e);
throw new RuntimeException("make sure you are using code rewritten using the rewrite tool,"
+ "see JarRewriter for options");
}
byte[] fid = MethodSwitcher.MethodId.get();//get the identity of function from hook
MethodSwitcher.IsRemoteCall.set(false);
return fid;
}
private RayObjects submit(UniqueID taskId,
byte[] fid,
int returnCount,
boolean multiReturn,
Object[] args) {
if (taskId == null) {
taskId = UniqueIdHelper.nextTaskId(-1);
}
if (args.length > 0 && args[0].getClass().equals(RayActor.class)) {
return actorTaskSubmit(taskId, fid, returnCount, multiReturn, args, (RayActor<?>) args[0]);
} else {
return taskSubmit(taskId, fid, returnCount, multiReturn, args);
}
private RayObjects taskSubmit(UniqueID taskId,
MethodId methodId,
int returnCount,
boolean multiReturn,
Object[] args) {
RayInvocation ri = createRemoteInvocation(methodId, args, RayActor.nil);
return scheduler.submit(taskId, ri, returnCount, multiReturn);
}
private RayObjects actorTaskSubmit(UniqueID taskId,
byte[] fid,
int returnCount,
boolean multiReturn,
Object[] args,
RayActor<?> actor) {
RayInvocation ri = new RayInvocation(fid, args, actor);
MethodId methodId,
int returnCount,
boolean multiReturn,
Object[] args,
RayActor<?> actor) {
RayInvocation ri = createRemoteInvocation(methodId, args, actor);
RayObjects returnObjs = scheduler.submit(taskId, ri, returnCount + 1, multiReturn);
actor.setTaskCursor(returnObjs.pop().getId());
return returnObjs;
}
private RayObjects taskSubmit(UniqueID taskId,
byte[] fid,
int returnCount,
boolean multiReturn,
Object[] args) {
RayInvocation ri = new RayInvocation(fid, args);
return scheduler.submit(taskId, ri, returnCount, multiReturn);
}
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId, Callable funcRun,
int returnCount,
Object[] args) {
byte[] fid = fidFromHook(funcRun);
RayInvocation ri = new RayInvocation(fid, new Object[] {});
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
}
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId, UniqueID functionId,
Class<?> funcCls, Serializable lambda, int returnCount,
Object[] args) {
byte[] fid = functionId.getBytes();
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
RayInvocation ri = new RayInvocation(fid, ls);
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
}
public <R, RIDT> RayMap<RIDT, R> rpcWithReturnLabels(UniqueID taskId, Callable funcRun,
Collection<RIDT> returnids,
Object[] args) {
byte[] fid = fidFromHook(funcRun);
private RayObjects submit(UniqueID taskId,
MethodId methodId,
int returnCount,
boolean multiReturn,
Object[] args) {
if (taskId == null) {
taskId = UniqueIdHelper.nextTaskId(-1);
}
return scheduler.submit(taskId, new RayInvocation(fid, args), returnids);
if (args.length > 0 && args[0].getClass().equals(RayActor.class)) {
return actorTaskSubmit(taskId, methodId, returnCount, multiReturn, args,
(RayActor<?>) args[0]);
} else {
return taskSubmit(taskId, methodId, returnCount, multiReturn, args);
}
}
public RayObjects rpc(UniqueID taskId, Class<?> funcCls, RayFunc lambda,
int returnCount, Object[] args) {
MethodId mid = methodIdOf(lambda);
return submit(taskId, mid, returnCount, false, args);
}
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId,
Class<?> funcCls, RayFunc lambda, int returnCount, Object[] args) {
Preconditions.checkNotNull(taskId);
MethodId mid = methodIdOf(lambda);
RayInvocation ri = createRemoteInvocation(mid, args, RayActor.nil);
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
}
public <R, RIDT> RayMap<RIDT, R> rpcWithReturnLabels(UniqueID taskId, Class<?> funcCls,
Serializable lambda,
Collection<RIDT> returnids,
Object[] args) {
final byte[] fid = UniqueID.nil.getBytes();
RayFunc lambda, Collection<RIDT> returnids,
Object[] args) {
MethodId mid = methodIdOf(lambda);
if (taskId == null) {
taskId = UniqueIdHelper.nextTaskId(-1);
}
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
return scheduler.submit(taskId, new RayInvocation(fid, ls), returnids);
RayInvocation ri = createRemoteInvocation(mid, args, RayActor.nil);
return scheduler.submit(taskId, ri, returnids);
}
public <R> RayList<R> rpcWithReturnIndices(UniqueID taskId, Callable funcRun,
Integer returnCount,
Object[] args) {
byte[] fid = fidFromHook(funcRun);
RayObjects objs = submit(taskId, fid, returnCount, true, args);
RayList<R> rets = new RayList<>();
for (RayObject obj : objs.getObjs()) {
rets.add(obj);
}
return rets;
}
@SuppressWarnings("unchecked")
public <R> RayList<R> rpcWithReturnIndices(UniqueID taskId, Class<?> funcCls,
Serializable lambda, Integer returnCount,
Object[] args) {
byte[] fid = UniqueID.nil.getBytes();
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
RayObjects objs = submit(taskId, fid, returnCount, true, ls);
RayFunc lambda, Integer returnCount,
Object[] args) {
MethodId mid = methodIdOf(lambda);
RayObjects objs = submit(taskId, mid, returnCount, true, args);
RayList<R> rets = new RayList<>();
for (RayObject obj : objs.getObjs()) {
rets.add(obj);
@@ -234,6 +150,27 @@ public class Worker {
return rets;
}
private RayInvocation createRemoteInvocation(MethodId methodId, Object[] args,
RayActor<?> actor) {
UniqueID driverId = WorkerContext.currentTask().driverId;
Object[] ls = Arrays.copyOf(args, args.length + 1);
ls[args.length] = methodId.className;
RayMethod method = functions
.getMethod(driverId, actor.getId(), new UniqueID(methodId.getSha1Hash()),
methodId.className).getRight();
RayInvocation ri = new RayInvocation(methodId.className, method.getFuncId(),
ls, method.remoteAnnotation, actor);
return ri;
}
private MethodId methodIdOf(RayFunc serialLambda) {
MethodId mid = MethodId.fromSerializedLambda(serialLambda);
return mid;
}
public UniqueID getCurrentTaskId() {
return WorkerContext.currentTask().taskId;
}
@@ -246,4 +183,4 @@ public class Worker {
public UniqueID[] getCurrentTaskReturnIDs() {
return WorkerContext.currentTask().returnIds;
}
}
}
@@ -1,40 +1,23 @@
package org.ray.core.model;
public enum RunMode {
SINGLE_PROCESS(true, false, true, false), // remote lambda, dev path, dev runtime
SINGLE_BOX(true, false, true, true), // remote lambda, dev path, native runtime
CLUSTER(false, true, false, true); // static rewrite, deploy path, naive runtime
SINGLE_PROCESS(true, false), // dev path, dev runtime
SINGLE_BOX(true, true), // dev path, native runtime
CLUSTER(false, true); // deploy path, naive runtime
private final boolean remoteLambda;
private final boolean staticRewrite;
private final boolean devPathManager;
private final boolean nativeRuntime;
RunMode(boolean remoteLambda, boolean staticRewrite, boolean devPathManager,
boolean nativeRuntime) {
this.remoteLambda = remoteLambda;
this.staticRewrite = staticRewrite;
RunMode(boolean devPathManager,
boolean nativeRuntime) {
this.devPathManager = devPathManager;
this.nativeRuntime = nativeRuntime;
}
/**
* Getter method for property <tt>remoteLambda</tt>.
*
* @return property value of remoteLambda
* the jar has add to java -cp, no need to load jar after started.
*/
public boolean isRemoteLambda() {
return remoteLambda;
}
private final boolean devPathManager;
/**
* Getter method for property <tt>staticRewrite</tt>.
*
* @return property value of staticRewrite
*/
public boolean isStaticRewrite() {
return staticRewrite;
}
private final boolean nativeRuntime;
/**
* Getter method for property <tt>devPathManager</tt>.
@@ -53,4 +36,4 @@ public enum RunMode {
public boolean isNativeRuntime() {
return nativeRuntime;
}
}
}
@@ -97,7 +97,7 @@ public class LocalSchedulerProxy {
task.args = ArgumentsBuilder.wrap(invocation);
task.driverId = current.driverId;
task.functionId = new UniqueID(invocation.getId());
task.functionId = invocation.getId();
task.parentCounter = -1; // TODO: this field is not used in core or python logically yet
task.parentTaskId = current.taskId;
task.actorHandleId = invocation.getActor().getActorHandleId();
@@ -1,18 +1,12 @@
package org.ray.spi;
import java.util.Set;
import org.ray.api.UniqueID;
import org.ray.hook.MethodId;
import org.ray.hook.runtime.LoadedFunctions;
import org.ray.util.logger.RayLog;
/**
* mock version of remote function manager using local loaded jars + runtime hook.
* mock version of remote function manager using local loaded jars.
*/
public class NopRemoteFunctionManager implements RemoteFunctionManager {
private final LoadedFunctions loadedFunctions = new LoadedFunctions();
public NopRemoteFunctionManager(UniqueID driverId) {
//onLoad(driverId, Agent.hookedMethods);
//Agent.consumers.add(m -> { this.onLoad(m); });
@@ -51,14 +45,9 @@ public class NopRemoteFunctionManager implements RemoteFunctionManager {
}
@Override
public LoadedFunctions loadFunctions(UniqueID driverId) {
public ClassLoader loadResource(UniqueID driverId) {
//assert (startupDriverId().equals(driverId));
if (loadedFunctions == null) {
RayLog.rapp.error("cannot find functions for " + driverId);
return null;
} else {
return loadedFunctions;
}
return this.getClass().getClassLoader();
}
@Override
@@ -66,15 +55,4 @@ public class NopRemoteFunctionManager implements RemoteFunctionManager {
// never
//assert (startupDriverId().equals(driverId));
}
private void onLoad(UniqueID driverId, Set<MethodId> methods) {
//assert (startupDriverId().equals(driverId));
for (MethodId mid : methods) {
onLoad(mid);
}
}
private void onLoad(MethodId mid) {
loadedFunctions.functions.add(mid);
}
}
}
@@ -1,7 +1,6 @@
package org.ray.spi;
import org.ray.api.UniqueID;
import org.ray.hook.runtime.LoadedFunctions;
/**
* register and load functions from function table.
@@ -53,14 +52,13 @@ public interface RemoteFunctionManager {
void unregisterApp(UniqueID driverId);
/**
* load resource and functions for this driver this function is used by the workers on demand when
* a required function is not found in {@code LocalFunctionManager}.
* load resource.
*/
LoadedFunctions loadFunctions(UniqueID driverId);
ClassLoader loadResource(UniqueID driverId);
/**
* unload functions for this driver
* this function is used by the workers on demand when a driver is dead.
*/
void unloadFunctions(UniqueID driverId);
}
}
@@ -0,0 +1,66 @@
package org.ray.spi.model;
import com.google.common.base.Preconditions;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
public final class RayActorMethods {
public final Class clazz;
public final RayRemote remoteAnnotation;
public final Map<UniqueID, RayMethod> functions;
/**
* the static function in Actor, call as task.
*/
public final Map<UniqueID, RayMethod> staticFunctions;
private RayActorMethods(Class clazz, RayRemote remoteAnnotation,
Map<UniqueID, RayMethod> functions, Map<UniqueID, RayMethod> staticFunctions) {
this.clazz = clazz;
this.remoteAnnotation = remoteAnnotation;
this.functions = Collections.unmodifiableMap(new HashMap<>(functions));
this.staticFunctions = Collections.unmodifiableMap(new HashMap<>(staticFunctions));
}
public static RayActorMethods fromClass(String clazzName, ClassLoader classLoader) {
try {
Class clazz = Class.forName(clazzName, true, classLoader);
RayRemote remoteAnnotation = (RayRemote) clazz.getAnnotation(RayRemote.class);
Preconditions
.checkNotNull(remoteAnnotation, "%s must declare @RayRemote", clazzName);
Method[] methods = clazz.getDeclaredMethods();
Map<UniqueID, RayMethod> functions = new HashMap<>(methods.length * 2);
Map<UniqueID, RayMethod> staticFunctions = new HashMap<>(methods.length * 2);
for (Method m : methods) {
if (!Modifier.isPublic(m.getModifiers())) {
continue;
}
RayMethod rayMethod = RayMethod.from(m, remoteAnnotation);
if (Modifier.isStatic(m.getModifiers())) {
staticFunctions.put(rayMethod.getFuncId(), rayMethod);
} else {
functions.put(rayMethod.getFuncId(), rayMethod);
}
}
return new RayActorMethods(clazz, remoteAnnotation, functions, staticFunctions);
} catch (Exception e) {
throw new RuntimeException("failed to get RayActorMethods from " + clazzName, e);
}
}
@Override
public String toString() {
return String
.format("RayActorMethods:%s, funcNum=%s:{%s}, sfuncNum=%s:{%s}", clazz, functions.size(),
functions.values(),
staticFunctions.size(), staticFunctions.values());
}
}
@@ -1,6 +1,7 @@
package org.ray.spi.model;
import org.ray.api.RayActor;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
/**
@@ -9,30 +10,30 @@ import org.ray.api.UniqueID;
public class RayInvocation {
private static final RayActor<?> nil = new RayActor<>(UniqueID.nil, UniqueID.nil);
public final String className;
/**
* unique id for a method.
*
* @see UniqueID
*/
private final byte[] id;
private final RayActor<?> actor;
private final UniqueID id;
private final RayRemote remoteAnnotation;
/**
* function arguments.
*/
private Object[] args;
public RayInvocation(byte[] id, Object[] args) {
this(id, args, nil);
}
public RayInvocation(byte[] id, Object[] args, RayActor<?> actor) {
super();
private RayActor<?> actor;
public RayInvocation(String className, UniqueID id, Object[] args, RayRemote remoteAnnotation,
RayActor<?> actor) {
this.className = className;
this.id = id;
this.args = args;
this.actor = actor;
this.remoteAnnotation = remoteAnnotation;
}
public byte[] getId() {
public UniqueID getId() {
return id;
}
@@ -48,4 +49,8 @@ public class RayInvocation {
return actor;
}
}
public RayRemote getRemoteAnnotation() {
return remoteAnnotation;
}
}
@@ -1,6 +1,9 @@
package org.ray.spi.model;
import java.lang.reflect.Method;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
import org.ray.util.MethodId;
/**
* method info.
@@ -9,12 +12,8 @@ public class RayMethod {
public final Method invokable;
public final String fullName;
// TODO: other annotated information
public RayMethod(Method m) {
invokable = m;
fullName = m.getDeclaringClass().getName() + "." + m.getName();
}
public final RayRemote remoteAnnotation;
private final UniqueID funcId;
public void check() {
for (Class<?> paramCls : invokable.getParameterTypes()) {
@@ -24,4 +23,32 @@ public class RayMethod {
}
}
}
}
private RayMethod(Method m, RayRemote remoteAnnotation, UniqueID funcId) {
this.invokable = m;
this.remoteAnnotation = remoteAnnotation;
this.funcId = funcId;
fullName = m.getDeclaringClass().getName() + "." + m.getName();
}
public static RayMethod from(Method m, RayRemote parentRemoteAnnotation) {
Class<?> clazz = m.getDeclaringClass();
RayRemote remoteAnnotation = m.getAnnotation(RayRemote.class);
MethodId mid = MethodId.fromMethod(m);
UniqueID funcId = new UniqueID(mid.getSha1Hash());
RayMethod method = new RayMethod(m,
remoteAnnotation != null ? remoteAnnotation : parentRemoteAnnotation,
funcId);
method.check();
return method;
}
@Override
public String toString() {
return fullName;
}
public UniqueID getFuncId() {
return funcId;
}
}
@@ -0,0 +1,54 @@
package org.ray.spi.model;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
public final class RayTaskMethods {
public final Class clazz;
public final Map<UniqueID, RayMethod> functions;
public RayTaskMethods(Class clazz,
Map<UniqueID, RayMethod> functions) {
this.clazz = clazz;
this.functions = Collections.unmodifiableMap(new HashMap<>(functions));
}
public static RayTaskMethods fromClass(String clazzName, ClassLoader classLoader) {
try {
Class clazz = Class.forName(clazzName, true, classLoader);
Method[] methods = clazz.getDeclaredMethods();
Map<UniqueID, RayMethod> functions = new HashMap<>(methods.length * 2);
for (Method m : methods) {
if (!Modifier.isStatic(m.getModifiers())) {
continue;
}
//task method only for static.
RayRemote remoteAnnotation = m.getAnnotation(RayRemote.class);
if (remoteAnnotation == null) {
continue;
}
m.setAccessible(true);
RayMethod rayMethod = RayMethod.from(m, null);
functions.put(rayMethod.getFuncId(), rayMethod);
}
return new RayTaskMethods(clazz, functions);
} catch (Exception e) {
throw new RuntimeException("failed to get RayTaskMethods from " + clazzName, e);
}
}
@Override
public String toString() {
return String
.format("RayTaskMethods:%s, funcNum=%s:{%s}", clazz, functions.size(), functions.values());
}
}