mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:06:31 +08:00
[java] put function meta in task spec and load functions with function meta (#2881)
This PR adds a `function_desc` field into task spec. a function descriptor is a list of strings that can uniquely describe a function. - For a Python function, it should be: [module_name, class_name, function_name] - For a Java function, it should be: [class_name, method_name, type_descriptor] There're a couple of purposes to add this field: In this PR: - Java worker needs to know function's class name to load it. Previously, since task spec didn't have such a field to hold this info, we did a hack by appending the class name to the argument list. With this change, we fixed that hack and significantly simplified function management in Java. Will be done in subsequent PRs: - Support cross-language invocation (#2576): currently Python worker manages functions by saving them in GCS and pass function id in task spec. However, if we want to call a Python function from Java, we cannot save it in GCS and get the function id. But instead, we can pass the function descriptor (module name, class name, function name) in task spec and use it to load the function. - Support deployment: one major problem of Python worker's current function management mechanism is #2327. In prod env, we should have a mechanism to deploy code and dependencies to the cluster. And when code is already deployed, we don't need to save functions to GCS any more and can use `function_desc` to manage functions.
This commit is contained in:
committed by
Robert Nishihara
parent
3cccb49191
commit
971df5ea8a
@@ -3,6 +3,7 @@
|
||||
"http://www.puppycrawl.com/dtds/suppressions_1_1.dtd">
|
||||
|
||||
<suppressions>
|
||||
<suppress checks="OperatorWrap" files=".*" />
|
||||
<suppress checks="MemberNameCheck" files="PathConfig.java"/>
|
||||
<suppress checks="MemberNameCheck" files="RayParameters.java"/>
|
||||
<suppress checks="AbbreviationAsWordInNameCheck" files="RayParameters.java"/>
|
||||
|
||||
@@ -1,24 +1,19 @@
|
||||
package org.ray.cli;
|
||||
|
||||
import com.beust.jcommander.JCommander;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import net.lingala.zip4j.core.ZipFile;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.PathConfig;
|
||||
import org.ray.runtime.config.RayParameters;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.functionmanager.NativeRemoteFunctionManager;
|
||||
import org.ray.runtime.functionmanager.RemoteFunctionManager;
|
||||
import org.ray.runtime.gcs.KeyValueStoreLink;
|
||||
import org.ray.runtime.gcs.RedisClient;
|
||||
import org.ray.runtime.gcs.StateStoreProxy;
|
||||
import org.ray.runtime.gcs.StateStoreProxyImpl;
|
||||
import org.ray.runtime.runner.RunManager;
|
||||
import org.ray.runtime.runner.worker.DefaultDriver;
|
||||
import org.ray.runtime.util.FileUtil;
|
||||
import org.ray.runtime.util.config.ConfigReader;
|
||||
import org.ray.runtime.util.logger.RayLog;
|
||||
|
||||
@@ -140,64 +135,18 @@ public class RayCli {
|
||||
ConfigReader config = new ConfigReader(configPath, "ray.java.start.deploy=true");
|
||||
PathConfig paths = new PathConfig(config);
|
||||
RayParameters params = new RayParameters(config);
|
||||
|
||||
params.redis_address = cmdSubmit.redisAddress;
|
||||
params.run_mode = RunMode.CLUSTER;
|
||||
|
||||
|
||||
KeyValueStoreLink kvStore = new RedisClient();
|
||||
kvStore.setAddr(cmdSubmit.redisAddress);
|
||||
StateStoreProxy stateStoreProxy = new StateStoreProxyImpl(kvStore);
|
||||
stateStoreProxy.initializeGlobalState();
|
||||
|
||||
RemoteFunctionManager functionManager = new NativeRemoteFunctionManager(kvStore);
|
||||
|
||||
// Register app to Redis.
|
||||
byte[] zip = FileUtil.fileToBytes(cmdSubmit.packageZip);
|
||||
|
||||
String packageName = cmdSubmit.packageZip.substring(
|
||||
cmdSubmit.packageZip.lastIndexOf('/') + 1,
|
||||
cmdSubmit.packageZip.lastIndexOf('.'));
|
||||
|
||||
UniqueId resourceId = functionManager.registerResource(zip);
|
||||
|
||||
// Init RayLog before using it.
|
||||
RayLog.init(params.log_dir);
|
||||
|
||||
RayLog.rapp.debug(
|
||||
"registerResource " + resourceId + " for package " + packageName + " done");
|
||||
|
||||
UniqueId appId = params.driver_id;
|
||||
functionManager.registerApp(appId, resourceId);
|
||||
RayLog.rapp.debug("registerApp " + appId + " for resouorce " + resourceId + " done");
|
||||
|
||||
// Unzip the package file.
|
||||
String appDir = "/tmp/" + cmdSubmit.className;
|
||||
String extPath = appDir + "/" + packageName;
|
||||
if (!FileUtil.createDir(extPath, false)) {
|
||||
throw new RuntimeException("create dir " + extPath + " failed ");
|
||||
}
|
||||
|
||||
ZipFile zipFile = new ZipFile(cmdSubmit.packageZip);
|
||||
zipFile.extractAll(extPath);
|
||||
|
||||
// Build the args for driver process.
|
||||
File originDirFile = new File(extPath);
|
||||
File[] topFiles = originDirFile.listFiles();
|
||||
String topDir = null;
|
||||
for (File file : topFiles) {
|
||||
if (file.isDirectory()) {
|
||||
topDir = file.getName();
|
||||
}
|
||||
}
|
||||
RayLog.rapp.debug("topDir of app classes: " + topDir);
|
||||
if (topDir == null) {
|
||||
RayLog.rapp.error("Can't find topDir of app classes, the app directory " + appDir);
|
||||
return;
|
||||
}
|
||||
|
||||
String additionalClassPath = appDir + "/" + packageName + "/" + topDir + "/*";
|
||||
RayLog.rapp.debug("Find app class path " + additionalClassPath);
|
||||
|
||||
// Start driver process.
|
||||
RunManager runManager = new RunManager(params, paths, config);
|
||||
@@ -209,18 +158,15 @@ public class RayCli {
|
||||
params.node_ip_address,
|
||||
cmdSubmit.className,
|
||||
cmdSubmit.classArgs,
|
||||
additionalClassPath,
|
||||
"",
|
||||
null);
|
||||
|
||||
if (null == proc) {
|
||||
RayLog.rapp.error(
|
||||
"Create process for app " + packageName + " in local directory " + appDir
|
||||
+ " failed");
|
||||
if (null == proc) {
|
||||
RayLog.rapp.error("Failed to start driver.");
|
||||
return;
|
||||
}
|
||||
|
||||
RayLog.rapp
|
||||
.info("Create app " + appDir + " for package " + packageName + " succeeded");
|
||||
RayLog.rapp.info("Driver started.");
|
||||
}
|
||||
|
||||
private static String getConfigPath(String config) {
|
||||
|
||||
@@ -105,6 +105,11 @@
|
||||
<artifactId>mockito-all</artifactId>
|
||||
<version>1.10.19</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>4.11</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
|
||||
@@ -66,5 +66,12 @@
|
||||
<groupId>org.ow2.asm</groupId>
|
||||
<artifactId>asm</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- test dependencies -->
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
@@ -17,15 +17,13 @@ import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.config.PathConfig;
|
||||
import org.ray.runtime.config.RayParameters;
|
||||
import org.ray.runtime.functionmanager.LocalFunctionManager;
|
||||
import org.ray.runtime.functionmanager.RayMethod;
|
||||
import org.ray.runtime.functionmanager.RemoteFunctionManager;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy.GetStatus;
|
||||
import org.ray.runtime.raylet.RayletClient;
|
||||
import org.ray.runtime.task.ArgumentsBuilder;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.ray.runtime.util.MethodId;
|
||||
import org.ray.runtime.util.ResourceUtil;
|
||||
import org.ray.runtime.util.UniqueIdHelper;
|
||||
import org.ray.runtime.util.config.ConfigReader;
|
||||
@@ -44,8 +42,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
protected Worker worker;
|
||||
protected RayletClient rayletClient;
|
||||
protected ObjectStoreProxy objectStoreProxy;
|
||||
protected LocalFunctionManager functions;
|
||||
protected RemoteFunctionManager remoteFunctionManager;
|
||||
protected FunctionManager functionManager;
|
||||
protected PathConfig pathConfig;
|
||||
|
||||
/**
|
||||
@@ -121,13 +118,11 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
protected void init(
|
||||
RayletClient slink,
|
||||
ObjectStoreLink plink,
|
||||
RemoteFunctionManager remoteLoader,
|
||||
PathConfig pathManager
|
||||
) {
|
||||
remoteFunctionManager = remoteLoader;
|
||||
pathConfig = pathManager;
|
||||
|
||||
functions = new LocalFunctionManager(remoteLoader);
|
||||
functionManager = new FunctionManager();
|
||||
rayletClient = slink;
|
||||
|
||||
objectStoreProxy = new ObjectStoreProxy(plink);
|
||||
@@ -308,6 +303,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
}
|
||||
RayActorImpl actorImpl = (RayActorImpl)actor;
|
||||
TaskSpec spec = createTaskSpec(func, actorImpl, args, false);
|
||||
spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor());
|
||||
actorImpl.setTaskCursor(spec.returnIds[1]);
|
||||
rayletClient.submitTask(spec);
|
||||
return new RayObjectImpl(spec.returnIds[0]);
|
||||
@@ -357,33 +353,20 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
actorCreationId = returnIds[0];
|
||||
}
|
||||
|
||||
MethodId methodId = MethodId.fromSerializedLambda(func);
|
||||
|
||||
// NOTE: we append the class name at the end of arguments,
|
||||
// so that we can look up the method based on the class name.
|
||||
// TODO(hchen): move class name to task spec.
|
||||
args = Arrays.copyOf(args, args.length + 1);
|
||||
args[args.length - 1] = methodId.className;
|
||||
|
||||
RayMethod rayMethod = functions.getMethod(
|
||||
current.driverId, actor.getId(), new UniqueId(methodId.getSha1Hash()), methodId.className
|
||||
).getRight();
|
||||
UniqueId funcId = rayMethod.getFuncId();
|
||||
|
||||
RayFunction rayFunction = functionManager.getFunction(current.driverId, func);
|
||||
return new TaskSpec(
|
||||
current.driverId,
|
||||
taskId,
|
||||
current.taskId,
|
||||
-1,
|
||||
actorCreationId,
|
||||
actor.getId(),
|
||||
actor.getHandleId(),
|
||||
actor.increaseTaskCounter(),
|
||||
funcId,
|
||||
ArgumentsBuilder.wrap(args),
|
||||
returnIds,
|
||||
actor.getHandleId(),
|
||||
actorCreationId,
|
||||
ResourceUtil.getResourcesMapFromArray(rayMethod.remoteAnnotation),
|
||||
actor.getTaskCursor()
|
||||
ResourceUtil.getResourcesMapFromArray(rayFunction.getRayRemoteAnnotation()),
|
||||
rayFunction.getFunctionDescriptor()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -399,8 +382,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
return rayletClient;
|
||||
}
|
||||
|
||||
public LocalFunctionManager getLocalFunctionManager() {
|
||||
return functions;
|
||||
public FunctionManager getFunctionManager() {
|
||||
return functionManager;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,8 +2,6 @@ package org.ray.runtime;
|
||||
|
||||
import org.ray.runtime.config.PathConfig;
|
||||
import org.ray.runtime.config.RayParameters;
|
||||
import org.ray.runtime.functionmanager.NopRemoteFunctionManager;
|
||||
import org.ray.runtime.functionmanager.RemoteFunctionManager;
|
||||
import org.ray.runtime.objectstore.MockObjectStore;
|
||||
import org.ray.runtime.raylet.MockRayletClient;
|
||||
|
||||
@@ -12,11 +10,9 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
||||
@Override
|
||||
public void start(RayParameters params) {
|
||||
PathConfig pathConfig = new PathConfig(configReader);
|
||||
RemoteFunctionManager rfm = new NopRemoteFunctionManager(params.driver_id);
|
||||
MockObjectStore store = new MockObjectStore();
|
||||
MockRayletClient scheduler = new MockRayletClient(this, store);
|
||||
init(scheduler, store, rfm, pathConfig);
|
||||
scheduler.setLocalFunctionManager(this.functions);
|
||||
init(scheduler, store, pathConfig);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -8,9 +8,6 @@ import org.apache.arrow.plasma.PlasmaClient;
|
||||
import org.ray.runtime.config.PathConfig;
|
||||
import org.ray.runtime.config.RayParameters;
|
||||
import org.ray.runtime.config.WorkerMode;
|
||||
import org.ray.runtime.functionmanager.NativeRemoteFunctionManager;
|
||||
import org.ray.runtime.functionmanager.NopRemoteFunctionManager;
|
||||
import org.ray.runtime.functionmanager.RemoteFunctionManager;
|
||||
import org.ray.runtime.gcs.AddressInfo;
|
||||
import org.ray.runtime.gcs.KeyValueStoreLink;
|
||||
import org.ray.runtime.gcs.RedisClient;
|
||||
@@ -61,10 +58,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
// initialize remote function manager
|
||||
RemoteFunctionManager funcMgr = params.run_mode.isDevPathManager()
|
||||
? new NopRemoteFunctionManager(params.driver_id) : new NativeRemoteFunctionManager(kvStore);
|
||||
|
||||
// initialize worker context
|
||||
if (params.worker_mode == WorkerMode.DRIVER) {
|
||||
// TODO: The relationship between workerID, driver_id and dummy_task.driver_id should be
|
||||
@@ -96,7 +89,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
WorkerContext.currentTask().taskId
|
||||
);
|
||||
|
||||
init(rayletClient, plink, funcMgr, pathConfig);
|
||||
init(rayletClient, plink, pathConfig);
|
||||
|
||||
// register
|
||||
registerWorker(isWorker, params.node_ip_address, params.object_store_name,
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.functionmanager.RayMethod;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.task.ArgumentsBuilder;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.ray.runtime.util.logger.RayLog;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* The worker, which pulls tasks from {@code org.ray.spi.LocalSchedulerProxy} and executes them
|
||||
* The worker, which pulls tasks from {@link org.ray.runtime.raylet.RayletClient} and executes them
|
||||
* continuously.
|
||||
*/
|
||||
public class Worker {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(Worker.class);
|
||||
|
||||
private final AbstractRayRuntime runtime;
|
||||
|
||||
public Worker(AbstractRayRuntime runtime) {
|
||||
@@ -22,7 +25,7 @@ public class Worker {
|
||||
|
||||
public void loop() {
|
||||
while (true) {
|
||||
RayLog.core.info(Thread.currentThread().getName() + ":fetching new task...");
|
||||
LOGGER.info("Fetching new task in thread {}.", Thread.currentThread().getName());
|
||||
TaskSpec task = runtime.getRayletClient().getTask();
|
||||
execute(task);
|
||||
}
|
||||
@@ -32,27 +35,26 @@ public class Worker {
|
||||
* Execute a task.
|
||||
*/
|
||||
public void execute(TaskSpec spec) {
|
||||
RayLog.core.info("Executing task {}", spec.taskId);
|
||||
LOGGER.info("Executing task {}", spec.taskId);
|
||||
LOGGER.debug("Executing task {}", spec);
|
||||
UniqueId returnId = spec.returnIds[0];
|
||||
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
|
||||
try {
|
||||
// Get method
|
||||
Pair<ClassLoader, RayMethod> pair = runtime.getLocalFunctionManager().getMethod(
|
||||
spec.driverId, spec.actorId, spec.functionId, spec.args);
|
||||
ClassLoader classLoader = pair.getLeft();
|
||||
RayMethod method = pair.getRight();
|
||||
RayFunction rayFunction = runtime.getFunctionManager()
|
||||
.getFunction(spec.driverId, spec.functionDescriptor);
|
||||
// Set context
|
||||
WorkerContext.prepare(spec, classLoader);
|
||||
Thread.currentThread().setContextClassLoader(classLoader);
|
||||
WorkerContext.prepare(spec, rayFunction.classLoader);
|
||||
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
|
||||
// Get local actor object and arguments.
|
||||
Object actor = spec.isActorTask() ? runtime.localActors.get(spec.actorId) : null;
|
||||
Object[] args = ArgumentsBuilder.unwrap(spec, classLoader);
|
||||
Object[] args = ArgumentsBuilder.unwrap(spec, rayFunction.classLoader);
|
||||
// Execute the task.
|
||||
Object result;
|
||||
if (!method.isConstructor()) {
|
||||
result = method.getMethod().invoke(actor, args);
|
||||
if (!rayFunction.isConstructor()) {
|
||||
result = rayFunction.getMethod().invoke(actor, args);
|
||||
} else {
|
||||
result = method.getConstructor().newInstance(args);
|
||||
result = rayFunction.getConstructor().newInstance(args);
|
||||
}
|
||||
// Set result
|
||||
if (!spec.isActorCreationTask()) {
|
||||
|
||||
@@ -34,15 +34,20 @@ public class WorkerContext {
|
||||
WorkerContext ctx = new WorkerContext();
|
||||
currentWorkerCtx.set(ctx);
|
||||
|
||||
TaskSpec dummy = new TaskSpec();
|
||||
dummy.parentTaskId = UniqueId.NIL;
|
||||
if (params.worker_mode == WorkerMode.DRIVER) {
|
||||
dummy.taskId = UniqueId.randomId();
|
||||
} else {
|
||||
dummy.taskId = UniqueId.NIL;
|
||||
}
|
||||
dummy.actorId = UniqueId.NIL;
|
||||
dummy.driverId = params.driver_id;
|
||||
TaskSpec dummy = new TaskSpec(
|
||||
params.driver_id,
|
||||
params.worker_mode == WorkerMode.DRIVER ? UniqueId.randomId() : UniqueId.NIL,
|
||||
UniqueId.NIL,
|
||||
0,
|
||||
UniqueId.NIL,
|
||||
UniqueId.NIL,
|
||||
UniqueId.NIL,
|
||||
0,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
prepare(dummy, null);
|
||||
|
||||
return ctx;
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
|
||||
/**
|
||||
* Represents the function's metadata.
|
||||
*/
|
||||
public final class FunctionDescriptor {
|
||||
|
||||
/**
|
||||
* Function's class name.
|
||||
*/
|
||||
public final String className;
|
||||
/**
|
||||
* Function's name.
|
||||
*/
|
||||
public final String name;
|
||||
/**
|
||||
* Function's type descriptor.
|
||||
*/
|
||||
public final String typeDescriptor;
|
||||
|
||||
public FunctionDescriptor(String className, String name, String typeDescriptor) {
|
||||
this.className = className;
|
||||
this.name = name;
|
||||
this.typeDescriptor = typeDescriptor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return className + "." + name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
FunctionDescriptor that = (FunctionDescriptor) o;
|
||||
return Objects.equal(className, that.className) &&
|
||||
Objects.equal(name, that.name) &&
|
||||
Objects.equal(typeDescriptor, that.typeDescriptor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(className, name, typeDescriptor);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import java.lang.invoke.SerializedLambda;
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Executable;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.WeakHashMap;
|
||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.objectweb.asm.Type;
|
||||
import org.ray.api.function.RayFunc;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.util.LambdaUtils;
|
||||
|
||||
/**
|
||||
* Manages functions by driver id.
|
||||
*/
|
||||
public class FunctionManager {
|
||||
|
||||
static final String CONSTRUCTOR_NAME = "<init>";
|
||||
|
||||
/**
|
||||
* Cache from a RayFunc object to its corresponding FunctionDescriptor. Because
|
||||
* `LambdaUtils.getSerializedLambda` is expensive.
|
||||
*/
|
||||
private static final ThreadLocal<WeakHashMap<Class<RayFunc>, FunctionDescriptor>>
|
||||
RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new);
|
||||
|
||||
/**
|
||||
* Mapping from the driver id to the functions that belong to this driver.
|
||||
*/
|
||||
private Map<UniqueId, DriverFunctionTable> driverFunctionTables = new HashMap<>();
|
||||
|
||||
/**
|
||||
* Get the RayFunction from a RayFunc instance (a lambda).
|
||||
*
|
||||
* @param driverId current driver id.
|
||||
* @param func The lambda.
|
||||
* @return A RayFunction object.
|
||||
*/
|
||||
public RayFunction getFunction(UniqueId driverId, RayFunc func) {
|
||||
FunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
|
||||
if (functionDescriptor == null) {
|
||||
SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func);
|
||||
final String className = serializedLambda.getImplClass().replace('/', '.');
|
||||
final String methodName = serializedLambda.getImplMethodName();
|
||||
final String typeDescriptor = serializedLambda.getImplMethodSignature();
|
||||
functionDescriptor = new FunctionDescriptor(className, methodName, typeDescriptor);
|
||||
}
|
||||
return getFunction(driverId, functionDescriptor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the RayFunction from a function descriptor.
|
||||
*
|
||||
* @param driverId Current driver id.
|
||||
* @param functionDescriptor The function descriptor.
|
||||
* @return A RayFunction object.
|
||||
*/
|
||||
public RayFunction getFunction(UniqueId driverId, FunctionDescriptor functionDescriptor) {
|
||||
DriverFunctionTable driverFunctionTable = driverFunctionTables.get(driverId);
|
||||
if (driverFunctionTable == null) {
|
||||
//TODO(hchen): distinguish class loader by driver id.
|
||||
ClassLoader classLoader = getClass().getClassLoader();
|
||||
driverFunctionTable = new DriverFunctionTable(classLoader);
|
||||
driverFunctionTables.put(driverId, driverFunctionTable);
|
||||
}
|
||||
return driverFunctionTable.getFunction(functionDescriptor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Manages all functions that belong to one driver.
|
||||
*/
|
||||
static class DriverFunctionTable {
|
||||
|
||||
/**
|
||||
* The driver's corresponding class loader.
|
||||
*/
|
||||
ClassLoader classLoader;
|
||||
/**
|
||||
* Functions per class, per function name + type descriptor.
|
||||
*/
|
||||
Map<String, Map<Pair<String, String>, RayFunction>> functions;
|
||||
|
||||
DriverFunctionTable(ClassLoader classLoader) {
|
||||
this.classLoader = classLoader;
|
||||
this.functions = new HashMap<>();
|
||||
}
|
||||
|
||||
RayFunction getFunction(FunctionDescriptor descriptor) {
|
||||
Map<Pair<String, String>, RayFunction> classFunctions = functions.get(descriptor.className);
|
||||
if (classFunctions == null) {
|
||||
classFunctions = loadFunctionsForClass(descriptor.className);
|
||||
functions.put(descriptor.className, classFunctions);
|
||||
}
|
||||
return classFunctions.get(ImmutablePair.of(descriptor.name, descriptor.typeDescriptor));
|
||||
}
|
||||
|
||||
/**
|
||||
* Load all functions from a class.
|
||||
*/
|
||||
Map<Pair<String, String>, RayFunction> loadFunctionsForClass(String className) {
|
||||
Map<Pair<String, String>, RayFunction> map = new HashMap<>();
|
||||
try {
|
||||
Class clazz = Class.forName(className, true, classLoader);
|
||||
|
||||
List<Executable> executables = new ArrayList<>();
|
||||
executables.addAll(Arrays.asList(clazz.getDeclaredMethods()));
|
||||
executables.addAll(Arrays.asList(clazz.getConstructors()));
|
||||
|
||||
for (Executable e : executables) {
|
||||
e.setAccessible(true);
|
||||
final String methodName = e instanceof Method ? e.getName() : CONSTRUCTOR_NAME;
|
||||
final Type type =
|
||||
e instanceof Method ? Type.getType((Method) e) : Type.getType((Constructor) e);
|
||||
final String typeDescriptor = type.getDescriptor();
|
||||
RayFunction rayFunction = new RayFunction(e, classLoader,
|
||||
new FunctionDescriptor(className, methodName, typeDescriptor));
|
||||
map.put(ImmutablePair.of(methodName, typeDescriptor), rayFunction);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to load functions from class " + className, e);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.task.FunctionArg;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
import org.ray.runtime.util.logger.RayLog;
|
||||
|
||||
/**
|
||||
* local function manager which pulls remote functions on demand.
|
||||
*/
|
||||
public class LocalFunctionManager {
|
||||
|
||||
private final RemoteFunctionManager remoteLoader;
|
||||
|
||||
private final ConcurrentHashMap<UniqueId, FunctionTable> functionTables
|
||||
= new ConcurrentHashMap<>();
|
||||
|
||||
/**
|
||||
* initialize load function manager using remote function manager to pull remote functions on
|
||||
* demand.
|
||||
*/
|
||||
public LocalFunctionManager(RemoteFunctionManager remoteLoader) {
|
||||
this.remoteLoader = remoteLoader;
|
||||
}
|
||||
|
||||
private FunctionTable loadDriverFunctions(UniqueId driverId) {
|
||||
FunctionTable functionTable = functionTables.get(driverId);
|
||||
if (functionTable == null) {
|
||||
RayLog.core.info("DriverId " + driverId + " Try to load functions");
|
||||
ClassLoader classLoader = remoteLoader.loadResource(driverId);
|
||||
if (classLoader == null) {
|
||||
throw new RuntimeException(
|
||||
"Cannot find resource' classLoader for app " + driverId.toString());
|
||||
}
|
||||
functionTable = new FunctionTable(classLoader);
|
||||
functionTables.put(driverId, functionTable);
|
||||
}
|
||||
return functionTable;
|
||||
}
|
||||
|
||||
public Pair<ClassLoader, RayMethod> getMethod(UniqueId driverId, UniqueId actorId,
|
||||
UniqueId methodId, String className) {
|
||||
// assert the driver's resource is load.
|
||||
FunctionTable functionTable = loadDriverFunctions(driverId);
|
||||
Preconditions.checkNotNull(functionTable, "driver's resource is not loaded:%s", driverId);
|
||||
RayMethod method = actorId.isNil() ? functionTable.getTaskMethod(methodId, className)
|
||||
: functionTable.getActorMethod(methodId, className);
|
||||
Preconditions
|
||||
.checkNotNull(method, "method not found, class=%s, methodId=%s, driverId=%s", className,
|
||||
methodId, driverId);
|
||||
return Pair.of(functionTable.classLoader, method);
|
||||
}
|
||||
|
||||
/**
|
||||
* get local method for executing, which pulls information from remote repo on-demand, therefore
|
||||
* it may block for a while if the related resources (e.g., jars) are not ready on local machine
|
||||
*/
|
||||
public Pair<ClassLoader, RayMethod> getMethod(UniqueId driverId, UniqueId actorId,
|
||||
UniqueId methodId, FunctionArg[] args) {
|
||||
Preconditions.checkArgument(args.length >= 1, "method's args len %s<=1", args.length);
|
||||
String className = (String) Serializer.decode(args[args.length - 1].data);
|
||||
return getMethod(driverId, actorId, methodId, className);
|
||||
}
|
||||
|
||||
/**
|
||||
* unload the functions when the driver is declared dead.
|
||||
*/
|
||||
public synchronized void removeApp(UniqueId driverId) {
|
||||
FunctionTable funcs = functionTables.get(driverId);
|
||||
if (funcs != null) {
|
||||
functionTables.remove(driverId);
|
||||
remoteLoader.unloadFunctions(driverId);
|
||||
}
|
||||
}
|
||||
|
||||
private static class FunctionTable {
|
||||
|
||||
final ClassLoader classLoader;
|
||||
final ConcurrentHashMap<String, RayTaskMethods> taskMethods = new ConcurrentHashMap<>();
|
||||
final ConcurrentHashMap<String, RayActorMethods> actorMethods = new ConcurrentHashMap<>();
|
||||
|
||||
FunctionTable(ClassLoader classLoader) {
|
||||
this.classLoader = classLoader;
|
||||
}
|
||||
|
||||
RayMethod getTaskMethod(UniqueId methodId, String className) {
|
||||
RayTaskMethods taskMethods = this.taskMethods.get(className);
|
||||
if (taskMethods == null) {
|
||||
taskMethods = RayTaskMethods.fromClass(className, classLoader);
|
||||
RayLog.core.info("create RayTaskMethods: {}", taskMethods);
|
||||
this.taskMethods.put(className, taskMethods);
|
||||
}
|
||||
RayMethod m = taskMethods.functions.get(methodId);
|
||||
if (m != null) {
|
||||
return m;
|
||||
}
|
||||
// it is a actor static func.
|
||||
return getActorMethod(methodId, className, true);
|
||||
}
|
||||
|
||||
RayMethod getActorMethod(UniqueId methodId, String className) {
|
||||
return getActorMethod(methodId, className, false);
|
||||
}
|
||||
|
||||
private RayMethod getActorMethod(UniqueId methodId, String className, boolean isStatic) {
|
||||
RayActorMethods actorMethods = this.actorMethods.get(className);
|
||||
if (actorMethods == null) {
|
||||
actorMethods = RayActorMethods.fromClass(className, classLoader);
|
||||
RayLog.core.info("create RayActorMethods: {}", actorMethods);
|
||||
this.actorMethods.put(className, actorMethods);
|
||||
}
|
||||
return isStatic ? actorMethods.staticFunctions.get(methodId)
|
||||
: actorMethods.functions.get(methodId);
|
||||
}
|
||||
}
|
||||
}
|
||||
-131
@@ -1,131 +0,0 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import java.io.File;
|
||||
import java.security.MessageDigest;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import net.lingala.zip4j.core.ZipFile;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.gcs.KeyValueStoreLink;
|
||||
import org.ray.runtime.util.FileUtil;
|
||||
import org.ray.runtime.util.JarLoader;
|
||||
import org.ray.runtime.util.Sha1Digestor;
|
||||
import org.ray.runtime.util.SystemUtil;
|
||||
import org.ray.runtime.util.logger.RayLog;
|
||||
|
||||
/**
|
||||
* native implementation of remote function manager.
|
||||
*/
|
||||
public class NativeRemoteFunctionManager implements RemoteFunctionManager {
|
||||
|
||||
private final ConcurrentHashMap<UniqueId, ClassLoader> loadedApps = new ConcurrentHashMap<>();
|
||||
private MessageDigest md;
|
||||
private final String appDir = System.getProperty("user.dir") + "/apps";
|
||||
private final KeyValueStoreLink kvStore;
|
||||
|
||||
|
||||
public NativeRemoteFunctionManager(KeyValueStoreLink kvStore) throws NoSuchAlgorithmException {
|
||||
this.kvStore = kvStore;
|
||||
md = MessageDigest.getInstance("SHA-1");
|
||||
File appDir = new File(this.appDir);
|
||||
if (!appDir.exists()) {
|
||||
appDir.mkdirs();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId registerResource(byte[] resourceZip) {
|
||||
byte[] digest = Sha1Digestor.digest(resourceZip);
|
||||
assert (digest.length == UniqueId.LENGTH);
|
||||
|
||||
UniqueId resourceId = new UniqueId(digest);
|
||||
|
||||
// TODO: resources must be saved in persistent store
|
||||
kvStore.set(resourceId.getBytes(), resourceZip, null);
|
||||
|
||||
return resourceId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] getResource(UniqueId resourceId) {
|
||||
return kvStore.get(resourceId.getBytes(), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void unregisterResource(UniqueId resourceId) {
|
||||
kvStore.delete(resourceId.getBytes(), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void registerApp(UniqueId driverId, UniqueId resourceId) {
|
||||
kvStore.set("App2ResMap", resourceId.toString(), driverId.toString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId getAppResourceId(UniqueId driverId) {
|
||||
return UniqueId.fromHexString(kvStore.get("App2ResMap", driverId.toString()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void unregisterApp(UniqueId driverId) {
|
||||
kvStore.delete("App2ResMap", driverId.toString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClassLoader loadResource(UniqueId driverId) {
|
||||
ClassLoader classLoader = loadedApps.get(driverId);
|
||||
if (classLoader == null) {
|
||||
synchronized (this) {
|
||||
classLoader = loadedApps.get(driverId);
|
||||
if (classLoader == null) {
|
||||
classLoader = initLoadedApps(driverId);
|
||||
}
|
||||
}
|
||||
}
|
||||
return classLoader;
|
||||
}
|
||||
|
||||
private ClassLoader initLoadedApps(UniqueId driverId) {
|
||||
try {
|
||||
RayLog.core.info("initLoadedApps" + driverId.toString());
|
||||
|
||||
ClassLoader cl = loadedApps.get(driverId);
|
||||
if (cl == null) {
|
||||
UniqueId resId = UniqueId.fromHexString(kvStore.get("App2ResMap", driverId.toString()));
|
||||
byte[] res = getResource(resId);
|
||||
if (res == null) {
|
||||
throw new RuntimeException("get resource null, the resId " + resId.toString());
|
||||
}
|
||||
RayLog.core.info("get resource of " + resId.toString() + ", result len " + res.length);
|
||||
String resPath =
|
||||
appDir + "/" + driverId.toString() + "/" + String.valueOf(SystemUtil.pid());
|
||||
File dir = new File(resPath);
|
||||
if (!dir.exists()) {
|
||||
dir.mkdirs();
|
||||
}
|
||||
String zipPath = resPath + ".zip";
|
||||
RayLog.rapp.info("unzip app file: zipPath " + zipPath + " resPath " + resPath);
|
||||
FileUtil.bytesToFile(res, zipPath);
|
||||
ZipFile zipFile = new ZipFile(zipPath);
|
||||
zipFile.extractAll(resPath);
|
||||
cl = JarLoader.loadJars(resPath, false);
|
||||
loadedApps.put(driverId, cl);
|
||||
}
|
||||
return cl;
|
||||
} catch (Exception e) {
|
||||
RayLog.rapp.error("load function for " + driverId + " failed, ex = " + e.getMessage(), e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void unloadFunctions(UniqueId driverId) {
|
||||
ClassLoader cl = loadedApps.get(driverId);
|
||||
try {
|
||||
JarLoader.unloadJars(cl);
|
||||
} catch (Exception e) {
|
||||
RayLog.rapp.error("unload function for " + driverId + " failed, ex = " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
-58
@@ -1,58 +0,0 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import org.ray.api.id.UniqueId;
|
||||
|
||||
/**
|
||||
* mock version of remote function manager using local loaded jars.
|
||||
*/
|
||||
public class NopRemoteFunctionManager implements RemoteFunctionManager {
|
||||
|
||||
public NopRemoteFunctionManager(UniqueId driverId) {
|
||||
//onLoad(driverId, Agent.hookedMethods);
|
||||
//Agent.consumers.add(m -> { this.onLoad(m); });
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId registerResource(byte[] resourceZip) {
|
||||
return null;
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] getResource(UniqueId resourceId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void unregisterResource(UniqueId resourceId) {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
@Override
|
||||
public void registerApp(UniqueId driverId, UniqueId resourceId) {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId getAppResourceId(UniqueId driverId) {
|
||||
return null;
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
@Override
|
||||
public void unregisterApp(UniqueId driverId) {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClassLoader loadResource(UniqueId driverId) {
|
||||
//assert (startupDriverId().equals(driverId));
|
||||
return this.getClass().getClassLoader();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void unloadFunctions(UniqueId driverId) {
|
||||
// never
|
||||
//assert (startupDriverId().equals(driverId));
|
||||
}
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.lang.reflect.Executable;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.id.UniqueId;
|
||||
|
||||
|
||||
public final class RayActorMethods {
|
||||
|
||||
public final Class clazz;
|
||||
public final RayRemote remoteAnnotation;
|
||||
public final Map<UniqueId, RayMethod> functions;
|
||||
/**
|
||||
* the static function in Actor, call as task.
|
||||
*/
|
||||
public final Map<UniqueId, RayMethod> staticFunctions;
|
||||
|
||||
private RayActorMethods(Class clazz, RayRemote remoteAnnotation,
|
||||
Map<UniqueId, RayMethod> functions, Map<UniqueId, RayMethod> staticFunctions) {
|
||||
this.clazz = clazz;
|
||||
this.remoteAnnotation = remoteAnnotation;
|
||||
this.functions = Collections.unmodifiableMap(new HashMap<>(functions));
|
||||
this.staticFunctions = Collections.unmodifiableMap(new HashMap<>(staticFunctions));
|
||||
}
|
||||
|
||||
public static RayActorMethods fromClass(String className, ClassLoader classLoader) {
|
||||
try {
|
||||
Class clazz = Class.forName(className, true, classLoader);
|
||||
RayRemote remoteAnnotation = (RayRemote) clazz.getAnnotation(RayRemote.class);
|
||||
Preconditions.checkNotNull(remoteAnnotation,
|
||||
"%s must be annotated with @RayRemote", className);
|
||||
|
||||
List<Executable> executables = new ArrayList<>(Arrays.asList(clazz.getDeclaredMethods()));
|
||||
|
||||
Map<UniqueId, RayMethod> functions = new HashMap<>();
|
||||
Map<UniqueId, RayMethod> staticFunctions = new HashMap<>();
|
||||
|
||||
for (Executable e : executables) {
|
||||
RayMethod rayMethod = RayMethod.from(e, remoteAnnotation);
|
||||
if (Modifier.isStatic(e.getModifiers())) {
|
||||
staticFunctions.put(rayMethod.getFuncId(), rayMethod);
|
||||
} else {
|
||||
functions.put(rayMethod.getFuncId(), rayMethod);
|
||||
}
|
||||
}
|
||||
return new RayActorMethods(clazz, remoteAnnotation, functions, staticFunctions);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("failed to get RayActorMethods from " + className, e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String
|
||||
.format("RayActorMethods:%s, funcNum=%s:{%s}, sfuncNum=%s:{%s}", clazz, functions.size(),
|
||||
functions.values(),
|
||||
staticFunctions.size(), staticFunctions.values());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Executable;
|
||||
import java.lang.reflect.Method;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
|
||||
/**
|
||||
* Represents a Ray function (either a Method or a Constructor in Java) and its metadata.
|
||||
*/
|
||||
public class RayFunction {
|
||||
|
||||
/**
|
||||
* The executor object, can be either a Method or a Constructor.
|
||||
*/
|
||||
public final Executable executable;
|
||||
|
||||
/**
|
||||
* This function's class loader.
|
||||
*/
|
||||
public final ClassLoader classLoader;
|
||||
|
||||
/**
|
||||
* Function's metadata.
|
||||
*/
|
||||
public final FunctionDescriptor functionDescriptor;
|
||||
|
||||
public RayFunction(Executable executable, ClassLoader classLoader,
|
||||
FunctionDescriptor functionDescriptor) {
|
||||
this.executable = executable;
|
||||
this.classLoader = classLoader;
|
||||
this.functionDescriptor = functionDescriptor;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return True if it's a constructor, otherwise it's a method.
|
||||
*/
|
||||
public boolean isConstructor() {
|
||||
return executable instanceof Constructor;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The underlying constructor object.
|
||||
*/
|
||||
public Constructor<?> getConstructor() {
|
||||
return (Constructor<?>) executable;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The underlying method object.
|
||||
*/
|
||||
public Method getMethod() {
|
||||
return (Method) executable;
|
||||
}
|
||||
|
||||
public FunctionDescriptor getFunctionDescriptor() {
|
||||
return functionDescriptor;
|
||||
}
|
||||
|
||||
public RayRemote getRayRemoteAnnotation() {
|
||||
RayRemote rayRemote = executable.getAnnotation(RayRemote.class);
|
||||
if (rayRemote == null) {
|
||||
// If the method doesn't have a annotation, get the annotation from
|
||||
// its wrapping class.
|
||||
rayRemote = executable.getDeclaringClass().getAnnotation(RayRemote.class);
|
||||
}
|
||||
return rayRemote;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return executable.toString();
|
||||
}
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Executable;
|
||||
import java.lang.reflect.Method;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.util.MethodId;
|
||||
|
||||
/**
|
||||
* method info.
|
||||
*/
|
||||
public class RayMethod {
|
||||
|
||||
public final Executable invokable;
|
||||
public final String fullName;
|
||||
public final RayRemote remoteAnnotation;
|
||||
private final UniqueId funcId;
|
||||
|
||||
private RayMethod(Executable e, RayRemote remoteAnnotation, UniqueId funcId) {
|
||||
this.invokable = e;
|
||||
this.remoteAnnotation = remoteAnnotation;
|
||||
this.funcId = funcId;
|
||||
fullName = e.getDeclaringClass().getName() + "." + e.getName();
|
||||
}
|
||||
|
||||
public static RayMethod from(Executable e, RayRemote parentRemoteAnnotation) {
|
||||
RayRemote remoteAnnotation = e.getAnnotation(RayRemote.class);
|
||||
MethodId mid = MethodId.fromExecutable(e);
|
||||
UniqueId funcId = new UniqueId(mid.getSha1Hash());
|
||||
RayMethod method = new RayMethod(e,
|
||||
remoteAnnotation != null ? remoteAnnotation : parentRemoteAnnotation,
|
||||
funcId);
|
||||
return method;
|
||||
}
|
||||
|
||||
public boolean isConstructor() {
|
||||
return invokable instanceof Constructor;
|
||||
}
|
||||
|
||||
public Constructor<?> getConstructor() {
|
||||
return (Constructor<?>) invokable;
|
||||
}
|
||||
|
||||
public Method getMethod() {
|
||||
return (Method) invokable;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return fullName;
|
||||
}
|
||||
|
||||
public UniqueId getFuncId() {
|
||||
return funcId;
|
||||
}
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Executable;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.ray.api.id.UniqueId;
|
||||
|
||||
|
||||
public final class RayTaskMethods {
|
||||
|
||||
public final Class clazz;
|
||||
public final Map<UniqueId, RayMethod> functions;
|
||||
|
||||
public RayTaskMethods(Class clazz,
|
||||
Map<UniqueId, RayMethod> functions) {
|
||||
this.clazz = clazz;
|
||||
this.functions = Collections.unmodifiableMap(new HashMap<>(functions));
|
||||
}
|
||||
|
||||
public static RayTaskMethods fromClass(String clazzName, ClassLoader classLoader) {
|
||||
try {
|
||||
Class clazz = Class.forName(clazzName, true, classLoader);
|
||||
List<Executable> executables = new ArrayList<>();
|
||||
executables.addAll(Arrays.asList(clazz.getDeclaredMethods()));
|
||||
executables.addAll(Arrays.asList(clazz.getConstructors()));
|
||||
Map<UniqueId, RayMethod> functions = new HashMap<>(executables.size());
|
||||
|
||||
for (Executable e : executables) {
|
||||
// This executable must be either a constructor or a static method.
|
||||
if (!(e instanceof Constructor)
|
||||
&& !Modifier.isStatic(e.getModifiers())) {
|
||||
continue;
|
||||
}
|
||||
e.setAccessible(true);
|
||||
RayMethod rayMethod = RayMethod.from(e, null);
|
||||
functions.put(rayMethod.getFuncId(), rayMethod);
|
||||
}
|
||||
return new RayTaskMethods(clazz, functions);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("failed to get RayTaskMethods from " + clazzName, e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String
|
||||
.format("RayTaskMethods:%s, funcNum=%s:{%s}", clazz, functions.size(), functions.values());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import org.ray.api.id.UniqueId;
|
||||
|
||||
/**
|
||||
* register and load functions from function table.
|
||||
*/
|
||||
public interface RemoteFunctionManager {
|
||||
|
||||
/*
|
||||
* register <resourceId, resource> mapping, and upload resource.
|
||||
* this function is invoked by app proxy or other stand-alone tools it should detect for
|
||||
* duplication first though
|
||||
*
|
||||
* @param resourceZip a directory zip from @JarRewriter
|
||||
* @return SHA-1 hash of the content
|
||||
*/
|
||||
UniqueId registerResource(byte[] resourceZip);
|
||||
|
||||
/**
|
||||
* download resource content.
|
||||
*
|
||||
* @return resource content
|
||||
*/
|
||||
byte[] getResource(UniqueId resourceId);
|
||||
|
||||
/**
|
||||
* remove resource by its hash id
|
||||
* be careful of invoking this function to make sure it is no longer used.
|
||||
*
|
||||
* @param resourceId SHA-1 hash of the resource zip bytes
|
||||
*/
|
||||
void unregisterResource(UniqueId resourceId);
|
||||
|
||||
/*
|
||||
* register the <driver, resource> mapping to repo,
|
||||
* this function is invoked by whoever initiates the driver id
|
||||
*/
|
||||
void registerApp(UniqueId driverId, UniqueId resourceId);
|
||||
|
||||
/**
|
||||
* get the resourceId of one app.
|
||||
*
|
||||
* @return resourceId of the app driver
|
||||
*/
|
||||
UniqueId getAppResourceId(UniqueId driverId);
|
||||
|
||||
/*
|
||||
* unregister <dirver, resource> mapping
|
||||
* this function is called when the driver exits or detected dead
|
||||
*/
|
||||
void unregisterApp(UniqueId driverId);
|
||||
|
||||
/**
|
||||
* load resource.
|
||||
*/
|
||||
ClassLoader loadResource(UniqueId driverId);
|
||||
|
||||
/**
|
||||
* unload functions for this driver
|
||||
* this function is used by the workers on demand when a driver is dead.
|
||||
*/
|
||||
void unloadFunctions(UniqueId driverId);
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
// automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
package org.ray.runtime.generated;
|
||||
|
||||
import java.nio.*;
|
||||
import java.lang.*;
|
||||
import com.google.flatbuffers.*;
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.google.flatbuffers.Table;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
|
||||
@SuppressWarnings("unused")
|
||||
public final class Arg extends Table {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
package org.ray.runtime.generated;
|
||||
|
||||
import java.nio.*;
|
||||
import java.lang.*;
|
||||
import com.google.flatbuffers.*;
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.google.flatbuffers.Table;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
|
||||
@SuppressWarnings("unused")
|
||||
public final class TaskInfo extends Table {
|
||||
@@ -49,6 +49,8 @@ public final class TaskInfo extends Table {
|
||||
public ResourcePair requiredResources(ResourcePair obj, int j) { int o = __offset(30); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; }
|
||||
public int requiredResourcesLength() { int o = __offset(30); return o != 0 ? __vector_len(o) : 0; }
|
||||
public int language() { int o = __offset(32); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
|
||||
public String functionDescriptor(int j) { int o = __offset(34); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int functionDescriptorLength() { int o = __offset(34); return o != 0 ? __vector_len(o) : 0; }
|
||||
|
||||
public static int createTaskInfo(FlatBufferBuilder builder,
|
||||
int driver_idOffset,
|
||||
@@ -65,8 +67,10 @@ public final class TaskInfo extends Table {
|
||||
int argsOffset,
|
||||
int returnsOffset,
|
||||
int required_resourcesOffset,
|
||||
int language) {
|
||||
builder.startObject(15);
|
||||
int language,
|
||||
int function_descriptorOffset) {
|
||||
builder.startObject(16);
|
||||
TaskInfo.addFunctionDescriptor(builder, function_descriptorOffset);
|
||||
TaskInfo.addLanguage(builder, language);
|
||||
TaskInfo.addRequiredResources(builder, required_resourcesOffset);
|
||||
TaskInfo.addReturns(builder, returnsOffset);
|
||||
@@ -85,7 +89,7 @@ public final class TaskInfo extends Table {
|
||||
return TaskInfo.endTaskInfo(builder);
|
||||
}
|
||||
|
||||
public static void startTaskInfo(FlatBufferBuilder builder) { builder.startObject(15); }
|
||||
public static void startTaskInfo(FlatBufferBuilder builder) { builder.startObject(16); }
|
||||
public static void addDriverId(FlatBufferBuilder builder, int driverIdOffset) { builder.addOffset(0, driverIdOffset, 0); }
|
||||
public static void addTaskId(FlatBufferBuilder builder, int taskIdOffset) { builder.addOffset(1, taskIdOffset, 0); }
|
||||
public static void addParentTaskId(FlatBufferBuilder builder, int parentTaskIdOffset) { builder.addOffset(2, parentTaskIdOffset, 0); }
|
||||
@@ -107,6 +111,9 @@ public final class TaskInfo extends Table {
|
||||
public static int createRequiredResourcesVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startRequiredResourcesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static void addLanguage(FlatBufferBuilder builder, int language) { builder.addInt(14, language, 0); }
|
||||
public static void addFunctionDescriptor(FlatBufferBuilder builder, int functionDescriptorOffset) { builder.addOffset(15, functionDescriptorOffset, 0); }
|
||||
public static int createFunctionDescriptorVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startFunctionDescriptorVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static int endTaskInfo(FlatBufferBuilder builder) {
|
||||
int o = builder.endObject();
|
||||
return o;
|
||||
|
||||
@@ -8,7 +8,6 @@ import org.ray.api.RayObject;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.RayDevRuntime;
|
||||
import org.ray.runtime.functionmanager.LocalFunctionManager;
|
||||
import org.ray.runtime.objectstore.MockObjectStore;
|
||||
import org.ray.runtime.task.FunctionArg;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
@@ -20,7 +19,6 @@ public class MockRayletClient implements RayletClient {
|
||||
|
||||
private final Map<UniqueId, Map<UniqueId, TaskSpec>> waitTasks = new ConcurrentHashMap<>();
|
||||
private final MockObjectStore store;
|
||||
private LocalFunctionManager functions = null;
|
||||
private final RayDevRuntime runtime;
|
||||
|
||||
public MockRayletClient(RayDevRuntime runtime, MockObjectStore store) {
|
||||
@@ -29,10 +27,6 @@ public class MockRayletClient implements RayletClient {
|
||||
store.registerScheduler(this);
|
||||
}
|
||||
|
||||
public void setLocalFunctionManager(LocalFunctionManager mgr) {
|
||||
functions = mgr;
|
||||
}
|
||||
|
||||
public void onObjectPut(UniqueId id) {
|
||||
Map<UniqueId, TaskSpec> bucket = waitTasks.get(id);
|
||||
if (bucket != null) {
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
package org.ray.runtime.raylet;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.generated.Arg;
|
||||
import org.ray.runtime.generated.ResourcePair;
|
||||
import org.ray.runtime.generated.TaskInfo;
|
||||
@@ -21,12 +23,18 @@ import org.ray.runtime.util.logger.RayLog;
|
||||
|
||||
public class RayletClientImpl implements RayletClient {
|
||||
|
||||
private static ThreadLocal<ByteBuffer> _taskBuffer = ThreadLocal.withInitial(() -> {
|
||||
ByteBuffer bb = ByteBuffer
|
||||
.allocateDirect(AbstractRayRuntime.getParams().max_submit_task_buffer_size_bytes);
|
||||
bb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
return bb;
|
||||
});
|
||||
private static final int TASK_SPEC_BUFFER_SIZE = 2 * 1024 * 1024;
|
||||
|
||||
/**
|
||||
* Direct buffers that are used to hold flatbuffer-serialized task specs.
|
||||
*/
|
||||
private static ThreadLocal<ByteBuffer> taskSpecBuffer = ThreadLocal.withInitial(() ->
|
||||
ByteBuffer.allocateDirect(TASK_SPEC_BUFFER_SIZE).order(ByteOrder.LITTLE_ENDIAN)
|
||||
);
|
||||
|
||||
/**
|
||||
* Point to c++'s local scheduler client.
|
||||
*/
|
||||
private long client = 0;
|
||||
|
||||
public RayletClientImpl(String schedulerSockName, UniqueId clientId,
|
||||
@@ -58,15 +66,14 @@ public class RayletClientImpl implements RayletClient {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void submitTask(TaskSpec task) {
|
||||
RayLog.core.debug("Submitting task: {}", task);
|
||||
|
||||
ByteBuffer info = taskSpec2Info(task);
|
||||
public void submitTask(TaskSpec spec) {
|
||||
RayLog.core.debug("Submitting task: {}", spec);
|
||||
ByteBuffer info = convertTaskSpecToFlatbuffer(spec);
|
||||
byte[] cursorId = null;
|
||||
if (!task.actorId.isNil()) {
|
||||
cursorId = task.cursorId.getBytes();
|
||||
if (!spec.getExecutionDependencies().isEmpty()) {
|
||||
//TODO(hchen): handle more than one dependencies.
|
||||
cursorId = spec.getExecutionDependencies().get(0).getBytes();
|
||||
}
|
||||
|
||||
nativeSubmitTask(client, cursorId, info, info.position(), info.remaining());
|
||||
}
|
||||
|
||||
@@ -75,7 +82,7 @@ public class RayletClientImpl implements RayletClient {
|
||||
byte[] bytes = nativeGetTask(client);
|
||||
assert (null != bytes);
|
||||
ByteBuffer bb = ByteBuffer.wrap(bytes);
|
||||
return taskInfo2Spec(bb);
|
||||
return parseTaskSpecFromFlatbuffer(bb);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -104,105 +111,93 @@ public class RayletClientImpl implements RayletClient {
|
||||
nativeFreePlasmaObjects(client, objectIdsArray, localOnly);
|
||||
}
|
||||
|
||||
public static TaskSpec taskInfo2Spec(ByteBuffer bb) {
|
||||
private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) {
|
||||
bb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
TaskInfo info = TaskInfo.getRootAsTaskInfo(bb);
|
||||
|
||||
TaskSpec spec = new TaskSpec();
|
||||
spec.driverId = UniqueId.fromByteBuffer(info.driverIdAsByteBuffer());
|
||||
spec.taskId = UniqueId.fromByteBuffer(info.taskIdAsByteBuffer());
|
||||
spec.parentTaskId = UniqueId.fromByteBuffer(info.parentTaskIdAsByteBuffer());
|
||||
spec.parentCounter = info.parentCounter();
|
||||
spec.actorId = UniqueId.fromByteBuffer(info.actorIdAsByteBuffer());
|
||||
spec.actorCounter = info.actorCounter();
|
||||
spec.createActorId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer());
|
||||
|
||||
spec.functionId = UniqueId.fromByteBuffer(info.functionIdAsByteBuffer());
|
||||
|
||||
List<FunctionArg> args = new ArrayList<>();
|
||||
UniqueId driverId = UniqueId.fromByteBuffer(info.driverIdAsByteBuffer());
|
||||
UniqueId taskId = UniqueId.fromByteBuffer(info.taskIdAsByteBuffer());
|
||||
UniqueId parentTaskId = UniqueId.fromByteBuffer(info.parentTaskIdAsByteBuffer());
|
||||
int parentCounter = info.parentCounter();
|
||||
UniqueId actorCreationId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer());
|
||||
UniqueId actorId = UniqueId.fromByteBuffer(info.actorIdAsByteBuffer());
|
||||
UniqueId actorHandleId = UniqueId.fromByteBuffer(info.actorHandleIdAsByteBuffer());
|
||||
int actorCounter = info.actorCounter();
|
||||
// Deserialize args
|
||||
FunctionArg[] args = new FunctionArg[info.argsLength()];
|
||||
for (int i = 0; i < info.argsLength(); i++) {
|
||||
UniqueId id = null;
|
||||
byte[] data = null;
|
||||
Arg sarg = info.args(i);
|
||||
|
||||
int idCount = sarg.objectIdsLength();
|
||||
if (idCount > 0) {
|
||||
ByteBuffer lbb = sarg.objectIdAsByteBuffer(0);
|
||||
assert (lbb != null && lbb.remaining() > 0);
|
||||
id = UniqueId.fromByteBuffer(lbb);
|
||||
}
|
||||
|
||||
ByteBuffer lbb = sarg.dataAsByteBuffer();
|
||||
if (lbb != null && lbb.remaining() > 0) {
|
||||
// TODO: how to avoid memory copy
|
||||
data = new byte[lbb.remaining()];
|
||||
Arg arg = info.args(i);
|
||||
if (arg.objectIdsLength() > 0) {
|
||||
Preconditions.checkArgument(arg.objectIdsLength() == 1,
|
||||
"This arg has more than one id: {}", arg.objectIdsLength());
|
||||
UniqueId id = UniqueId.fromByteBuffer(arg.objectIdAsByteBuffer(0));
|
||||
args[i] = FunctionArg.passByReference(id);
|
||||
} else {
|
||||
ByteBuffer lbb = arg.dataAsByteBuffer();
|
||||
Preconditions.checkState(lbb != null && lbb.remaining() > 0);
|
||||
byte[] data = new byte[lbb.remaining()];
|
||||
lbb.get(data);
|
||||
args[i] = FunctionArg.passByValue(data);
|
||||
}
|
||||
|
||||
args.add(new FunctionArg(id, data));
|
||||
}
|
||||
spec.args = args.toArray(new FunctionArg[0]);
|
||||
|
||||
List<UniqueId> rids = new ArrayList<>();
|
||||
// Deserialize return ids
|
||||
UniqueId[] returnIds = new UniqueId[info.returnsLength()];
|
||||
for (int i = 0; i < info.returnsLength(); i++) {
|
||||
ByteBuffer lbb = info.returnsAsByteBuffer(i);
|
||||
assert (lbb != null && lbb.remaining() > 0);
|
||||
rids.add(UniqueId.fromByteBuffer(lbb));
|
||||
returnIds[i] = UniqueId.fromByteBuffer(info.returnsAsByteBuffer(i));
|
||||
}
|
||||
spec.returnIds = rids.toArray(new UniqueId[0]);
|
||||
|
||||
return spec;
|
||||
// Deserialize required resources;
|
||||
Map<String, Double> resources = new HashMap<>();
|
||||
for (int i = 0; i < info.requiredResourcesLength(); i++) {
|
||||
resources.put(info.requiredResources(i).key(), info.requiredResources(i).value());
|
||||
}
|
||||
// Deserialize function descriptor
|
||||
Preconditions.checkArgument(info.functionDescriptorLength() == 3);
|
||||
FunctionDescriptor functionDescriptor = new FunctionDescriptor(
|
||||
info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2)
|
||||
);
|
||||
return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId, actorId,
|
||||
actorHandleId, actorCounter, args, returnIds, resources, functionDescriptor);
|
||||
}
|
||||
|
||||
public static ByteBuffer taskSpec2Info(TaskSpec task) {
|
||||
ByteBuffer bb = _taskBuffer.get();
|
||||
private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) {
|
||||
ByteBuffer bb = taskSpecBuffer.get();
|
||||
bb.clear();
|
||||
|
||||
FlatBufferBuilder fbb = new FlatBufferBuilder(bb);
|
||||
|
||||
final int driverIdOffset = fbb.createString(task.driverId.toByteBuffer());
|
||||
final int taskIdOffset = fbb.createString(task.taskId.toByteBuffer());
|
||||
final int parentTaskIdOffset = fbb.createString(task.parentTaskId.toByteBuffer());
|
||||
final int parentCounter = task.parentCounter;
|
||||
final int actorCreateIdOffset = fbb.createString(task.createActorId.toByteBuffer());
|
||||
final int actorCreateIdOffset = fbb.createString(task.actorCreationId.toByteBuffer());
|
||||
final int actorCreateDummyIdOffset = fbb.createString(UniqueId.NIL.toByteBuffer());
|
||||
final int actorIdOffset = fbb.createString(task.actorId.toByteBuffer());
|
||||
final int actorHandleIdOffset = fbb.createString(task.actorHandleId.toByteBuffer());
|
||||
final int actorCounter = task.actorCounter;
|
||||
final int functionIdOffset = fbb.createString(task.functionId.toByteBuffer());
|
||||
|
||||
// serialize args
|
||||
final int functionIdOffset = fbb.createString(UniqueId.NIL.toByteBuffer());
|
||||
// Serialize args
|
||||
int[] argsOffsets = new int[task.args.length];
|
||||
for (int i = 0; i < argsOffsets.length; i++) {
|
||||
|
||||
int objectIdOffset = 0;
|
||||
int dataOffset = 0;
|
||||
if (task.args[i].id != null) {
|
||||
int[] idOffsets = new int[] {
|
||||
fbb.createString(task.args[i].id.toByteBuffer())
|
||||
};
|
||||
int[] idOffsets = new int[]{fbb.createString(task.args[i].id.toByteBuffer())};
|
||||
objectIdOffset = fbb.createVectorOfTables(idOffsets);
|
||||
} else {
|
||||
objectIdOffset = fbb.createVectorOfTables(new int[0]);
|
||||
}
|
||||
|
||||
if (task.args[i].data != null) {
|
||||
dataOffset = fbb.createString(ByteBuffer.wrap(task.args[i].data));
|
||||
}
|
||||
|
||||
argsOffsets[i] = Arg.createArg(fbb, objectIdOffset, dataOffset);
|
||||
}
|
||||
int argsOffset = fbb.createVectorOfTables(argsOffsets);
|
||||
|
||||
// serialize returns
|
||||
// Serialize returns
|
||||
int returnCount = task.returnIds.length;
|
||||
int[] returnsOffsets = new int[returnCount];
|
||||
for (int k = 0; k < returnCount; k++) {
|
||||
returnsOffsets[k] = fbb.createString(task.returnIds[k].toByteBuffer());
|
||||
}
|
||||
int returnsOffset = fbb.createVectorOfTables(returnsOffsets);
|
||||
|
||||
// serialize required resources
|
||||
// Serialize required resources
|
||||
// The required_resources vector indicates the quantities of the different
|
||||
// resources required by this task. The index in this vector corresponds to
|
||||
// the resource type defined in the ResourceIndex enum. For example,
|
||||
@@ -210,12 +205,16 @@ public class RayletClientImpl implements RayletClient {
|
||||
int i = 0;
|
||||
for (Map.Entry<String, Double> entry : task.resources.entrySet()) {
|
||||
int keyOffset = fbb.createString(ByteBuffer.wrap(entry.getKey().getBytes()));
|
||||
requiredResourcesOffsets[i] =
|
||||
requiredResourcesOffsets[i++] =
|
||||
ResourcePair.createResourcePair(fbb, keyOffset, entry.getValue());
|
||||
i++;
|
||||
}
|
||||
|
||||
int requiredResourcesOffset = fbb.createVectorOfTables(requiredResourcesOffsets);
|
||||
int[] functionDescriptorOffsets = new int[]{
|
||||
fbb.createString(task.functionDescriptor.className),
|
||||
fbb.createString(task.functionDescriptor.name),
|
||||
fbb.createString(task.functionDescriptor.typeDescriptor)
|
||||
};
|
||||
int functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
|
||||
|
||||
int root = TaskInfo.createTaskInfo(
|
||||
fbb, driverIdOffset, taskIdOffset,
|
||||
@@ -223,18 +222,17 @@ public class RayletClientImpl implements RayletClient {
|
||||
actorCreateIdOffset, actorCreateDummyIdOffset,
|
||||
actorIdOffset, actorHandleIdOffset, actorCounter,
|
||||
false, functionIdOffset,
|
||||
argsOffset, returnsOffset, requiredResourcesOffset, TaskLanguage.JAVA);
|
||||
|
||||
argsOffset, returnsOffset, requiredResourcesOffset, TaskLanguage.JAVA,
|
||||
functionDescriptorOffset);
|
||||
fbb.finish(root);
|
||||
ByteBuffer buffer = fbb.dataBuffer();
|
||||
|
||||
if (buffer.remaining() > AbstractRayRuntime.getParams().max_submit_task_buffer_size_bytes) {
|
||||
if (buffer.remaining() > TASK_SPEC_BUFFER_SIZE) {
|
||||
RayLog.core.error(
|
||||
"Allocated buffer is not enough to transfer the task specification: " + AbstractRayRuntime
|
||||
.getParams().max_submit_task_buffer_size_bytes + " vs " + buffer.remaining());
|
||||
"Allocated buffer is not enough to transfer the task specification: "
|
||||
+ TASK_SPEC_BUFFER_SIZE + " vs " + buffer.remaining());
|
||||
assert (false);
|
||||
}
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
@@ -251,7 +249,6 @@ public class RayletClientImpl implements RayletClient {
|
||||
nativeDestroy(client);
|
||||
}
|
||||
|
||||
|
||||
/// Native method declarations.
|
||||
///
|
||||
/// If you change the signature of any native methods, please re-generate
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
@@ -31,10 +33,13 @@ public class ArgumentsBuilder {
|
||||
} else if (checkSimpleValue(arg)) {
|
||||
data = Serializer.encode(arg);
|
||||
} else {
|
||||
RayObject obj = Ray.put(arg);
|
||||
id = obj.getId();
|
||||
id = Ray.put(arg).getId();
|
||||
}
|
||||
if (id != null) {
|
||||
ret[i] = FunctionArg.passByReference(id);
|
||||
} else {
|
||||
ret[i] = FunctionArg.passByValue(data);
|
||||
}
|
||||
ret[i] = new FunctionArg(id, data);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
@@ -43,19 +48,24 @@ public class ArgumentsBuilder {
|
||||
* Convert task spec arguments to real function arguments.
|
||||
*/
|
||||
public static Object[] unwrap(TaskSpec task, ClassLoader classLoader) {
|
||||
// Ignore the last arg, which is the class name
|
||||
Object[] realArgs = new Object[task.args.length - 1];
|
||||
for (int i = 0; i < task.args.length - 1; i++) {
|
||||
Object[] realArgs = new Object[task.args.length];
|
||||
List<UniqueId> idsToFetch = new ArrayList<>();
|
||||
List<Integer> indices = new ArrayList<>();
|
||||
for (int i = 0; i < task.args.length; i++) {
|
||||
FunctionArg arg = task.args[i];
|
||||
if (arg.id == null) {
|
||||
// pass by value
|
||||
Object obj = Serializer.decode(arg.data, classLoader);
|
||||
realArgs[i] = obj;
|
||||
} else if (arg.data == null) {
|
||||
if (arg.id != null) {
|
||||
// pass by reference
|
||||
realArgs[i] = Ray.get(arg.id);
|
||||
idsToFetch.add(arg.id);
|
||||
indices.add(i);
|
||||
} else {
|
||||
// pass by value
|
||||
realArgs[i] = Serializer.decode(arg.data, classLoader);
|
||||
}
|
||||
}
|
||||
List<Object> objects = Ray.get(idsToFetch);
|
||||
for (int i = 0; i < objects.size(); i++) {
|
||||
realArgs[indices.get(i)] = objects.get(i);
|
||||
}
|
||||
return realArgs;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,19 +3,47 @@ package org.ray.runtime.task;
|
||||
import org.ray.api.id.UniqueId;
|
||||
|
||||
/**
|
||||
* Represents arguments for ray function calls.
|
||||
* Represents a function argument in task spec.
|
||||
*
|
||||
* Either `id` or `data` should be null, when id is not null, this argument will be
|
||||
* passed by reference, otherwise it will be passed by value.
|
||||
*/
|
||||
public class FunctionArg {
|
||||
|
||||
/**
|
||||
* The id of this argument (passed by reference).
|
||||
*/
|
||||
public final UniqueId id;
|
||||
/**
|
||||
* Serialized data of this argument (passed by value).
|
||||
*/
|
||||
public final byte[] data;
|
||||
|
||||
public FunctionArg(UniqueId id, byte[] data) {
|
||||
private FunctionArg(UniqueId id, byte[] data) {
|
||||
this.id = id;
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
public void toString(StringBuilder builder) {
|
||||
builder.append("ids: ").append(id).append(", ").append("<data>:").append(data);
|
||||
/**
|
||||
* Create a FunctionArg that will be passed by reference.
|
||||
*/
|
||||
public static FunctionArg passByReference(UniqueId id) {
|
||||
return new FunctionArg(id, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a FunctionArg that will be passed by value.
|
||||
*/
|
||||
public static FunctionArg passByValue(byte[] data) {
|
||||
return new FunctionArg(null, data);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
if (id != null) {
|
||||
return "<id>: " + id.toString();
|
||||
} else {
|
||||
return "<data>: " + data.length;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.util.ResourceUtil;
|
||||
|
||||
/**
|
||||
@@ -11,102 +14,91 @@ import org.ray.runtime.util.ResourceUtil;
|
||||
public class TaskSpec {
|
||||
|
||||
// ID of the driver that created this task.
|
||||
public UniqueId driverId;
|
||||
public final UniqueId driverId;
|
||||
|
||||
// Task ID of the task.
|
||||
public UniqueId taskId;
|
||||
public final UniqueId taskId;
|
||||
|
||||
// Task ID of the parent task.
|
||||
public UniqueId parentTaskId;
|
||||
public final UniqueId parentTaskId;
|
||||
|
||||
// A count of the number of tasks submitted by the parent task before this one.
|
||||
public int parentCounter;
|
||||
public final int parentCounter;
|
||||
|
||||
// Id for createActor a target actor
|
||||
public final UniqueId actorCreationId;
|
||||
|
||||
// Actor ID of the task. This is the actor that this task is executed on
|
||||
// or NIL_ACTOR_ID if the task is just a normal task.
|
||||
public UniqueId actorId;
|
||||
|
||||
// Number of tasks that have been submitted to this actor so far.
|
||||
public int actorCounter;
|
||||
|
||||
// Function ID of the task.
|
||||
public UniqueId functionId;
|
||||
|
||||
// Task arguments.
|
||||
public FunctionArg[] args;
|
||||
|
||||
// return ids
|
||||
public UniqueId[] returnIds;
|
||||
public final UniqueId actorId;
|
||||
|
||||
// ID per actor client for session consistency
|
||||
public UniqueId actorHandleId;
|
||||
public final UniqueId actorHandleId;
|
||||
|
||||
// Id for createActor a target actor
|
||||
public UniqueId createActorId;
|
||||
// Number of tasks that have been submitted to this actor so far.
|
||||
public final int actorCounter;
|
||||
|
||||
// Task arguments.
|
||||
public final FunctionArg[] args;
|
||||
|
||||
// return ids
|
||||
public final UniqueId[] returnIds;
|
||||
|
||||
// The task's resource demands.
|
||||
public Map<String, Double> resources;
|
||||
public final Map<String, Double> resources;
|
||||
|
||||
public UniqueId cursorId;
|
||||
// Function descriptor is a list of strings that can uniquely identify a function.
|
||||
// It will be sent to worker and used to load the target callable function.
|
||||
public final FunctionDescriptor functionDescriptor;
|
||||
|
||||
public TaskSpec() {}
|
||||
|
||||
public TaskSpec(UniqueId driverId, UniqueId taskId, UniqueId parentTaskId, int parentCounter,
|
||||
UniqueId actorId, int actorCounter, UniqueId functionId, FunctionArg[] args,
|
||||
UniqueId[] returnIds, UniqueId actorHandleId, UniqueId createActorId,
|
||||
Map<String, Double> resources, UniqueId cursorId) {
|
||||
this.driverId = driverId;
|
||||
this.taskId = taskId;
|
||||
this.parentTaskId = parentTaskId;
|
||||
this.parentCounter = parentCounter;
|
||||
this.actorId = actorId;
|
||||
this.actorCounter = actorCounter;
|
||||
this.functionId = functionId;
|
||||
this.args = args;
|
||||
this.returnIds = returnIds;
|
||||
this.actorHandleId = actorHandleId;
|
||||
this.createActorId = createActorId;
|
||||
this.resources = resources;
|
||||
this.cursorId = cursorId;
|
||||
|
||||
if (!this.resources.containsKey(ResourceUtil.CPU_LITERAL)) {
|
||||
this.resources.put(ResourceUtil.CPU_LITERAL, 0.0);
|
||||
}
|
||||
|
||||
if (!this.resources.containsKey(ResourceUtil.GPU_LITERAL)) {
|
||||
this.resources.put(ResourceUtil.GPU_LITERAL, 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append("\ttaskId: ").append(taskId).append("\n");
|
||||
builder.append("\tdriverId: ").append(driverId).append("\n");
|
||||
builder.append("\tparentCounter: ").append(parentCounter).append("\n");
|
||||
builder.append("\tactorId: ").append(actorId).append("\n");
|
||||
builder.append("\tactorCounter: ").append(actorCounter).append("\n");
|
||||
builder.append("\tfunctionId: ").append(functionId).append("\n");
|
||||
builder.append("\treturnIds: ").append(Arrays.toString(returnIds)).append("\n");
|
||||
builder.append("\tactorHandleId: ").append(actorHandleId).append("\n");
|
||||
builder.append("\tcreateActorId: ").append(createActorId).append("\n");
|
||||
builder.append("\tresources: ")
|
||||
.append(ResourceUtil.getResourcesFromatStringFromMap(resources)).append("\n");
|
||||
builder.append("\tcursorId: ").append(cursorId).append("\n");
|
||||
builder.append("\targs:\n");
|
||||
for (FunctionArg arg : args) {
|
||||
builder.append("\t\t");
|
||||
arg.toString(builder);
|
||||
builder.append("\n");
|
||||
}
|
||||
return builder.toString();
|
||||
}
|
||||
private List<UniqueId> executionDependencies;
|
||||
|
||||
public boolean isActorTask() {
|
||||
return !actorId.isNil();
|
||||
}
|
||||
|
||||
public boolean isActorCreationTask() {
|
||||
return !createActorId.isNil();
|
||||
return !actorCreationId.isNil();
|
||||
}
|
||||
|
||||
public TaskSpec(UniqueId driverId, UniqueId taskId, UniqueId parentTaskId, int parentCounter,
|
||||
UniqueId actorCreationId, UniqueId actorId, UniqueId actorHandleId, int actorCounter,
|
||||
FunctionArg[] args, UniqueId[] returnIds,
|
||||
Map<String, Double> resources, FunctionDescriptor functionDescriptor) {
|
||||
this.driverId = driverId;
|
||||
this.taskId = taskId;
|
||||
this.parentTaskId = parentTaskId;
|
||||
this.parentCounter = parentCounter;
|
||||
this.actorCreationId = actorCreationId;
|
||||
this.actorId = actorId;
|
||||
this.actorHandleId = actorHandleId;
|
||||
this.actorCounter = actorCounter;
|
||||
this.args = args;
|
||||
this.returnIds = returnIds;
|
||||
this.resources = resources;
|
||||
this.functionDescriptor = functionDescriptor;
|
||||
this.executionDependencies = new ArrayList<>();
|
||||
}
|
||||
|
||||
public List<UniqueId> getExecutionDependencies() {
|
||||
return executionDependencies;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "TaskSpec{" +
|
||||
"driverId=" + driverId +
|
||||
", taskId=" + taskId +
|
||||
", parentTaskId=" + parentTaskId +
|
||||
", parentCounter=" + parentCounter +
|
||||
", actorCreationId=" + actorCreationId +
|
||||
", actorId=" + actorId +
|
||||
", actorHandleId=" + actorHandleId +
|
||||
", actorCounter=" + actorCounter +
|
||||
", args=" + Arrays.toString(args) +
|
||||
", returnIds=" + Arrays.toString(returnIds) +
|
||||
", resources=" + ResourceUtil.getResourcesStringFromMap(resources) +
|
||||
", functionDescriptor=" + functionDescriptor +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,215 +0,0 @@
|
||||
package org.ray.runtime.util;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.io.Serializable;
|
||||
import java.lang.invoke.MethodHandleInfo;
|
||||
import java.lang.invoke.SerializedLambda;
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Executable;
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.WeakHashMap;
|
||||
import org.objectweb.asm.Type;
|
||||
import org.ray.runtime.util.logger.RayLog;
|
||||
|
||||
|
||||
/**
|
||||
* An instance of RayFunc is a lambda.
|
||||
* MethodId describe the information of the called function in lambda.<br/>
|
||||
* e.g. Ray.call(Foo::foo), the MethodId of the lambda Foo::foo is:<br/>
|
||||
* MethodId.className = Foo <br/>
|
||||
* MethodId.methodName = foo <br/>
|
||||
* MethodId.methodDesc = describe the types of args and return.
|
||||
* see org.objectweb.asm.Type.getDescriptor.
|
||||
*/
|
||||
public final class MethodId {
|
||||
|
||||
/**
|
||||
* use ThreadLocal to avoid lock.
|
||||
* A cache from the lambda instances to MethodId.
|
||||
* Note: the lambda instances are dynamically created per call site,
|
||||
* we use WeakHashMap to avoid OOM.
|
||||
*/
|
||||
private static final ThreadLocal<WeakHashMap<Class<Serializable>, MethodId>>
|
||||
CACHE = ThreadLocal.withInitial(() -> new WeakHashMap<>());
|
||||
|
||||
public final String className;
|
||||
public final String methodName;
|
||||
|
||||
public final String methodDesc;
|
||||
public final boolean isStatic;
|
||||
/**
|
||||
* encode the className,methodName,methodDesc,isStatic as an uniquel id.
|
||||
*/
|
||||
private final String encoding;
|
||||
|
||||
/**
|
||||
* sha1 from the encoding, used as functionId.
|
||||
*/
|
||||
private final byte[] digest;
|
||||
|
||||
public MethodId(String className, String methodName, String methodDesc, boolean isStatic) {
|
||||
this.className = className;
|
||||
this.methodName = methodName;
|
||||
this.methodDesc = methodDesc;
|
||||
this.isStatic = isStatic;
|
||||
this.encoding = encode(className, methodName, methodDesc, isStatic);
|
||||
this.digest = getSha1Hash0();
|
||||
}
|
||||
|
||||
private static String encode(String className, String methodName, String methodDesc,
|
||||
boolean isStatic) {
|
||||
StringBuilder sb = new StringBuilder(512);
|
||||
sb.append(className).append('/').append(methodName).append("::").append(methodDesc).append("&&")
|
||||
.append(isStatic);
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
public static MethodId fromExecutable(Executable method) {
|
||||
final boolean isStatic = Modifier.isStatic(method.getModifiers());
|
||||
final String className = method.getDeclaringClass().getName();
|
||||
final String methodName = method instanceof Method
|
||||
? method.getName() : "<init>";
|
||||
final Type type = method instanceof Method
|
||||
? Type.getType((Method) method) : Type.getType((Constructor) method);
|
||||
final String methodDesc = type.getDescriptor();
|
||||
return new MethodId(className, methodName, methodDesc, isStatic);
|
||||
}
|
||||
|
||||
public static MethodId fromSerializedLambda(Serializable serial) {
|
||||
return fromSerializedLambda(serial, false);
|
||||
}
|
||||
|
||||
public static MethodId fromSerializedLambda(Serializable serial, boolean forceNew) {
|
||||
Preconditions.checkArgument(!(serial instanceof SerializedLambda), "arg could not be "
|
||||
+ "SerializedLambda");
|
||||
Class<Serializable> clazz = (Class<Serializable>) serial.getClass();
|
||||
WeakHashMap<Class<Serializable>, MethodId> map = CACHE.get();
|
||||
MethodId id = map.get(clazz);
|
||||
if (id == null || forceNew) {
|
||||
final SerializedLambda lambda = LambdaUtils.getSerializedLambda(serial);
|
||||
Preconditions.checkArgument(lambda.getCapturedArgCount() == 0, "could not transfer a lambda "
|
||||
+ "which is closure");
|
||||
final boolean isStatic = lambda.getImplMethodKind() == MethodHandleInfo.REF_invokeStatic;
|
||||
final String className = lambda.getImplClass().replace('/', '.');
|
||||
id = new MethodId(className, lambda.getImplMethodName(),
|
||||
lambda.getImplMethodSignature(), isStatic);
|
||||
if (!forceNew) {
|
||||
map.put(clazz, id);
|
||||
}
|
||||
}
|
||||
return id;
|
||||
}
|
||||
|
||||
public Method load() {
|
||||
return load(null);
|
||||
}
|
||||
|
||||
public Method load(ClassLoader loader) {
|
||||
Class<?> cls = null;
|
||||
try {
|
||||
RayLog.core.debug(
|
||||
"load class " + className + " from class loader " + (loader == null ? this.getClass()
|
||||
.getClassLoader() : loader)
|
||||
+ " for method " + toString() + " with ID = " + toHexHashString()
|
||||
);
|
||||
cls = Class
|
||||
.forName(className, true, loader == null ? this.getClass().getClassLoader() : loader);
|
||||
} catch (Throwable e) {
|
||||
RayLog.core.error("Cannot load class {}", className, e);
|
||||
return null;
|
||||
}
|
||||
|
||||
Method[] ms = cls.getDeclaredMethods();
|
||||
ArrayList<Method> methods = new ArrayList<>();
|
||||
Type t = Type.getMethodType(this.methodDesc);
|
||||
Type[] params = t.getArgumentTypes();
|
||||
String rt = t.getReturnType().getDescriptor();
|
||||
|
||||
for (Method m : ms) {
|
||||
if (m.getName().equals(methodName)) {
|
||||
if (!Arrays.equals(params, Type.getArgumentTypes(m))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
String mrt = Type.getDescriptor(m.getReturnType());
|
||||
if (!rt.equals(mrt)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isStatic != Modifier.isStatic(m.getModifiers())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
methods.add(m);
|
||||
}
|
||||
}
|
||||
|
||||
if (methods.size() != 1) {
|
||||
RayLog.core.error(
|
||||
"Load method {} failed as there are {} definitions.", toString(), methods.size());
|
||||
return null;
|
||||
}
|
||||
|
||||
return methods.get(0);
|
||||
}
|
||||
|
||||
private byte[] getSha1Hash0() {
|
||||
byte[] digests = Sha1Digestor.digest(encoding);
|
||||
ByteBuffer bb = ByteBuffer.wrap(digests);
|
||||
bb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
if (methodName.contains("createActorStage1")) {
|
||||
bb.putLong(Long.BYTES, 1);
|
||||
} else {
|
||||
bb.putLong(Long.BYTES, 0);
|
||||
}
|
||||
return digests;
|
||||
}
|
||||
|
||||
public byte[] getSha1Hash() {
|
||||
return digest;
|
||||
}
|
||||
|
||||
private String toHexHashString() {
|
||||
byte[] id = this.getSha1Hash();
|
||||
return StringUtil.toHexHashString(id);
|
||||
}
|
||||
|
||||
public String toEncodingString() {
|
||||
return encoding;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return encoding.hashCode();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null) {
|
||||
return false;
|
||||
}
|
||||
if (getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
MethodId other = (MethodId) obj;
|
||||
return className.equals(other.className)
|
||||
&& methodName.equals(other.methodName)
|
||||
&& methodDesc.equals(other.methodDesc)
|
||||
&& isStatic == other.isStatic;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return encoding;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import java.util.Map;
|
||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.function.RayFunc0;
|
||||
import org.ray.api.function.RayFunc1;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.functionmanager.FunctionManager.DriverFunctionTable;
|
||||
|
||||
/**
|
||||
* Tests for {@link FunctionManager}
|
||||
*/
|
||||
public class FunctionManagerTest {
|
||||
|
||||
@RayRemote
|
||||
public static Object foo() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
public static class Bar {
|
||||
|
||||
public Bar() {
|
||||
}
|
||||
|
||||
public Object bar() {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private static RayFunc0<Object> fooFunc;
|
||||
private static RayFunc1<Bar, Object> barFunc;
|
||||
private static RayFunc0<Bar> barConstructor;
|
||||
private static FunctionDescriptor fooDescriptor;
|
||||
private static FunctionDescriptor barDescriptor;
|
||||
private static FunctionDescriptor barConstructorDescriptor;
|
||||
|
||||
private FunctionManager functionManager;
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass() {
|
||||
fooFunc = FunctionManagerTest::foo;
|
||||
barConstructor = Bar::new;
|
||||
barFunc = Bar::bar;
|
||||
fooDescriptor = new FunctionDescriptor(FunctionManagerTest.class.getName(), "foo",
|
||||
"()Ljava/lang/Object;");
|
||||
barDescriptor = new FunctionDescriptor(Bar.class.getName(), "bar",
|
||||
"()Ljava/lang/Object;");
|
||||
barConstructorDescriptor = new FunctionDescriptor(Bar.class.getName(),
|
||||
FunctionManager.CONSTRUCTOR_NAME,
|
||||
"()V");
|
||||
}
|
||||
|
||||
@Before
|
||||
public void before() {
|
||||
functionManager = new FunctionManager();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetFunctionFromRayFunc() {
|
||||
// Test normal function.
|
||||
RayFunction func = functionManager.getFunction(UniqueId.NIL, fooFunc);
|
||||
Assert.assertFalse(func.isConstructor());
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor);
|
||||
Assert.assertNotNull(func.getRayRemoteAnnotation());
|
||||
|
||||
// Test actor method
|
||||
func = functionManager.getFunction(UniqueId.NIL, barFunc);
|
||||
Assert.assertFalse(func.isConstructor());
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), barDescriptor);
|
||||
Assert.assertNotNull(func.getRayRemoteAnnotation());
|
||||
|
||||
// Test actor constructor
|
||||
func = functionManager.getFunction(UniqueId.NIL, barConstructor);
|
||||
Assert.assertTrue(func.isConstructor());
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), barConstructorDescriptor);
|
||||
Assert.assertNotNull(func.getRayRemoteAnnotation());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetFunctionFromFunctionDescriptor() {
|
||||
// Test normal function.
|
||||
RayFunction func = functionManager.getFunction(UniqueId.NIL, fooDescriptor);
|
||||
Assert.assertFalse(func.isConstructor());
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor);
|
||||
Assert.assertNotNull(func.getRayRemoteAnnotation());
|
||||
|
||||
// Test actor method
|
||||
func = functionManager.getFunction(UniqueId.NIL, barDescriptor);
|
||||
Assert.assertFalse(func.isConstructor());
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), barDescriptor);
|
||||
Assert.assertNotNull(func.getRayRemoteAnnotation());
|
||||
|
||||
// Test actor constructor
|
||||
func = functionManager.getFunction(UniqueId.NIL, barConstructorDescriptor);
|
||||
Assert.assertTrue(func.isConstructor());
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), barConstructorDescriptor);
|
||||
Assert.assertNotNull(func.getRayRemoteAnnotation());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLoadFunctionTableForClass() {
|
||||
DriverFunctionTable functionTable = new DriverFunctionTable(getClass().getClassLoader());
|
||||
Map<Pair<String, String>, RayFunction> res = functionTable
|
||||
.loadFunctionsForClass(Bar.class.getName());
|
||||
// The result should 2 entries, one for the constructor, the other for bar.
|
||||
Assert.assertEquals(res.size(), 2);
|
||||
Assert.assertTrue(res.containsKey(
|
||||
ImmutablePair.of(barDescriptor.name, barDescriptor.typeDescriptor)));
|
||||
Assert.assertTrue(res.containsKey(
|
||||
ImmutablePair.of(barConstructorDescriptor.name, barConstructorDescriptor.typeDescriptor)));
|
||||
}
|
||||
}
|
||||
@@ -35,8 +35,6 @@
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>4.11</version>
|
||||
<!-- <scope>test</scope> -->
|
||||
</dependency>
|
||||
|
||||
<!-- https://mvnrepository.com/artifact/commons-collections/commons-collections -->
|
||||
|
||||
@@ -1,211 +0,0 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.ray.api.function.RayFunc0;
|
||||
import org.ray.api.function.RayFunc1;
|
||||
import org.ray.api.function.RayFunc3;
|
||||
import org.ray.runtime.util.MethodId;
|
||||
import org.ray.runtime.util.logger.RayLog;
|
||||
|
||||
public class LambdaUtilsTest {
|
||||
|
||||
static final String CLASS_NAME = LambdaUtilsTest.class.getName();
|
||||
static final Method CALL0;
|
||||
static final Method CALL1;
|
||||
static final Method CALL2;
|
||||
static final Method CALL3;
|
||||
|
||||
static {
|
||||
try {
|
||||
CALL0 = LambdaUtilsTest.class.getDeclaredMethod("call0", new Class[0]);
|
||||
CALL1 = LambdaUtilsTest.class.getDeclaredMethod("call1", new Class[]{Long.class});
|
||||
CALL2 = LambdaUtilsTest.class.getDeclaredMethod("call2", new Class[0]);
|
||||
CALL3 = LambdaUtilsTest.class
|
||||
.getDeclaredMethod("call3", new Class[]{Long.class, String.class});
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static <T0, T1, T2, R0> void testRemoteLambdaParse(RayFunc3<T0, T1, T2, R0> f, int n,
|
||||
boolean forceNew, boolean debug)
|
||||
throws Exception {
|
||||
if (debug) {
|
||||
RayLog.core.info("parse#" + f.getClass().getName());
|
||||
}
|
||||
long start = System.nanoTime();
|
||||
for (int i = 0; i < n; i++) {
|
||||
MethodId mid = MethodId.fromSerializedLambda(f, forceNew);
|
||||
}
|
||||
long end = System.nanoTime();
|
||||
RayLog.core.info(String.format("remoteLambdaParse(new=%s):total=%sms, one=%s", forceNew,
|
||||
TimeUnit.NANOSECONDS.toMillis(end - start),
|
||||
(end - start) / n));
|
||||
}
|
||||
|
||||
public static <T0, T1, T2, R0> void testRemoteLambdaSerde(RayFunc3<T0, T1, T2, R0> f, int n,
|
||||
boolean de, boolean debug)
|
||||
throws Exception {
|
||||
if (debug) {
|
||||
RayLog.core.info("se#" + f.getClass().getName());
|
||||
}
|
||||
long start = System.nanoTime();
|
||||
for (int i = 0; i < n; i++) {
|
||||
ByteArrayOutputStream bytes = new ByteArrayOutputStream(1024);
|
||||
ObjectOutputStream out = new ObjectOutputStream(bytes);
|
||||
out.writeObject(f);
|
||||
out.close();
|
||||
if (de) {
|
||||
ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()));
|
||||
RayFunc3 def = (RayFunc3) in.readObject();
|
||||
in.close();
|
||||
if (debug) {
|
||||
RayLog.core.info("de#" + def.getClass().getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
long end = System.nanoTime();
|
||||
RayLog.core.info(
|
||||
String.format("remoteLambdaSer(de=%s):total=%sms,one=%s", de,
|
||||
TimeUnit.NANOSECONDS.toMillis(end - start),
|
||||
(end - start) / n));
|
||||
}
|
||||
|
||||
public static void testCall0(RayFunc0 f) {
|
||||
MethodId mid = MethodId.fromSerializedLambda(f);
|
||||
RayLog.core.info(mid.toString());
|
||||
Assert.assertEquals(mid.load(), CALL0);
|
||||
Assert.assertTrue(mid.isStatic);
|
||||
}
|
||||
|
||||
public static <T, R> void testCall1(RayFunc1<T, R> f, T t) {
|
||||
MethodId mid = MethodId.fromSerializedLambda(f);
|
||||
RayLog.core.info(mid.toString());
|
||||
Assert.assertEquals(mid.load(), CALL1);
|
||||
Assert.assertTrue(mid.isStatic);
|
||||
}
|
||||
|
||||
public static <T, R> void testCall2(RayFunc1<T, R> f) {
|
||||
MethodId mid = MethodId.fromSerializedLambda(f);
|
||||
RayLog.core.info(mid.toString());
|
||||
Assert.assertEquals(mid.load(), CALL2);
|
||||
Assert.assertTrue(!mid.isStatic);
|
||||
}
|
||||
|
||||
public static <T0, T1, T2, R0> void testCall3(RayFunc3<T0, T1, T2, R0> f) {
|
||||
MethodId mid = MethodId.fromSerializedLambda(f);
|
||||
RayLog.core.info(mid.toString());
|
||||
Assert.assertEquals(mid.load(), CALL3);
|
||||
Assert.assertTrue(!mid.isStatic);
|
||||
}
|
||||
|
||||
public static String call0() {
|
||||
long t = System.currentTimeMillis();
|
||||
RayLog.core.info("call0:" + t);
|
||||
return String.valueOf(t);
|
||||
}
|
||||
|
||||
public static String call1(Long v) {
|
||||
for (int i = 0; i < 100; i++) {
|
||||
v += i;
|
||||
}
|
||||
RayLog.core.info("call1:" + v);
|
||||
return String.valueOf(v);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLambdaSer() throws Exception {
|
||||
testCall0(LambdaUtilsTest::call0);
|
||||
testCall1(LambdaUtilsTest::call1, Long.valueOf(System.currentTimeMillis()));
|
||||
testCall2(LambdaUtilsTest::call2);
|
||||
testCall3(LambdaUtilsTest::call3);
|
||||
}
|
||||
|
||||
/**
|
||||
* to test the serdeLambda's perf.
|
||||
*/
|
||||
public void testBenchmark() throws Exception {
|
||||
//test serde
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 2, true, true);
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 2, true, true);
|
||||
//warmup
|
||||
|
||||
RayLog.core.info("warmup:serde################");
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, true, false);
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, false, false);
|
||||
RayLog.core.info("benchmark:serde################");
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, true, false);
|
||||
RayLog.core.info("benchmark:ser################");
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, false, false);
|
||||
|
||||
//test serde one new call's time, no class cache
|
||||
long start = System.nanoTime();
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false);
|
||||
long end = System.nanoTime();
|
||||
RayLog.core.info("one sertime:" + (end - start));
|
||||
|
||||
//test serde one new call's time, no class cache
|
||||
start = System.nanoTime();
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false);
|
||||
end = System.nanoTime();
|
||||
RayLog.core.info("one sertime:" + (end - start));
|
||||
|
||||
//test serde one new call's time, no class cache
|
||||
start = System.nanoTime();
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false);
|
||||
end = System.nanoTime();
|
||||
RayLog.core.info("one sertime:" + (end - start));
|
||||
|
||||
//test serde one new call's time, no class cache
|
||||
start = System.nanoTime();
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false);
|
||||
end = System.nanoTime();
|
||||
RayLog.core.info("one serdetime:" + (end - start));
|
||||
|
||||
//test serde one new call's time, no class cache
|
||||
start = System.nanoTime();
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false);
|
||||
end = System.nanoTime();
|
||||
RayLog.core.info("one serdetime:" + (end - start));
|
||||
|
||||
//test serde one new call's time, no class cache
|
||||
start = System.nanoTime();
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false);
|
||||
end = System.nanoTime();
|
||||
RayLog.core.info("one serdetime:" + (end - start));
|
||||
|
||||
//test lambda
|
||||
testRemoteLambdaParse(LambdaUtilsTest::call3, 2, true, true);
|
||||
testRemoteLambdaParse(LambdaUtilsTest::call3, 2, false, true);
|
||||
//warmup
|
||||
RayLog.core.info("warmup:parse################");
|
||||
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, true, false);
|
||||
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, false, false);
|
||||
RayLog.core.info("benchmark:parseNew################");
|
||||
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, true, false);
|
||||
RayLog.core.info("benchmark:parseCache################");
|
||||
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, false, false);
|
||||
}
|
||||
|
||||
public String call2() {
|
||||
long t = System.currentTimeMillis();
|
||||
RayLog.core.info("call2:" + t);
|
||||
return "call2:" + t;
|
||||
}
|
||||
|
||||
public String call3(Long v, String s) {
|
||||
for (int i = 0; i < 100; i++) {
|
||||
v += i;
|
||||
}
|
||||
RayLog.core.info("call3:" + v);
|
||||
return String.valueOf(v);
|
||||
}
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
|
||||
import java.lang.reflect.Executable;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.ray.api.function.RayFunc2;
|
||||
import org.ray.runtime.util.MethodId;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class MethodIdTest {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(MethodIdTest.class);
|
||||
|
||||
@Test
|
||||
public void testNormalMethod() throws Exception {
|
||||
RayFunc2<Integer, String, String> f = MethodIdTest::foo;
|
||||
MethodId m1 = MethodId.fromSerializedLambda(f);
|
||||
Executable e = MethodIdTest.class.getDeclaredMethod("foo", int.class, String.class);
|
||||
MethodId m2 = MethodId.fromExecutable(e);
|
||||
LOGGER.info("{}, {}", m1, m2);
|
||||
Assert.assertEquals(m1, m2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConstructor() throws Exception {
|
||||
RayFunc2<Integer, String, Foo> f = Foo::new;
|
||||
MethodId m1 = MethodId.fromSerializedLambda(f);
|
||||
Executable e = Foo.class.getConstructor(int.class, String.class);
|
||||
MethodId m2 = MethodId.fromExecutable(e);
|
||||
LOGGER.info("{}, {}", m1, m2);
|
||||
Assert.assertEquals(m1, m2);
|
||||
}
|
||||
|
||||
public static String foo(int a, String b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
public static class Foo {
|
||||
public Foo(int a, String b) {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.runtime.functionmanager.RayActorMethods;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class RayActorMethodsTest {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayActorMethodsTest.class);
|
||||
|
||||
@RayRemote
|
||||
public static class ExampleActor {
|
||||
|
||||
public void func1() {}
|
||||
|
||||
public void func2() {}
|
||||
|
||||
public static void func3() {}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testActorMethods() {
|
||||
RayActorMethods methods = RayActorMethods
|
||||
.fromClass(ExampleActor.class.getName(), RayActorMethodsTest.class.getClassLoader());
|
||||
LOGGER.info(methods.toString());
|
||||
Assert.assertEquals(methods.functions.size(), 2);
|
||||
Assert.assertEquals(methods.staticFunctions.size(), 1);
|
||||
}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.ray.runtime.functionmanager.RayMethod;
|
||||
import org.ray.runtime.functionmanager.RayTaskMethods;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
||||
public class RayTaskMethodsTest {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayTaskMethodsTest.class);
|
||||
|
||||
private static class Foo {
|
||||
|
||||
public Foo() {}
|
||||
|
||||
public Foo(int x) {}
|
||||
|
||||
public static void f1() {}
|
||||
|
||||
public void f2() {}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTask() {
|
||||
RayTaskMethods methods = RayTaskMethods
|
||||
.fromClass(Foo.class.getName(), Foo.class.getClassLoader());
|
||||
LOGGER.info(methods.toString());
|
||||
int numMethods = 0;
|
||||
int numConstructors = 0;
|
||||
for (RayMethod m : methods.functions.values()) {
|
||||
if (m.isConstructor()) {
|
||||
numConstructors += 1;
|
||||
} else {
|
||||
numMethods += 1;
|
||||
}
|
||||
}
|
||||
Assert.assertEquals(numMethods, 1);
|
||||
Assert.assertEquals(numConstructors, 2);
|
||||
}
|
||||
}
|
||||
@@ -62,6 +62,13 @@ table TaskInfo {
|
||||
required_resources: [ResourcePair];
|
||||
// The language that this task belongs to
|
||||
language: TaskLanguage;
|
||||
// Function descriptor, which is a list of strings that can
|
||||
// uniquely describe a function.
|
||||
// For a Python function, it should be: [module_name, class_name, function_name]
|
||||
// For a Java function, it should be: [class_name, method_name, type_descriptor]
|
||||
// TODO(hchen): after changing Python worker to use function_descriptor,
|
||||
// function_id can be removed.
|
||||
function_descriptor: [string];
|
||||
}
|
||||
|
||||
// Object information data structure.
|
||||
|
||||
Reference in New Issue
Block a user