diff --git a/java/checkstyle-suppressions.xml b/java/checkstyle-suppressions.xml index c32146e18..619c24e14 100644 --- a/java/checkstyle-suppressions.xml +++ b/java/checkstyle-suppressions.xml @@ -3,6 +3,7 @@ "http://www.puppycrawl.com/dtds/suppressions_1_1.dtd"> + diff --git a/java/cli/src/main/java/org/ray/cli/RayCli.java b/java/cli/src/main/java/org/ray/cli/RayCli.java index e683682a7..7c39030fd 100644 --- a/java/cli/src/main/java/org/ray/cli/RayCli.java +++ b/java/cli/src/main/java/org/ray/cli/RayCli.java @@ -1,24 +1,19 @@ package org.ray.cli; import com.beust.jcommander.JCommander; -import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.List; -import net.lingala.zip4j.core.ZipFile; import org.ray.api.id.UniqueId; import org.ray.runtime.config.PathConfig; import org.ray.runtime.config.RayParameters; import org.ray.runtime.config.RunMode; -import org.ray.runtime.functionmanager.NativeRemoteFunctionManager; -import org.ray.runtime.functionmanager.RemoteFunctionManager; import org.ray.runtime.gcs.KeyValueStoreLink; import org.ray.runtime.gcs.RedisClient; import org.ray.runtime.gcs.StateStoreProxy; import org.ray.runtime.gcs.StateStoreProxyImpl; import org.ray.runtime.runner.RunManager; import org.ray.runtime.runner.worker.DefaultDriver; -import org.ray.runtime.util.FileUtil; import org.ray.runtime.util.config.ConfigReader; import org.ray.runtime.util.logger.RayLog; @@ -140,64 +135,18 @@ public class RayCli { ConfigReader config = new ConfigReader(configPath, "ray.java.start.deploy=true"); PathConfig paths = new PathConfig(config); RayParameters params = new RayParameters(config); - params.redis_address = cmdSubmit.redisAddress; params.run_mode = RunMode.CLUSTER; - KeyValueStoreLink kvStore = new RedisClient(); kvStore.setAddr(cmdSubmit.redisAddress); StateStoreProxy stateStoreProxy = new StateStoreProxyImpl(kvStore); stateStoreProxy.initializeGlobalState(); - RemoteFunctionManager functionManager = new NativeRemoteFunctionManager(kvStore); - - // Register app to Redis. - byte[] zip = FileUtil.fileToBytes(cmdSubmit.packageZip); - - String packageName = cmdSubmit.packageZip.substring( - cmdSubmit.packageZip.lastIndexOf('/') + 1, - cmdSubmit.packageZip.lastIndexOf('.')); - - UniqueId resourceId = functionManager.registerResource(zip); - // Init RayLog before using it. RayLog.init(params.log_dir); - - RayLog.rapp.debug( - "registerResource " + resourceId + " for package " + packageName + " done"); - UniqueId appId = params.driver_id; - functionManager.registerApp(appId, resourceId); - RayLog.rapp.debug("registerApp " + appId + " for resouorce " + resourceId + " done"); - - // Unzip the package file. String appDir = "/tmp/" + cmdSubmit.className; - String extPath = appDir + "/" + packageName; - if (!FileUtil.createDir(extPath, false)) { - throw new RuntimeException("create dir " + extPath + " failed "); - } - - ZipFile zipFile = new ZipFile(cmdSubmit.packageZip); - zipFile.extractAll(extPath); - - // Build the args for driver process. - File originDirFile = new File(extPath); - File[] topFiles = originDirFile.listFiles(); - String topDir = null; - for (File file : topFiles) { - if (file.isDirectory()) { - topDir = file.getName(); - } - } - RayLog.rapp.debug("topDir of app classes: " + topDir); - if (topDir == null) { - RayLog.rapp.error("Can't find topDir of app classes, the app directory " + appDir); - return; - } - - String additionalClassPath = appDir + "/" + packageName + "/" + topDir + "/*"; - RayLog.rapp.debug("Find app class path " + additionalClassPath); // Start driver process. RunManager runManager = new RunManager(params, paths, config); @@ -209,18 +158,15 @@ public class RayCli { params.node_ip_address, cmdSubmit.className, cmdSubmit.classArgs, - additionalClassPath, + "", null); - if (null == proc) { - RayLog.rapp.error( - "Create process for app " + packageName + " in local directory " + appDir - + " failed"); + if (null == proc) { + RayLog.rapp.error("Failed to start driver."); return; } - RayLog.rapp - .info("Create app " + appDir + " for package " + packageName + " succeeded"); + RayLog.rapp.info("Driver started."); } private static String getConfigPath(String config) { diff --git a/java/pom.xml b/java/pom.xml index 6859502ac..28b745da6 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -105,6 +105,11 @@ mockito-all 1.10.19 + + junit + junit + 4.11 + diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index f709839d9..6b0c69e7b 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -66,5 +66,12 @@ org.ow2.asm asm + + + + junit + junit + test + \ No newline at end of file diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 8032e9642..50189b82b 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -17,15 +17,13 @@ import org.ray.api.id.UniqueId; import org.ray.api.runtime.RayRuntime; import org.ray.runtime.config.PathConfig; import org.ray.runtime.config.RayParameters; -import org.ray.runtime.functionmanager.LocalFunctionManager; -import org.ray.runtime.functionmanager.RayMethod; -import org.ray.runtime.functionmanager.RemoteFunctionManager; +import org.ray.runtime.functionmanager.FunctionManager; +import org.ray.runtime.functionmanager.RayFunction; import org.ray.runtime.objectstore.ObjectStoreProxy; import org.ray.runtime.objectstore.ObjectStoreProxy.GetStatus; import org.ray.runtime.raylet.RayletClient; import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.TaskSpec; -import org.ray.runtime.util.MethodId; import org.ray.runtime.util.ResourceUtil; import org.ray.runtime.util.UniqueIdHelper; import org.ray.runtime.util.config.ConfigReader; @@ -44,8 +42,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { protected Worker worker; protected RayletClient rayletClient; protected ObjectStoreProxy objectStoreProxy; - protected LocalFunctionManager functions; - protected RemoteFunctionManager remoteFunctionManager; + protected FunctionManager functionManager; protected PathConfig pathConfig; /** @@ -121,13 +118,11 @@ public abstract class AbstractRayRuntime implements RayRuntime { protected void init( RayletClient slink, ObjectStoreLink plink, - RemoteFunctionManager remoteLoader, PathConfig pathManager ) { - remoteFunctionManager = remoteLoader; pathConfig = pathManager; - functions = new LocalFunctionManager(remoteLoader); + functionManager = new FunctionManager(); rayletClient = slink; objectStoreProxy = new ObjectStoreProxy(plink); @@ -308,6 +303,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { } RayActorImpl actorImpl = (RayActorImpl)actor; TaskSpec spec = createTaskSpec(func, actorImpl, args, false); + spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor()); actorImpl.setTaskCursor(spec.returnIds[1]); rayletClient.submitTask(spec); return new RayObjectImpl(spec.returnIds[0]); @@ -357,33 +353,20 @@ public abstract class AbstractRayRuntime implements RayRuntime { actorCreationId = returnIds[0]; } - MethodId methodId = MethodId.fromSerializedLambda(func); - - // NOTE: we append the class name at the end of arguments, - // so that we can look up the method based on the class name. - // TODO(hchen): move class name to task spec. - args = Arrays.copyOf(args, args.length + 1); - args[args.length - 1] = methodId.className; - - RayMethod rayMethod = functions.getMethod( - current.driverId, actor.getId(), new UniqueId(methodId.getSha1Hash()), methodId.className - ).getRight(); - UniqueId funcId = rayMethod.getFuncId(); - + RayFunction rayFunction = functionManager.getFunction(current.driverId, func); return new TaskSpec( current.driverId, taskId, current.taskId, -1, + actorCreationId, actor.getId(), + actor.getHandleId(), actor.increaseTaskCounter(), - funcId, ArgumentsBuilder.wrap(args), returnIds, - actor.getHandleId(), - actorCreationId, - ResourceUtil.getResourcesMapFromArray(rayMethod.remoteAnnotation), - actor.getTaskCursor() + ResourceUtil.getResourcesMapFromArray(rayFunction.getRayRemoteAnnotation()), + rayFunction.getFunctionDescriptor() ); } @@ -399,8 +382,8 @@ public abstract class AbstractRayRuntime implements RayRuntime { return rayletClient; } - public LocalFunctionManager getLocalFunctionManager() { - return functions; + public FunctionManager getFunctionManager() { + return functionManager; } } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java index 7fd685146..21714e0f6 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java @@ -2,8 +2,6 @@ package org.ray.runtime; import org.ray.runtime.config.PathConfig; import org.ray.runtime.config.RayParameters; -import org.ray.runtime.functionmanager.NopRemoteFunctionManager; -import org.ray.runtime.functionmanager.RemoteFunctionManager; import org.ray.runtime.objectstore.MockObjectStore; import org.ray.runtime.raylet.MockRayletClient; @@ -12,11 +10,9 @@ public class RayDevRuntime extends AbstractRayRuntime { @Override public void start(RayParameters params) { PathConfig pathConfig = new PathConfig(configReader); - RemoteFunctionManager rfm = new NopRemoteFunctionManager(params.driver_id); MockObjectStore store = new MockObjectStore(); MockRayletClient scheduler = new MockRayletClient(this, store); - init(scheduler, store, rfm, pathConfig); - scheduler.setLocalFunctionManager(this.functions); + init(scheduler, store, pathConfig); } @Override diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index 19eaa4124..04bbdac26 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -8,9 +8,6 @@ import org.apache.arrow.plasma.PlasmaClient; import org.ray.runtime.config.PathConfig; import org.ray.runtime.config.RayParameters; import org.ray.runtime.config.WorkerMode; -import org.ray.runtime.functionmanager.NativeRemoteFunctionManager; -import org.ray.runtime.functionmanager.NopRemoteFunctionManager; -import org.ray.runtime.functionmanager.RemoteFunctionManager; import org.ray.runtime.gcs.AddressInfo; import org.ray.runtime.gcs.KeyValueStoreLink; import org.ray.runtime.gcs.RedisClient; @@ -61,10 +58,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime { } } - // initialize remote function manager - RemoteFunctionManager funcMgr = params.run_mode.isDevPathManager() - ? new NopRemoteFunctionManager(params.driver_id) : new NativeRemoteFunctionManager(kvStore); - // initialize worker context if (params.worker_mode == WorkerMode.DRIVER) { // TODO: The relationship between workerID, driver_id and dummy_task.driver_id should be @@ -96,7 +89,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { WorkerContext.currentTask().taskId ); - init(rayletClient, plink, funcMgr, pathConfig); + init(rayletClient, plink, pathConfig); // register registerWorker(isWorker, params.node_ip_address, params.object_store_name, diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/Worker.java index cb6bb17c2..d9215e33f 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/Worker.java @@ -1,19 +1,22 @@ package org.ray.runtime; -import org.apache.commons.lang3.tuple.Pair; import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; -import org.ray.runtime.functionmanager.RayMethod; +import org.ray.runtime.functionmanager.RayFunction; import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.logger.RayLog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** - * The worker, which pulls tasks from {@code org.ray.spi.LocalSchedulerProxy} and executes them + * The worker, which pulls tasks from {@link org.ray.runtime.raylet.RayletClient} and executes them * continuously. */ public class Worker { + private static final Logger LOGGER = LoggerFactory.getLogger(Worker.class); + private final AbstractRayRuntime runtime; public Worker(AbstractRayRuntime runtime) { @@ -22,7 +25,7 @@ public class Worker { public void loop() { while (true) { - RayLog.core.info(Thread.currentThread().getName() + ":fetching new task..."); + LOGGER.info("Fetching new task in thread {}.", Thread.currentThread().getName()); TaskSpec task = runtime.getRayletClient().getTask(); execute(task); } @@ -32,27 +35,26 @@ public class Worker { * Execute a task. */ public void execute(TaskSpec spec) { - RayLog.core.info("Executing task {}", spec.taskId); + LOGGER.info("Executing task {}", spec.taskId); + LOGGER.debug("Executing task {}", spec); UniqueId returnId = spec.returnIds[0]; ClassLoader oldLoader = Thread.currentThread().getContextClassLoader(); try { // Get method - Pair pair = runtime.getLocalFunctionManager().getMethod( - spec.driverId, spec.actorId, spec.functionId, spec.args); - ClassLoader classLoader = pair.getLeft(); - RayMethod method = pair.getRight(); + RayFunction rayFunction = runtime.getFunctionManager() + .getFunction(spec.driverId, spec.functionDescriptor); // Set context - WorkerContext.prepare(spec, classLoader); - Thread.currentThread().setContextClassLoader(classLoader); + WorkerContext.prepare(spec, rayFunction.classLoader); + Thread.currentThread().setContextClassLoader(rayFunction.classLoader); // Get local actor object and arguments. Object actor = spec.isActorTask() ? runtime.localActors.get(spec.actorId) : null; - Object[] args = ArgumentsBuilder.unwrap(spec, classLoader); + Object[] args = ArgumentsBuilder.unwrap(spec, rayFunction.classLoader); // Execute the task. Object result; - if (!method.isConstructor()) { - result = method.getMethod().invoke(actor, args); + if (!rayFunction.isConstructor()) { + result = rayFunction.getMethod().invoke(actor, args); } else { - result = method.getConstructor().newInstance(args); + result = rayFunction.getConstructor().newInstance(args); } // Set result if (!spec.isActorCreationTask()) { diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java index dc08c562c..55099fb0d 100644 --- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java @@ -34,15 +34,20 @@ public class WorkerContext { WorkerContext ctx = new WorkerContext(); currentWorkerCtx.set(ctx); - TaskSpec dummy = new TaskSpec(); - dummy.parentTaskId = UniqueId.NIL; - if (params.worker_mode == WorkerMode.DRIVER) { - dummy.taskId = UniqueId.randomId(); - } else { - dummy.taskId = UniqueId.NIL; - } - dummy.actorId = UniqueId.NIL; - dummy.driverId = params.driver_id; + TaskSpec dummy = new TaskSpec( + params.driver_id, + params.worker_mode == WorkerMode.DRIVER ? UniqueId.randomId() : UniqueId.NIL, + UniqueId.NIL, + 0, + UniqueId.NIL, + UniqueId.NIL, + UniqueId.NIL, + 0, + null, + null, + null, + null + ); prepare(dummy, null); return ctx; diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionDescriptor.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionDescriptor.java new file mode 100644 index 000000000..70be2f3e9 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionDescriptor.java @@ -0,0 +1,52 @@ +package org.ray.runtime.functionmanager; + +import com.google.common.base.Objects; + +/** + * Represents the function's metadata. + */ +public final class FunctionDescriptor { + + /** + * Function's class name. + */ + public final String className; + /** + * Function's name. + */ + public final String name; + /** + * Function's type descriptor. + */ + public final String typeDescriptor; + + public FunctionDescriptor(String className, String name, String typeDescriptor) { + this.className = className; + this.name = name; + this.typeDescriptor = typeDescriptor; + } + + @Override + public String toString() { + return className + "." + name; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FunctionDescriptor that = (FunctionDescriptor) o; + return Objects.equal(className, that.className) && + Objects.equal(name, that.name) && + Objects.equal(typeDescriptor, that.typeDescriptor); + } + + @Override + public int hashCode() { + return Objects.hashCode(className, name, typeDescriptor); + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java new file mode 100644 index 000000000..e58674164 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java @@ -0,0 +1,132 @@ +package org.ray.runtime.functionmanager; + +import java.lang.invoke.SerializedLambda; +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.WeakHashMap; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.objectweb.asm.Type; +import org.ray.api.function.RayFunc; +import org.ray.api.id.UniqueId; +import org.ray.runtime.util.LambdaUtils; + +/** + * Manages functions by driver id. + */ +public class FunctionManager { + + static final String CONSTRUCTOR_NAME = ""; + + /** + * Cache from a RayFunc object to its corresponding FunctionDescriptor. Because + * `LambdaUtils.getSerializedLambda` is expensive. + */ + private static final ThreadLocal, FunctionDescriptor>> + RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new); + + /** + * Mapping from the driver id to the functions that belong to this driver. + */ + private Map driverFunctionTables = new HashMap<>(); + + /** + * Get the RayFunction from a RayFunc instance (a lambda). + * + * @param driverId current driver id. + * @param func The lambda. + * @return A RayFunction object. + */ + public RayFunction getFunction(UniqueId driverId, RayFunc func) { + FunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass()); + if (functionDescriptor == null) { + SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func); + final String className = serializedLambda.getImplClass().replace('/', '.'); + final String methodName = serializedLambda.getImplMethodName(); + final String typeDescriptor = serializedLambda.getImplMethodSignature(); + functionDescriptor = new FunctionDescriptor(className, methodName, typeDescriptor); + } + return getFunction(driverId, functionDescriptor); + } + + /** + * Get the RayFunction from a function descriptor. + * + * @param driverId Current driver id. + * @param functionDescriptor The function descriptor. + * @return A RayFunction object. + */ + public RayFunction getFunction(UniqueId driverId, FunctionDescriptor functionDescriptor) { + DriverFunctionTable driverFunctionTable = driverFunctionTables.get(driverId); + if (driverFunctionTable == null) { + //TODO(hchen): distinguish class loader by driver id. + ClassLoader classLoader = getClass().getClassLoader(); + driverFunctionTable = new DriverFunctionTable(classLoader); + driverFunctionTables.put(driverId, driverFunctionTable); + } + return driverFunctionTable.getFunction(functionDescriptor); + } + + /** + * Manages all functions that belong to one driver. + */ + static class DriverFunctionTable { + + /** + * The driver's corresponding class loader. + */ + ClassLoader classLoader; + /** + * Functions per class, per function name + type descriptor. + */ + Map, RayFunction>> functions; + + DriverFunctionTable(ClassLoader classLoader) { + this.classLoader = classLoader; + this.functions = new HashMap<>(); + } + + RayFunction getFunction(FunctionDescriptor descriptor) { + Map, RayFunction> classFunctions = functions.get(descriptor.className); + if (classFunctions == null) { + classFunctions = loadFunctionsForClass(descriptor.className); + functions.put(descriptor.className, classFunctions); + } + return classFunctions.get(ImmutablePair.of(descriptor.name, descriptor.typeDescriptor)); + } + + /** + * Load all functions from a class. + */ + Map, RayFunction> loadFunctionsForClass(String className) { + Map, RayFunction> map = new HashMap<>(); + try { + Class clazz = Class.forName(className, true, classLoader); + + List executables = new ArrayList<>(); + executables.addAll(Arrays.asList(clazz.getDeclaredMethods())); + executables.addAll(Arrays.asList(clazz.getConstructors())); + + for (Executable e : executables) { + e.setAccessible(true); + final String methodName = e instanceof Method ? e.getName() : CONSTRUCTOR_NAME; + final Type type = + e instanceof Method ? Type.getType((Method) e) : Type.getType((Constructor) e); + final String typeDescriptor = type.getDescriptor(); + RayFunction rayFunction = new RayFunction(e, classLoader, + new FunctionDescriptor(className, methodName, typeDescriptor)); + map.put(ImmutablePair.of(methodName, typeDescriptor), rayFunction); + } + } catch (Exception e) { + throw new RuntimeException("Failed to load functions from class " + className, e); + } + return map; + } + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/LocalFunctionManager.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/LocalFunctionManager.java deleted file mode 100644 index 304f80566..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/LocalFunctionManager.java +++ /dev/null @@ -1,119 +0,0 @@ -package org.ray.runtime.functionmanager; - -import com.google.common.base.Preconditions; -import java.util.concurrent.ConcurrentHashMap; -import org.apache.commons.lang3.tuple.Pair; -import org.ray.api.id.UniqueId; -import org.ray.runtime.task.FunctionArg; -import org.ray.runtime.util.Serializer; -import org.ray.runtime.util.logger.RayLog; - -/** - * local function manager which pulls remote functions on demand. - */ -public class LocalFunctionManager { - - private final RemoteFunctionManager remoteLoader; - - private final ConcurrentHashMap functionTables - = new ConcurrentHashMap<>(); - - /** - * initialize load function manager using remote function manager to pull remote functions on - * demand. - */ - public LocalFunctionManager(RemoteFunctionManager remoteLoader) { - this.remoteLoader = remoteLoader; - } - - private FunctionTable loadDriverFunctions(UniqueId driverId) { - FunctionTable functionTable = functionTables.get(driverId); - if (functionTable == null) { - RayLog.core.info("DriverId " + driverId + " Try to load functions"); - ClassLoader classLoader = remoteLoader.loadResource(driverId); - if (classLoader == null) { - throw new RuntimeException( - "Cannot find resource' classLoader for app " + driverId.toString()); - } - functionTable = new FunctionTable(classLoader); - functionTables.put(driverId, functionTable); - } - return functionTable; - } - - public Pair getMethod(UniqueId driverId, UniqueId actorId, - UniqueId methodId, String className) { - // assert the driver's resource is load. - FunctionTable functionTable = loadDriverFunctions(driverId); - Preconditions.checkNotNull(functionTable, "driver's resource is not loaded:%s", driverId); - RayMethod method = actorId.isNil() ? functionTable.getTaskMethod(methodId, className) - : functionTable.getActorMethod(methodId, className); - Preconditions - .checkNotNull(method, "method not found, class=%s, methodId=%s, driverId=%s", className, - methodId, driverId); - return Pair.of(functionTable.classLoader, method); - } - - /** - * get local method for executing, which pulls information from remote repo on-demand, therefore - * it may block for a while if the related resources (e.g., jars) are not ready on local machine - */ - public Pair getMethod(UniqueId driverId, UniqueId actorId, - UniqueId methodId, FunctionArg[] args) { - Preconditions.checkArgument(args.length >= 1, "method's args len %s<=1", args.length); - String className = (String) Serializer.decode(args[args.length - 1].data); - return getMethod(driverId, actorId, methodId, className); - } - - /** - * unload the functions when the driver is declared dead. - */ - public synchronized void removeApp(UniqueId driverId) { - FunctionTable funcs = functionTables.get(driverId); - if (funcs != null) { - functionTables.remove(driverId); - remoteLoader.unloadFunctions(driverId); - } - } - - private static class FunctionTable { - - final ClassLoader classLoader; - final ConcurrentHashMap taskMethods = new ConcurrentHashMap<>(); - final ConcurrentHashMap actorMethods = new ConcurrentHashMap<>(); - - FunctionTable(ClassLoader classLoader) { - this.classLoader = classLoader; - } - - RayMethod getTaskMethod(UniqueId methodId, String className) { - RayTaskMethods taskMethods = this.taskMethods.get(className); - if (taskMethods == null) { - taskMethods = RayTaskMethods.fromClass(className, classLoader); - RayLog.core.info("create RayTaskMethods: {}", taskMethods); - this.taskMethods.put(className, taskMethods); - } - RayMethod m = taskMethods.functions.get(methodId); - if (m != null) { - return m; - } - // it is a actor static func. - return getActorMethod(methodId, className, true); - } - - RayMethod getActorMethod(UniqueId methodId, String className) { - return getActorMethod(methodId, className, false); - } - - private RayMethod getActorMethod(UniqueId methodId, String className, boolean isStatic) { - RayActorMethods actorMethods = this.actorMethods.get(className); - if (actorMethods == null) { - actorMethods = RayActorMethods.fromClass(className, classLoader); - RayLog.core.info("create RayActorMethods: {}", actorMethods); - this.actorMethods.put(className, actorMethods); - } - return isStatic ? actorMethods.staticFunctions.get(methodId) - : actorMethods.functions.get(methodId); - } - } -} diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/NativeRemoteFunctionManager.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/NativeRemoteFunctionManager.java deleted file mode 100644 index 0adf3b0ce..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/NativeRemoteFunctionManager.java +++ /dev/null @@ -1,131 +0,0 @@ -package org.ray.runtime.functionmanager; - -import java.io.File; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; -import java.util.concurrent.ConcurrentHashMap; -import net.lingala.zip4j.core.ZipFile; -import org.ray.api.id.UniqueId; -import org.ray.runtime.gcs.KeyValueStoreLink; -import org.ray.runtime.util.FileUtil; -import org.ray.runtime.util.JarLoader; -import org.ray.runtime.util.Sha1Digestor; -import org.ray.runtime.util.SystemUtil; -import org.ray.runtime.util.logger.RayLog; - -/** - * native implementation of remote function manager. - */ -public class NativeRemoteFunctionManager implements RemoteFunctionManager { - - private final ConcurrentHashMap loadedApps = new ConcurrentHashMap<>(); - private MessageDigest md; - private final String appDir = System.getProperty("user.dir") + "/apps"; - private final KeyValueStoreLink kvStore; - - - public NativeRemoteFunctionManager(KeyValueStoreLink kvStore) throws NoSuchAlgorithmException { - this.kvStore = kvStore; - md = MessageDigest.getInstance("SHA-1"); - File appDir = new File(this.appDir); - if (!appDir.exists()) { - appDir.mkdirs(); - } - - } - - @Override - public UniqueId registerResource(byte[] resourceZip) { - byte[] digest = Sha1Digestor.digest(resourceZip); - assert (digest.length == UniqueId.LENGTH); - - UniqueId resourceId = new UniqueId(digest); - - // TODO: resources must be saved in persistent store - kvStore.set(resourceId.getBytes(), resourceZip, null); - - return resourceId; - } - - @Override - public byte[] getResource(UniqueId resourceId) { - return kvStore.get(resourceId.getBytes(), null); - } - - @Override - public void unregisterResource(UniqueId resourceId) { - kvStore.delete(resourceId.getBytes(), null); - } - - @Override - public void registerApp(UniqueId driverId, UniqueId resourceId) { - kvStore.set("App2ResMap", resourceId.toString(), driverId.toString()); - } - - @Override - public UniqueId getAppResourceId(UniqueId driverId) { - return UniqueId.fromHexString(kvStore.get("App2ResMap", driverId.toString())); - } - - @Override - public void unregisterApp(UniqueId driverId) { - kvStore.delete("App2ResMap", driverId.toString()); - } - - @Override - public ClassLoader loadResource(UniqueId driverId) { - ClassLoader classLoader = loadedApps.get(driverId); - if (classLoader == null) { - synchronized (this) { - classLoader = loadedApps.get(driverId); - if (classLoader == null) { - classLoader = initLoadedApps(driverId); - } - } - } - return classLoader; - } - - private ClassLoader initLoadedApps(UniqueId driverId) { - try { - RayLog.core.info("initLoadedApps" + driverId.toString()); - - ClassLoader cl = loadedApps.get(driverId); - if (cl == null) { - UniqueId resId = UniqueId.fromHexString(kvStore.get("App2ResMap", driverId.toString())); - byte[] res = getResource(resId); - if (res == null) { - throw new RuntimeException("get resource null, the resId " + resId.toString()); - } - RayLog.core.info("get resource of " + resId.toString() + ", result len " + res.length); - String resPath = - appDir + "/" + driverId.toString() + "/" + String.valueOf(SystemUtil.pid()); - File dir = new File(resPath); - if (!dir.exists()) { - dir.mkdirs(); - } - String zipPath = resPath + ".zip"; - RayLog.rapp.info("unzip app file: zipPath " + zipPath + " resPath " + resPath); - FileUtil.bytesToFile(res, zipPath); - ZipFile zipFile = new ZipFile(zipPath); - zipFile.extractAll(resPath); - cl = JarLoader.loadJars(resPath, false); - loadedApps.put(driverId, cl); - } - return cl; - } catch (Exception e) { - RayLog.rapp.error("load function for " + driverId + " failed, ex = " + e.getMessage(), e); - return null; - } - } - - @Override - public synchronized void unloadFunctions(UniqueId driverId) { - ClassLoader cl = loadedApps.get(driverId); - try { - JarLoader.unloadJars(cl); - } catch (Exception e) { - RayLog.rapp.error("unload function for " + driverId + " failed, ex = " + e.getMessage(), e); - } - } -} \ No newline at end of file diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/NopRemoteFunctionManager.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/NopRemoteFunctionManager.java deleted file mode 100644 index ca488bac9..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/NopRemoteFunctionManager.java +++ /dev/null @@ -1,58 +0,0 @@ -package org.ray.runtime.functionmanager; - -import org.ray.api.id.UniqueId; - -/** - * mock version of remote function manager using local loaded jars. - */ -public class NopRemoteFunctionManager implements RemoteFunctionManager { - - public NopRemoteFunctionManager(UniqueId driverId) { - //onLoad(driverId, Agent.hookedMethods); - //Agent.consumers.add(m -> { this.onLoad(m); }); - } - - @Override - public UniqueId registerResource(byte[] resourceZip) { - return null; - // nothing to do - } - - @Override - public byte[] getResource(UniqueId resourceId) { - return null; - } - - @Override - public void unregisterResource(UniqueId resourceId) { - // nothing to do - } - - @Override - public void registerApp(UniqueId driverId, UniqueId resourceId) { - // nothing to do - } - - @Override - public UniqueId getAppResourceId(UniqueId driverId) { - return null; - // nothing to do - } - - @Override - public void unregisterApp(UniqueId driverId) { - // nothing to do - } - - @Override - public ClassLoader loadResource(UniqueId driverId) { - //assert (startupDriverId().equals(driverId)); - return this.getClass().getClassLoader(); - } - - @Override - public void unloadFunctions(UniqueId driverId) { - // never - //assert (startupDriverId().equals(driverId)); - } -} \ No newline at end of file diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayActorMethods.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayActorMethods.java deleted file mode 100644 index ee8fa35bd..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayActorMethods.java +++ /dev/null @@ -1,68 +0,0 @@ -package org.ray.runtime.functionmanager; - -import com.google.common.base.Preconditions; -import java.lang.reflect.Executable; -import java.lang.reflect.Modifier; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.ray.api.annotation.RayRemote; -import org.ray.api.id.UniqueId; - - -public final class RayActorMethods { - - public final Class clazz; - public final RayRemote remoteAnnotation; - public final Map functions; - /** - * the static function in Actor, call as task. - */ - public final Map staticFunctions; - - private RayActorMethods(Class clazz, RayRemote remoteAnnotation, - Map functions, Map staticFunctions) { - this.clazz = clazz; - this.remoteAnnotation = remoteAnnotation; - this.functions = Collections.unmodifiableMap(new HashMap<>(functions)); - this.staticFunctions = Collections.unmodifiableMap(new HashMap<>(staticFunctions)); - } - - public static RayActorMethods fromClass(String className, ClassLoader classLoader) { - try { - Class clazz = Class.forName(className, true, classLoader); - RayRemote remoteAnnotation = (RayRemote) clazz.getAnnotation(RayRemote.class); - Preconditions.checkNotNull(remoteAnnotation, - "%s must be annotated with @RayRemote", className); - - List executables = new ArrayList<>(Arrays.asList(clazz.getDeclaredMethods())); - - Map functions = new HashMap<>(); - Map staticFunctions = new HashMap<>(); - - for (Executable e : executables) { - RayMethod rayMethod = RayMethod.from(e, remoteAnnotation); - if (Modifier.isStatic(e.getModifiers())) { - staticFunctions.put(rayMethod.getFuncId(), rayMethod); - } else { - functions.put(rayMethod.getFuncId(), rayMethod); - } - } - return new RayActorMethods(clazz, remoteAnnotation, functions, staticFunctions); - } catch (Exception e) { - throw new RuntimeException("failed to get RayActorMethods from " + className, e); - } - } - - @Override - public String toString() { - return String - .format("RayActorMethods:%s, funcNum=%s:{%s}, sfuncNum=%s:{%s}", clazz, functions.size(), - functions.values(), - staticFunctions.size(), staticFunctions.values()); - } - -} \ No newline at end of file diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java new file mode 100644 index 000000000..3d0704c6b --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java @@ -0,0 +1,74 @@ +package org.ray.runtime.functionmanager; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; +import java.lang.reflect.Method; +import org.ray.api.annotation.RayRemote; + +/** + * Represents a Ray function (either a Method or a Constructor in Java) and its metadata. + */ +public class RayFunction { + + /** + * The executor object, can be either a Method or a Constructor. + */ + public final Executable executable; + + /** + * This function's class loader. + */ + public final ClassLoader classLoader; + + /** + * Function's metadata. + */ + public final FunctionDescriptor functionDescriptor; + + public RayFunction(Executable executable, ClassLoader classLoader, + FunctionDescriptor functionDescriptor) { + this.executable = executable; + this.classLoader = classLoader; + this.functionDescriptor = functionDescriptor; + } + + /** + * @return True if it's a constructor, otherwise it's a method. + */ + public boolean isConstructor() { + return executable instanceof Constructor; + } + + /** + * @return The underlying constructor object. + */ + public Constructor getConstructor() { + return (Constructor) executable; + } + + /** + * @return The underlying method object. + */ + public Method getMethod() { + return (Method) executable; + } + + public FunctionDescriptor getFunctionDescriptor() { + return functionDescriptor; + } + + public RayRemote getRayRemoteAnnotation() { + RayRemote rayRemote = executable.getAnnotation(RayRemote.class); + if (rayRemote == null) { + // If the method doesn't have a annotation, get the annotation from + // its wrapping class. + rayRemote = executable.getDeclaringClass().getAnnotation(RayRemote.class); + } + return rayRemote; + } + + @Override + public String toString() { + return executable.toString(); + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayMethod.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayMethod.java deleted file mode 100644 index af935eaaf..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayMethod.java +++ /dev/null @@ -1,57 +0,0 @@ -package org.ray.runtime.functionmanager; - -import java.lang.reflect.Constructor; -import java.lang.reflect.Executable; -import java.lang.reflect.Method; -import org.ray.api.annotation.RayRemote; -import org.ray.api.id.UniqueId; -import org.ray.runtime.util.MethodId; - -/** - * method info. - */ -public class RayMethod { - - public final Executable invokable; - public final String fullName; - public final RayRemote remoteAnnotation; - private final UniqueId funcId; - - private RayMethod(Executable e, RayRemote remoteAnnotation, UniqueId funcId) { - this.invokable = e; - this.remoteAnnotation = remoteAnnotation; - this.funcId = funcId; - fullName = e.getDeclaringClass().getName() + "." + e.getName(); - } - - public static RayMethod from(Executable e, RayRemote parentRemoteAnnotation) { - RayRemote remoteAnnotation = e.getAnnotation(RayRemote.class); - MethodId mid = MethodId.fromExecutable(e); - UniqueId funcId = new UniqueId(mid.getSha1Hash()); - RayMethod method = new RayMethod(e, - remoteAnnotation != null ? remoteAnnotation : parentRemoteAnnotation, - funcId); - return method; - } - - public boolean isConstructor() { - return invokable instanceof Constructor; - } - - public Constructor getConstructor() { - return (Constructor) invokable; - } - - public Method getMethod() { - return (Method) invokable; - } - - @Override - public String toString() { - return fullName; - } - - public UniqueId getFuncId() { - return funcId; - } -} diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayTaskMethods.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayTaskMethods.java deleted file mode 100644 index 43c9e8bd0..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayTaskMethods.java +++ /dev/null @@ -1,56 +0,0 @@ -package org.ray.runtime.functionmanager; - -import java.lang.reflect.Constructor; -import java.lang.reflect.Executable; -import java.lang.reflect.Modifier; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.ray.api.id.UniqueId; - - -public final class RayTaskMethods { - - public final Class clazz; - public final Map functions; - - public RayTaskMethods(Class clazz, - Map functions) { - this.clazz = clazz; - this.functions = Collections.unmodifiableMap(new HashMap<>(functions)); - } - - public static RayTaskMethods fromClass(String clazzName, ClassLoader classLoader) { - try { - Class clazz = Class.forName(clazzName, true, classLoader); - List executables = new ArrayList<>(); - executables.addAll(Arrays.asList(clazz.getDeclaredMethods())); - executables.addAll(Arrays.asList(clazz.getConstructors())); - Map functions = new HashMap<>(executables.size()); - - for (Executable e : executables) { - // This executable must be either a constructor or a static method. - if (!(e instanceof Constructor) - && !Modifier.isStatic(e.getModifiers())) { - continue; - } - e.setAccessible(true); - RayMethod rayMethod = RayMethod.from(e, null); - functions.put(rayMethod.getFuncId(), rayMethod); - } - return new RayTaskMethods(clazz, functions); - } catch (Exception e) { - throw new RuntimeException("failed to get RayTaskMethods from " + clazzName, e); - } - } - - @Override - public String toString() { - return String - .format("RayTaskMethods:%s, funcNum=%s:{%s}", clazz, functions.size(), functions.values()); - } - -} \ No newline at end of file diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RemoteFunctionManager.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/RemoteFunctionManager.java deleted file mode 100644 index 66d99d9f2..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RemoteFunctionManager.java +++ /dev/null @@ -1,64 +0,0 @@ -package org.ray.runtime.functionmanager; - -import org.ray.api.id.UniqueId; - -/** - * register and load functions from function table. - */ -public interface RemoteFunctionManager { - - /* - * register mapping, and upload resource. - * this function is invoked by app proxy or other stand-alone tools it should detect for - * duplication first though - * - * @param resourceZip a directory zip from @JarRewriter - * @return SHA-1 hash of the content - */ - UniqueId registerResource(byte[] resourceZip); - - /** - * download resource content. - * - * @return resource content - */ - byte[] getResource(UniqueId resourceId); - - /** - * remove resource by its hash id - * be careful of invoking this function to make sure it is no longer used. - * - * @param resourceId SHA-1 hash of the resource zip bytes - */ - void unregisterResource(UniqueId resourceId); - - /* - * register the mapping to repo, - * this function is invoked by whoever initiates the driver id - */ - void registerApp(UniqueId driverId, UniqueId resourceId); - - /** - * get the resourceId of one app. - * - * @return resourceId of the app driver - */ - UniqueId getAppResourceId(UniqueId driverId); - - /* - * unregister mapping - * this function is called when the driver exits or detected dead - */ - void unregisterApp(UniqueId driverId); - - /** - * load resource. - */ - ClassLoader loadResource(UniqueId driverId); - - /** - * unload functions for this driver - * this function is used by the workers on demand when a driver is dead. - */ - void unloadFunctions(UniqueId driverId); -} \ No newline at end of file diff --git a/java/runtime/src/main/java/org/ray/runtime/generated/Arg.java b/java/runtime/src/main/java/org/ray/runtime/generated/Arg.java index 927ea2941..158321052 100644 --- a/java/runtime/src/main/java/org/ray/runtime/generated/Arg.java +++ b/java/runtime/src/main/java/org/ray/runtime/generated/Arg.java @@ -1,10 +1,10 @@ // automatically generated by the FlatBuffers compiler, do not modify - package org.ray.runtime.generated; -import java.nio.*; -import java.lang.*; -import com.google.flatbuffers.*; +import com.google.flatbuffers.FlatBufferBuilder; +import com.google.flatbuffers.Table; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; @SuppressWarnings("unused") public final class Arg extends Table { diff --git a/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java b/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java index 00377426f..8c0512afb 100644 --- a/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java +++ b/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java @@ -1,10 +1,10 @@ // automatically generated by the FlatBuffers compiler, do not modify - package org.ray.runtime.generated; -import java.nio.*; -import java.lang.*; -import com.google.flatbuffers.*; +import com.google.flatbuffers.FlatBufferBuilder; +import com.google.flatbuffers.Table; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; @SuppressWarnings("unused") public final class TaskInfo extends Table { @@ -49,6 +49,8 @@ public final class TaskInfo extends Table { public ResourcePair requiredResources(ResourcePair obj, int j) { int o = __offset(30); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } public int requiredResourcesLength() { int o = __offset(30); return o != 0 ? __vector_len(o) : 0; } public int language() { int o = __offset(32); return o != 0 ? bb.getInt(o + bb_pos) : 0; } + public String functionDescriptor(int j) { int o = __offset(34); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int functionDescriptorLength() { int o = __offset(34); return o != 0 ? __vector_len(o) : 0; } public static int createTaskInfo(FlatBufferBuilder builder, int driver_idOffset, @@ -65,8 +67,10 @@ public final class TaskInfo extends Table { int argsOffset, int returnsOffset, int required_resourcesOffset, - int language) { - builder.startObject(15); + int language, + int function_descriptorOffset) { + builder.startObject(16); + TaskInfo.addFunctionDescriptor(builder, function_descriptorOffset); TaskInfo.addLanguage(builder, language); TaskInfo.addRequiredResources(builder, required_resourcesOffset); TaskInfo.addReturns(builder, returnsOffset); @@ -85,7 +89,7 @@ public final class TaskInfo extends Table { return TaskInfo.endTaskInfo(builder); } - public static void startTaskInfo(FlatBufferBuilder builder) { builder.startObject(15); } + public static void startTaskInfo(FlatBufferBuilder builder) { builder.startObject(16); } public static void addDriverId(FlatBufferBuilder builder, int driverIdOffset) { builder.addOffset(0, driverIdOffset, 0); } public static void addTaskId(FlatBufferBuilder builder, int taskIdOffset) { builder.addOffset(1, taskIdOffset, 0); } public static void addParentTaskId(FlatBufferBuilder builder, int parentTaskIdOffset) { builder.addOffset(2, parentTaskIdOffset, 0); } @@ -107,6 +111,9 @@ public final class TaskInfo extends Table { public static int createRequiredResourcesVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startRequiredResourcesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static void addLanguage(FlatBufferBuilder builder, int language) { builder.addInt(14, language, 0); } + public static void addFunctionDescriptor(FlatBufferBuilder builder, int functionDescriptorOffset) { builder.addOffset(15, functionDescriptorOffset, 0); } + public static int createFunctionDescriptorVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startFunctionDescriptorVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endTaskInfo(FlatBufferBuilder builder) { int o = builder.endObject(); return o; diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index 389f43c44..95a8abdf4 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -8,7 +8,6 @@ import org.ray.api.RayObject; import org.ray.api.WaitResult; import org.ray.api.id.UniqueId; import org.ray.runtime.RayDevRuntime; -import org.ray.runtime.functionmanager.LocalFunctionManager; import org.ray.runtime.objectstore.MockObjectStore; import org.ray.runtime.task.FunctionArg; import org.ray.runtime.task.TaskSpec; @@ -20,7 +19,6 @@ public class MockRayletClient implements RayletClient { private final Map> waitTasks = new ConcurrentHashMap<>(); private final MockObjectStore store; - private LocalFunctionManager functions = null; private final RayDevRuntime runtime; public MockRayletClient(RayDevRuntime runtime, MockObjectStore store) { @@ -29,10 +27,6 @@ public class MockRayletClient implements RayletClient { store.registerScheduler(this); } - public void setLocalFunctionManager(LocalFunctionManager mgr) { - functions = mgr; - } - public void onObjectPut(UniqueId id) { Map bucket = waitTasks.get(id); if (bucket != null) { diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 1ba2277c8..1a78f22de 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -1,15 +1,17 @@ package org.ray.runtime.raylet; +import com.google.common.base.Preconditions; import com.google.flatbuffers.FlatBufferBuilder; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import org.ray.api.RayObject; import org.ray.api.WaitResult; import org.ray.api.id.UniqueId; -import org.ray.runtime.AbstractRayRuntime; +import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.generated.Arg; import org.ray.runtime.generated.ResourcePair; import org.ray.runtime.generated.TaskInfo; @@ -21,12 +23,18 @@ import org.ray.runtime.util.logger.RayLog; public class RayletClientImpl implements RayletClient { - private static ThreadLocal _taskBuffer = ThreadLocal.withInitial(() -> { - ByteBuffer bb = ByteBuffer - .allocateDirect(AbstractRayRuntime.getParams().max_submit_task_buffer_size_bytes); - bb.order(ByteOrder.LITTLE_ENDIAN); - return bb; - }); + private static final int TASK_SPEC_BUFFER_SIZE = 2 * 1024 * 1024; + + /** + * Direct buffers that are used to hold flatbuffer-serialized task specs. + */ + private static ThreadLocal taskSpecBuffer = ThreadLocal.withInitial(() -> + ByteBuffer.allocateDirect(TASK_SPEC_BUFFER_SIZE).order(ByteOrder.LITTLE_ENDIAN) + ); + + /** + * Point to c++'s local scheduler client. + */ private long client = 0; public RayletClientImpl(String schedulerSockName, UniqueId clientId, @@ -58,15 +66,14 @@ public class RayletClientImpl implements RayletClient { } @Override - public void submitTask(TaskSpec task) { - RayLog.core.debug("Submitting task: {}", task); - - ByteBuffer info = taskSpec2Info(task); + public void submitTask(TaskSpec spec) { + RayLog.core.debug("Submitting task: {}", spec); + ByteBuffer info = convertTaskSpecToFlatbuffer(spec); byte[] cursorId = null; - if (!task.actorId.isNil()) { - cursorId = task.cursorId.getBytes(); + if (!spec.getExecutionDependencies().isEmpty()) { + //TODO(hchen): handle more than one dependencies. + cursorId = spec.getExecutionDependencies().get(0).getBytes(); } - nativeSubmitTask(client, cursorId, info, info.position(), info.remaining()); } @@ -75,7 +82,7 @@ public class RayletClientImpl implements RayletClient { byte[] bytes = nativeGetTask(client); assert (null != bytes); ByteBuffer bb = ByteBuffer.wrap(bytes); - return taskInfo2Spec(bb); + return parseTaskSpecFromFlatbuffer(bb); } @Override @@ -104,105 +111,93 @@ public class RayletClientImpl implements RayletClient { nativeFreePlasmaObjects(client, objectIdsArray, localOnly); } - public static TaskSpec taskInfo2Spec(ByteBuffer bb) { + private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { bb.order(ByteOrder.LITTLE_ENDIAN); TaskInfo info = TaskInfo.getRootAsTaskInfo(bb); - - TaskSpec spec = new TaskSpec(); - spec.driverId = UniqueId.fromByteBuffer(info.driverIdAsByteBuffer()); - spec.taskId = UniqueId.fromByteBuffer(info.taskIdAsByteBuffer()); - spec.parentTaskId = UniqueId.fromByteBuffer(info.parentTaskIdAsByteBuffer()); - spec.parentCounter = info.parentCounter(); - spec.actorId = UniqueId.fromByteBuffer(info.actorIdAsByteBuffer()); - spec.actorCounter = info.actorCounter(); - spec.createActorId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer()); - - spec.functionId = UniqueId.fromByteBuffer(info.functionIdAsByteBuffer()); - - List args = new ArrayList<>(); + UniqueId driverId = UniqueId.fromByteBuffer(info.driverIdAsByteBuffer()); + UniqueId taskId = UniqueId.fromByteBuffer(info.taskIdAsByteBuffer()); + UniqueId parentTaskId = UniqueId.fromByteBuffer(info.parentTaskIdAsByteBuffer()); + int parentCounter = info.parentCounter(); + UniqueId actorCreationId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer()); + UniqueId actorId = UniqueId.fromByteBuffer(info.actorIdAsByteBuffer()); + UniqueId actorHandleId = UniqueId.fromByteBuffer(info.actorHandleIdAsByteBuffer()); + int actorCounter = info.actorCounter(); + // Deserialize args + FunctionArg[] args = new FunctionArg[info.argsLength()]; for (int i = 0; i < info.argsLength(); i++) { - UniqueId id = null; - byte[] data = null; - Arg sarg = info.args(i); - - int idCount = sarg.objectIdsLength(); - if (idCount > 0) { - ByteBuffer lbb = sarg.objectIdAsByteBuffer(0); - assert (lbb != null && lbb.remaining() > 0); - id = UniqueId.fromByteBuffer(lbb); - } - - ByteBuffer lbb = sarg.dataAsByteBuffer(); - if (lbb != null && lbb.remaining() > 0) { - // TODO: how to avoid memory copy - data = new byte[lbb.remaining()]; + Arg arg = info.args(i); + if (arg.objectIdsLength() > 0) { + Preconditions.checkArgument(arg.objectIdsLength() == 1, + "This arg has more than one id: {}", arg.objectIdsLength()); + UniqueId id = UniqueId.fromByteBuffer(arg.objectIdAsByteBuffer(0)); + args[i] = FunctionArg.passByReference(id); + } else { + ByteBuffer lbb = arg.dataAsByteBuffer(); + Preconditions.checkState(lbb != null && lbb.remaining() > 0); + byte[] data = new byte[lbb.remaining()]; lbb.get(data); + args[i] = FunctionArg.passByValue(data); } - - args.add(new FunctionArg(id, data)); } - spec.args = args.toArray(new FunctionArg[0]); - - List rids = new ArrayList<>(); + // Deserialize return ids + UniqueId[] returnIds = new UniqueId[info.returnsLength()]; for (int i = 0; i < info.returnsLength(); i++) { - ByteBuffer lbb = info.returnsAsByteBuffer(i); - assert (lbb != null && lbb.remaining() > 0); - rids.add(UniqueId.fromByteBuffer(lbb)); + returnIds[i] = UniqueId.fromByteBuffer(info.returnsAsByteBuffer(i)); } - spec.returnIds = rids.toArray(new UniqueId[0]); - - return spec; + // Deserialize required resources; + Map resources = new HashMap<>(); + for (int i = 0; i < info.requiredResourcesLength(); i++) { + resources.put(info.requiredResources(i).key(), info.requiredResources(i).value()); + } + // Deserialize function descriptor + Preconditions.checkArgument(info.functionDescriptorLength() == 3); + FunctionDescriptor functionDescriptor = new FunctionDescriptor( + info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2) + ); + return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId, actorId, + actorHandleId, actorCounter, args, returnIds, resources, functionDescriptor); } - public static ByteBuffer taskSpec2Info(TaskSpec task) { - ByteBuffer bb = _taskBuffer.get(); + private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { + ByteBuffer bb = taskSpecBuffer.get(); bb.clear(); FlatBufferBuilder fbb = new FlatBufferBuilder(bb); - final int driverIdOffset = fbb.createString(task.driverId.toByteBuffer()); final int taskIdOffset = fbb.createString(task.taskId.toByteBuffer()); final int parentTaskIdOffset = fbb.createString(task.parentTaskId.toByteBuffer()); final int parentCounter = task.parentCounter; - final int actorCreateIdOffset = fbb.createString(task.createActorId.toByteBuffer()); + final int actorCreateIdOffset = fbb.createString(task.actorCreationId.toByteBuffer()); final int actorCreateDummyIdOffset = fbb.createString(UniqueId.NIL.toByteBuffer()); final int actorIdOffset = fbb.createString(task.actorId.toByteBuffer()); final int actorHandleIdOffset = fbb.createString(task.actorHandleId.toByteBuffer()); final int actorCounter = task.actorCounter; - final int functionIdOffset = fbb.createString(task.functionId.toByteBuffer()); - - // serialize args + final int functionIdOffset = fbb.createString(UniqueId.NIL.toByteBuffer()); + // Serialize args int[] argsOffsets = new int[task.args.length]; for (int i = 0; i < argsOffsets.length; i++) { - int objectIdOffset = 0; int dataOffset = 0; if (task.args[i].id != null) { - int[] idOffsets = new int[] { - fbb.createString(task.args[i].id.toByteBuffer()) - }; + int[] idOffsets = new int[]{fbb.createString(task.args[i].id.toByteBuffer())}; objectIdOffset = fbb.createVectorOfTables(idOffsets); } else { objectIdOffset = fbb.createVectorOfTables(new int[0]); } - if (task.args[i].data != null) { dataOffset = fbb.createString(ByteBuffer.wrap(task.args[i].data)); } - argsOffsets[i] = Arg.createArg(fbb, objectIdOffset, dataOffset); } int argsOffset = fbb.createVectorOfTables(argsOffsets); - - // serialize returns + // Serialize returns int returnCount = task.returnIds.length; int[] returnsOffsets = new int[returnCount]; for (int k = 0; k < returnCount; k++) { returnsOffsets[k] = fbb.createString(task.returnIds[k].toByteBuffer()); } int returnsOffset = fbb.createVectorOfTables(returnsOffsets); - - // serialize required resources + // Serialize required resources // The required_resources vector indicates the quantities of the different // resources required by this task. The index in this vector corresponds to // the resource type defined in the ResourceIndex enum. For example, @@ -210,12 +205,16 @@ public class RayletClientImpl implements RayletClient { int i = 0; for (Map.Entry entry : task.resources.entrySet()) { int keyOffset = fbb.createString(ByteBuffer.wrap(entry.getKey().getBytes())); - requiredResourcesOffsets[i] = + requiredResourcesOffsets[i++] = ResourcePair.createResourcePair(fbb, keyOffset, entry.getValue()); - i++; } - int requiredResourcesOffset = fbb.createVectorOfTables(requiredResourcesOffsets); + int[] functionDescriptorOffsets = new int[]{ + fbb.createString(task.functionDescriptor.className), + fbb.createString(task.functionDescriptor.name), + fbb.createString(task.functionDescriptor.typeDescriptor) + }; + int functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets); int root = TaskInfo.createTaskInfo( fbb, driverIdOffset, taskIdOffset, @@ -223,18 +222,17 @@ public class RayletClientImpl implements RayletClient { actorCreateIdOffset, actorCreateDummyIdOffset, actorIdOffset, actorHandleIdOffset, actorCounter, false, functionIdOffset, - argsOffset, returnsOffset, requiredResourcesOffset, TaskLanguage.JAVA); - + argsOffset, returnsOffset, requiredResourcesOffset, TaskLanguage.JAVA, + functionDescriptorOffset); fbb.finish(root); ByteBuffer buffer = fbb.dataBuffer(); - if (buffer.remaining() > AbstractRayRuntime.getParams().max_submit_task_buffer_size_bytes) { + if (buffer.remaining() > TASK_SPEC_BUFFER_SIZE) { RayLog.core.error( - "Allocated buffer is not enough to transfer the task specification: " + AbstractRayRuntime - .getParams().max_submit_task_buffer_size_bytes + " vs " + buffer.remaining()); + "Allocated buffer is not enough to transfer the task specification: " + + TASK_SPEC_BUFFER_SIZE + " vs " + buffer.remaining()); assert (false); } - return buffer; } @@ -251,7 +249,6 @@ public class RayletClientImpl implements RayletClient { nativeDestroy(client); } - /// Native method declarations. /// /// If you change the signature of any native methods, please re-generate diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index 7ff5a95a3..58dc3d803 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -1,5 +1,7 @@ package org.ray.runtime.task; +import java.util.ArrayList; +import java.util.List; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; @@ -31,10 +33,13 @@ public class ArgumentsBuilder { } else if (checkSimpleValue(arg)) { data = Serializer.encode(arg); } else { - RayObject obj = Ray.put(arg); - id = obj.getId(); + id = Ray.put(arg).getId(); + } + if (id != null) { + ret[i] = FunctionArg.passByReference(id); + } else { + ret[i] = FunctionArg.passByValue(data); } - ret[i] = new FunctionArg(id, data); } return ret; } @@ -43,19 +48,24 @@ public class ArgumentsBuilder { * Convert task spec arguments to real function arguments. */ public static Object[] unwrap(TaskSpec task, ClassLoader classLoader) { - // Ignore the last arg, which is the class name - Object[] realArgs = new Object[task.args.length - 1]; - for (int i = 0; i < task.args.length - 1; i++) { + Object[] realArgs = new Object[task.args.length]; + List idsToFetch = new ArrayList<>(); + List indices = new ArrayList<>(); + for (int i = 0; i < task.args.length; i++) { FunctionArg arg = task.args[i]; - if (arg.id == null) { - // pass by value - Object obj = Serializer.decode(arg.data, classLoader); - realArgs[i] = obj; - } else if (arg.data == null) { + if (arg.id != null) { // pass by reference - realArgs[i] = Ray.get(arg.id); + idsToFetch.add(arg.id); + indices.add(i); + } else { + // pass by value + realArgs[i] = Serializer.decode(arg.data, classLoader); } } + List objects = Ray.get(idsToFetch); + for (int i = 0; i < objects.size(); i++) { + realArgs[indices.get(i)] = objects.get(i); + } return realArgs; } } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java b/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java index e042c4e23..9d7502bca 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java @@ -3,19 +3,47 @@ package org.ray.runtime.task; import org.ray.api.id.UniqueId; /** - * Represents arguments for ray function calls. + * Represents a function argument in task spec. + * + * Either `id` or `data` should be null, when id is not null, this argument will be + * passed by reference, otherwise it will be passed by value. */ public class FunctionArg { + /** + * The id of this argument (passed by reference). + */ public final UniqueId id; + /** + * Serialized data of this argument (passed by value). + */ public final byte[] data; - public FunctionArg(UniqueId id, byte[] data) { + private FunctionArg(UniqueId id, byte[] data) { this.id = id; this.data = data; } - public void toString(StringBuilder builder) { - builder.append("ids: ").append(id).append(", ").append(":").append(data); + /** + * Create a FunctionArg that will be passed by reference. + */ + public static FunctionArg passByReference(UniqueId id) { + return new FunctionArg(id, null); + } + + /** + * Create a FunctionArg that will be passed by value. + */ + public static FunctionArg passByValue(byte[] data) { + return new FunctionArg(null, data); + } + + @Override + public String toString() { + if (id != null) { + return ": " + id.toString(); + } else { + return ": " + data.length; + } } } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index 7bfcbd557..864be3754 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -1,8 +1,11 @@ package org.ray.runtime.task; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Map; import org.ray.api.id.UniqueId; +import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.util.ResourceUtil; /** @@ -11,102 +14,91 @@ import org.ray.runtime.util.ResourceUtil; public class TaskSpec { // ID of the driver that created this task. - public UniqueId driverId; + public final UniqueId driverId; // Task ID of the task. - public UniqueId taskId; + public final UniqueId taskId; // Task ID of the parent task. - public UniqueId parentTaskId; + public final UniqueId parentTaskId; // A count of the number of tasks submitted by the parent task before this one. - public int parentCounter; + public final int parentCounter; + + // Id for createActor a target actor + public final UniqueId actorCreationId; // Actor ID of the task. This is the actor that this task is executed on // or NIL_ACTOR_ID if the task is just a normal task. - public UniqueId actorId; - - // Number of tasks that have been submitted to this actor so far. - public int actorCounter; - - // Function ID of the task. - public UniqueId functionId; - - // Task arguments. - public FunctionArg[] args; - - // return ids - public UniqueId[] returnIds; + public final UniqueId actorId; // ID per actor client for session consistency - public UniqueId actorHandleId; + public final UniqueId actorHandleId; - // Id for createActor a target actor - public UniqueId createActorId; + // Number of tasks that have been submitted to this actor so far. + public final int actorCounter; + + // Task arguments. + public final FunctionArg[] args; + + // return ids + public final UniqueId[] returnIds; // The task's resource demands. - public Map resources; + public final Map resources; - public UniqueId cursorId; + // Function descriptor is a list of strings that can uniquely identify a function. + // It will be sent to worker and used to load the target callable function. + public final FunctionDescriptor functionDescriptor; - public TaskSpec() {} - - public TaskSpec(UniqueId driverId, UniqueId taskId, UniqueId parentTaskId, int parentCounter, - UniqueId actorId, int actorCounter, UniqueId functionId, FunctionArg[] args, - UniqueId[] returnIds, UniqueId actorHandleId, UniqueId createActorId, - Map resources, UniqueId cursorId) { - this.driverId = driverId; - this.taskId = taskId; - this.parentTaskId = parentTaskId; - this.parentCounter = parentCounter; - this.actorId = actorId; - this.actorCounter = actorCounter; - this.functionId = functionId; - this.args = args; - this.returnIds = returnIds; - this.actorHandleId = actorHandleId; - this.createActorId = createActorId; - this.resources = resources; - this.cursorId = cursorId; - - if (!this.resources.containsKey(ResourceUtil.CPU_LITERAL)) { - this.resources.put(ResourceUtil.CPU_LITERAL, 0.0); - } - - if (!this.resources.containsKey(ResourceUtil.GPU_LITERAL)) { - this.resources.put(ResourceUtil.GPU_LITERAL, 0.0); - } - } - - @Override - public String toString() { - StringBuilder builder = new StringBuilder(); - builder.append("\ttaskId: ").append(taskId).append("\n"); - builder.append("\tdriverId: ").append(driverId).append("\n"); - builder.append("\tparentCounter: ").append(parentCounter).append("\n"); - builder.append("\tactorId: ").append(actorId).append("\n"); - builder.append("\tactorCounter: ").append(actorCounter).append("\n"); - builder.append("\tfunctionId: ").append(functionId).append("\n"); - builder.append("\treturnIds: ").append(Arrays.toString(returnIds)).append("\n"); - builder.append("\tactorHandleId: ").append(actorHandleId).append("\n"); - builder.append("\tcreateActorId: ").append(createActorId).append("\n"); - builder.append("\tresources: ") - .append(ResourceUtil.getResourcesFromatStringFromMap(resources)).append("\n"); - builder.append("\tcursorId: ").append(cursorId).append("\n"); - builder.append("\targs:\n"); - for (FunctionArg arg : args) { - builder.append("\t\t"); - arg.toString(builder); - builder.append("\n"); - } - return builder.toString(); - } + private List executionDependencies; public boolean isActorTask() { return !actorId.isNil(); } public boolean isActorCreationTask() { - return !createActorId.isNil(); + return !actorCreationId.isNil(); + } + + public TaskSpec(UniqueId driverId, UniqueId taskId, UniqueId parentTaskId, int parentCounter, + UniqueId actorCreationId, UniqueId actorId, UniqueId actorHandleId, int actorCounter, + FunctionArg[] args, UniqueId[] returnIds, + Map resources, FunctionDescriptor functionDescriptor) { + this.driverId = driverId; + this.taskId = taskId; + this.parentTaskId = parentTaskId; + this.parentCounter = parentCounter; + this.actorCreationId = actorCreationId; + this.actorId = actorId; + this.actorHandleId = actorHandleId; + this.actorCounter = actorCounter; + this.args = args; + this.returnIds = returnIds; + this.resources = resources; + this.functionDescriptor = functionDescriptor; + this.executionDependencies = new ArrayList<>(); + } + + public List getExecutionDependencies() { + return executionDependencies; + } + + @Override + public String toString() { + return "TaskSpec{" + + "driverId=" + driverId + + ", taskId=" + taskId + + ", parentTaskId=" + parentTaskId + + ", parentCounter=" + parentCounter + + ", actorCreationId=" + actorCreationId + + ", actorId=" + actorId + + ", actorHandleId=" + actorHandleId + + ", actorCounter=" + actorCounter + + ", args=" + Arrays.toString(args) + + ", returnIds=" + Arrays.toString(returnIds) + + ", resources=" + ResourceUtil.getResourcesStringFromMap(resources) + + ", functionDescriptor=" + functionDescriptor + + '}'; } } diff --git a/java/runtime/src/main/java/org/ray/runtime/util/MethodId.java b/java/runtime/src/main/java/org/ray/runtime/util/MethodId.java deleted file mode 100644 index f13eb189d..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/util/MethodId.java +++ /dev/null @@ -1,215 +0,0 @@ -package org.ray.runtime.util; - -import com.google.common.base.Preconditions; -import java.io.Serializable; -import java.lang.invoke.MethodHandleInfo; -import java.lang.invoke.SerializedLambda; -import java.lang.reflect.Constructor; -import java.lang.reflect.Executable; -import java.lang.reflect.Method; -import java.lang.reflect.Modifier; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.WeakHashMap; -import org.objectweb.asm.Type; -import org.ray.runtime.util.logger.RayLog; - - -/** - * An instance of RayFunc is a lambda. - * MethodId describe the information of the called function in lambda.
- * e.g. Ray.call(Foo::foo), the MethodId of the lambda Foo::foo is:
- * MethodId.className = Foo
- * MethodId.methodName = foo
- * MethodId.methodDesc = describe the types of args and return. - * see org.objectweb.asm.Type.getDescriptor. - */ -public final class MethodId { - - /** - * use ThreadLocal to avoid lock. - * A cache from the lambda instances to MethodId. - * Note: the lambda instances are dynamically created per call site, - * we use WeakHashMap to avoid OOM. - */ - private static final ThreadLocal, MethodId>> - CACHE = ThreadLocal.withInitial(() -> new WeakHashMap<>()); - - public final String className; - public final String methodName; - - public final String methodDesc; - public final boolean isStatic; - /** - * encode the className,methodName,methodDesc,isStatic as an uniquel id. - */ - private final String encoding; - - /** - * sha1 from the encoding, used as functionId. - */ - private final byte[] digest; - - public MethodId(String className, String methodName, String methodDesc, boolean isStatic) { - this.className = className; - this.methodName = methodName; - this.methodDesc = methodDesc; - this.isStatic = isStatic; - this.encoding = encode(className, methodName, methodDesc, isStatic); - this.digest = getSha1Hash0(); - } - - private static String encode(String className, String methodName, String methodDesc, - boolean isStatic) { - StringBuilder sb = new StringBuilder(512); - sb.append(className).append('/').append(methodName).append("::").append(methodDesc).append("&&") - .append(isStatic); - return sb.toString(); - } - - public static MethodId fromExecutable(Executable method) { - final boolean isStatic = Modifier.isStatic(method.getModifiers()); - final String className = method.getDeclaringClass().getName(); - final String methodName = method instanceof Method - ? method.getName() : ""; - final Type type = method instanceof Method - ? Type.getType((Method) method) : Type.getType((Constructor) method); - final String methodDesc = type.getDescriptor(); - return new MethodId(className, methodName, methodDesc, isStatic); - } - - public static MethodId fromSerializedLambda(Serializable serial) { - return fromSerializedLambda(serial, false); - } - - public static MethodId fromSerializedLambda(Serializable serial, boolean forceNew) { - Preconditions.checkArgument(!(serial instanceof SerializedLambda), "arg could not be " - + "SerializedLambda"); - Class clazz = (Class) serial.getClass(); - WeakHashMap, MethodId> map = CACHE.get(); - MethodId id = map.get(clazz); - if (id == null || forceNew) { - final SerializedLambda lambda = LambdaUtils.getSerializedLambda(serial); - Preconditions.checkArgument(lambda.getCapturedArgCount() == 0, "could not transfer a lambda " - + "which is closure"); - final boolean isStatic = lambda.getImplMethodKind() == MethodHandleInfo.REF_invokeStatic; - final String className = lambda.getImplClass().replace('/', '.'); - id = new MethodId(className, lambda.getImplMethodName(), - lambda.getImplMethodSignature(), isStatic); - if (!forceNew) { - map.put(clazz, id); - } - } - return id; - } - - public Method load() { - return load(null); - } - - public Method load(ClassLoader loader) { - Class cls = null; - try { - RayLog.core.debug( - "load class " + className + " from class loader " + (loader == null ? this.getClass() - .getClassLoader() : loader) - + " for method " + toString() + " with ID = " + toHexHashString() - ); - cls = Class - .forName(className, true, loader == null ? this.getClass().getClassLoader() : loader); - } catch (Throwable e) { - RayLog.core.error("Cannot load class {}", className, e); - return null; - } - - Method[] ms = cls.getDeclaredMethods(); - ArrayList methods = new ArrayList<>(); - Type t = Type.getMethodType(this.methodDesc); - Type[] params = t.getArgumentTypes(); - String rt = t.getReturnType().getDescriptor(); - - for (Method m : ms) { - if (m.getName().equals(methodName)) { - if (!Arrays.equals(params, Type.getArgumentTypes(m))) { - continue; - } - - String mrt = Type.getDescriptor(m.getReturnType()); - if (!rt.equals(mrt)) { - continue; - } - - if (isStatic != Modifier.isStatic(m.getModifiers())) { - continue; - } - - methods.add(m); - } - } - - if (methods.size() != 1) { - RayLog.core.error( - "Load method {} failed as there are {} definitions.", toString(), methods.size()); - return null; - } - - return methods.get(0); - } - - private byte[] getSha1Hash0() { - byte[] digests = Sha1Digestor.digest(encoding); - ByteBuffer bb = ByteBuffer.wrap(digests); - bb.order(ByteOrder.LITTLE_ENDIAN); - if (methodName.contains("createActorStage1")) { - bb.putLong(Long.BYTES, 1); - } else { - bb.putLong(Long.BYTES, 0); - } - return digests; - } - - public byte[] getSha1Hash() { - return digest; - } - - private String toHexHashString() { - byte[] id = this.getSha1Hash(); - return StringUtil.toHexHashString(id); - } - - public String toEncodingString() { - return encoding; - } - - - @Override - public int hashCode() { - return encoding.hashCode(); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { - return false; - } - MethodId other = (MethodId) obj; - return className.equals(other.className) - && methodName.equals(other.methodName) - && methodDesc.equals(other.methodDesc) - && isStatic == other.isStatic; - } - - @Override - public String toString() { - return encoding; - } - -} \ No newline at end of file diff --git a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java new file mode 100644 index 000000000..85f482544 --- /dev/null +++ b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java @@ -0,0 +1,119 @@ +package org.ray.runtime.functionmanager; + +import java.util.Map; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.ray.api.annotation.RayRemote; +import org.ray.api.function.RayFunc0; +import org.ray.api.function.RayFunc1; +import org.ray.api.id.UniqueId; +import org.ray.runtime.functionmanager.FunctionManager.DriverFunctionTable; + +/** + * Tests for {@link FunctionManager} + */ +public class FunctionManagerTest { + + @RayRemote + public static Object foo() { + return null; + } + + @RayRemote + public static class Bar { + + public Bar() { + } + + public Object bar() { + return null; + } + } + + private static RayFunc0 fooFunc; + private static RayFunc1 barFunc; + private static RayFunc0 barConstructor; + private static FunctionDescriptor fooDescriptor; + private static FunctionDescriptor barDescriptor; + private static FunctionDescriptor barConstructorDescriptor; + + private FunctionManager functionManager; + + @BeforeClass + public static void beforeClass() { + fooFunc = FunctionManagerTest::foo; + barConstructor = Bar::new; + barFunc = Bar::bar; + fooDescriptor = new FunctionDescriptor(FunctionManagerTest.class.getName(), "foo", + "()Ljava/lang/Object;"); + barDescriptor = new FunctionDescriptor(Bar.class.getName(), "bar", + "()Ljava/lang/Object;"); + barConstructorDescriptor = new FunctionDescriptor(Bar.class.getName(), + FunctionManager.CONSTRUCTOR_NAME, + "()V"); + } + + @Before + public void before() { + functionManager = new FunctionManager(); + } + + @Test + public void testGetFunctionFromRayFunc() { + // Test normal function. + RayFunction func = functionManager.getFunction(UniqueId.NIL, fooFunc); + Assert.assertFalse(func.isConstructor()); + Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor); + Assert.assertNotNull(func.getRayRemoteAnnotation()); + + // Test actor method + func = functionManager.getFunction(UniqueId.NIL, barFunc); + Assert.assertFalse(func.isConstructor()); + Assert.assertEquals(func.getFunctionDescriptor(), barDescriptor); + Assert.assertNotNull(func.getRayRemoteAnnotation()); + + // Test actor constructor + func = functionManager.getFunction(UniqueId.NIL, barConstructor); + Assert.assertTrue(func.isConstructor()); + Assert.assertEquals(func.getFunctionDescriptor(), barConstructorDescriptor); + Assert.assertNotNull(func.getRayRemoteAnnotation()); + } + + @Test + public void testGetFunctionFromFunctionDescriptor() { + // Test normal function. + RayFunction func = functionManager.getFunction(UniqueId.NIL, fooDescriptor); + Assert.assertFalse(func.isConstructor()); + Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor); + Assert.assertNotNull(func.getRayRemoteAnnotation()); + + // Test actor method + func = functionManager.getFunction(UniqueId.NIL, barDescriptor); + Assert.assertFalse(func.isConstructor()); + Assert.assertEquals(func.getFunctionDescriptor(), barDescriptor); + Assert.assertNotNull(func.getRayRemoteAnnotation()); + + // Test actor constructor + func = functionManager.getFunction(UniqueId.NIL, barConstructorDescriptor); + Assert.assertTrue(func.isConstructor()); + Assert.assertEquals(func.getFunctionDescriptor(), barConstructorDescriptor); + Assert.assertNotNull(func.getRayRemoteAnnotation()); + } + + @Test + public void testLoadFunctionTableForClass() { + DriverFunctionTable functionTable = new DriverFunctionTable(getClass().getClassLoader()); + Map, RayFunction> res = functionTable + .loadFunctionsForClass(Bar.class.getName()); + // The result should 2 entries, one for the constructor, the other for bar. + Assert.assertEquals(res.size(), 2); + Assert.assertTrue(res.containsKey( + ImmutablePair.of(barDescriptor.name, barDescriptor.typeDescriptor))); + Assert.assertTrue(res.containsKey( + ImmutablePair.of(barConstructorDescriptor.name, barConstructorDescriptor.typeDescriptor))); + } +} diff --git a/java/test/pom.xml b/java/test/pom.xml index bc5483808..db7fcc140 100644 --- a/java/test/pom.xml +++ b/java/test/pom.xml @@ -35,8 +35,6 @@ junit junit - 4.11 - diff --git a/java/test/src/main/java/org/ray/api/test/LambdaUtilsTest.java b/java/test/src/main/java/org/ray/api/test/LambdaUtilsTest.java deleted file mode 100644 index 210241ecf..000000000 --- a/java/test/src/main/java/org/ray/api/test/LambdaUtilsTest.java +++ /dev/null @@ -1,211 +0,0 @@ -package org.ray.api.test; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.lang.reflect.Method; -import java.util.concurrent.TimeUnit; -import org.junit.Assert; -import org.junit.Test; -import org.ray.api.function.RayFunc0; -import org.ray.api.function.RayFunc1; -import org.ray.api.function.RayFunc3; -import org.ray.runtime.util.MethodId; -import org.ray.runtime.util.logger.RayLog; - -public class LambdaUtilsTest { - - static final String CLASS_NAME = LambdaUtilsTest.class.getName(); - static final Method CALL0; - static final Method CALL1; - static final Method CALL2; - static final Method CALL3; - - static { - try { - CALL0 = LambdaUtilsTest.class.getDeclaredMethod("call0", new Class[0]); - CALL1 = LambdaUtilsTest.class.getDeclaredMethod("call1", new Class[]{Long.class}); - CALL2 = LambdaUtilsTest.class.getDeclaredMethod("call2", new Class[0]); - CALL3 = LambdaUtilsTest.class - .getDeclaredMethod("call3", new Class[]{Long.class, String.class}); - - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - public static void testRemoteLambdaParse(RayFunc3 f, int n, - boolean forceNew, boolean debug) - throws Exception { - if (debug) { - RayLog.core.info("parse#" + f.getClass().getName()); - } - long start = System.nanoTime(); - for (int i = 0; i < n; i++) { - MethodId mid = MethodId.fromSerializedLambda(f, forceNew); - } - long end = System.nanoTime(); - RayLog.core.info(String.format("remoteLambdaParse(new=%s):total=%sms, one=%s", forceNew, - TimeUnit.NANOSECONDS.toMillis(end - start), - (end - start) / n)); - } - - public static void testRemoteLambdaSerde(RayFunc3 f, int n, - boolean de, boolean debug) - throws Exception { - if (debug) { - RayLog.core.info("se#" + f.getClass().getName()); - } - long start = System.nanoTime(); - for (int i = 0; i < n; i++) { - ByteArrayOutputStream bytes = new ByteArrayOutputStream(1024); - ObjectOutputStream out = new ObjectOutputStream(bytes); - out.writeObject(f); - out.close(); - if (de) { - ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray())); - RayFunc3 def = (RayFunc3) in.readObject(); - in.close(); - if (debug) { - RayLog.core.info("de#" + def.getClass().getName()); - } - } - } - long end = System.nanoTime(); - RayLog.core.info( - String.format("remoteLambdaSer(de=%s):total=%sms,one=%s", de, - TimeUnit.NANOSECONDS.toMillis(end - start), - (end - start) / n)); - } - - public static void testCall0(RayFunc0 f) { - MethodId mid = MethodId.fromSerializedLambda(f); - RayLog.core.info(mid.toString()); - Assert.assertEquals(mid.load(), CALL0); - Assert.assertTrue(mid.isStatic); - } - - public static void testCall1(RayFunc1 f, T t) { - MethodId mid = MethodId.fromSerializedLambda(f); - RayLog.core.info(mid.toString()); - Assert.assertEquals(mid.load(), CALL1); - Assert.assertTrue(mid.isStatic); - } - - public static void testCall2(RayFunc1 f) { - MethodId mid = MethodId.fromSerializedLambda(f); - RayLog.core.info(mid.toString()); - Assert.assertEquals(mid.load(), CALL2); - Assert.assertTrue(!mid.isStatic); - } - - public static void testCall3(RayFunc3 f) { - MethodId mid = MethodId.fromSerializedLambda(f); - RayLog.core.info(mid.toString()); - Assert.assertEquals(mid.load(), CALL3); - Assert.assertTrue(!mid.isStatic); - } - - public static String call0() { - long t = System.currentTimeMillis(); - RayLog.core.info("call0:" + t); - return String.valueOf(t); - } - - public static String call1(Long v) { - for (int i = 0; i < 100; i++) { - v += i; - } - RayLog.core.info("call1:" + v); - return String.valueOf(v); - } - - @Test - public void testLambdaSer() throws Exception { - testCall0(LambdaUtilsTest::call0); - testCall1(LambdaUtilsTest::call1, Long.valueOf(System.currentTimeMillis())); - testCall2(LambdaUtilsTest::call2); - testCall3(LambdaUtilsTest::call3); - } - - /** - * to test the serdeLambda's perf. - */ - public void testBenchmark() throws Exception { - //test serde - testRemoteLambdaSerde(LambdaUtilsTest::call3, 2, true, true); - testRemoteLambdaSerde(LambdaUtilsTest::call3, 2, true, true); - //warmup - - RayLog.core.info("warmup:serde################"); - testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, true, false); - testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, false, false); - RayLog.core.info("benchmark:serde################"); - testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, true, false); - RayLog.core.info("benchmark:ser################"); - testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, false, false); - - //test serde one new call's time, no class cache - long start = System.nanoTime(); - testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false); - long end = System.nanoTime(); - RayLog.core.info("one sertime:" + (end - start)); - - //test serde one new call's time, no class cache - start = System.nanoTime(); - testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false); - end = System.nanoTime(); - RayLog.core.info("one sertime:" + (end - start)); - - //test serde one new call's time, no class cache - start = System.nanoTime(); - testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false); - end = System.nanoTime(); - RayLog.core.info("one sertime:" + (end - start)); - - //test serde one new call's time, no class cache - start = System.nanoTime(); - testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false); - end = System.nanoTime(); - RayLog.core.info("one serdetime:" + (end - start)); - - //test serde one new call's time, no class cache - start = System.nanoTime(); - testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false); - end = System.nanoTime(); - RayLog.core.info("one serdetime:" + (end - start)); - - //test serde one new call's time, no class cache - start = System.nanoTime(); - testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false); - end = System.nanoTime(); - RayLog.core.info("one serdetime:" + (end - start)); - - //test lambda - testRemoteLambdaParse(LambdaUtilsTest::call3, 2, true, true); - testRemoteLambdaParse(LambdaUtilsTest::call3, 2, false, true); - //warmup - RayLog.core.info("warmup:parse################"); - testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, true, false); - testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, false, false); - RayLog.core.info("benchmark:parseNew################"); - testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, true, false); - RayLog.core.info("benchmark:parseCache################"); - testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, false, false); - } - - public String call2() { - long t = System.currentTimeMillis(); - RayLog.core.info("call2:" + t); - return "call2:" + t; - } - - public String call3(Long v, String s) { - for (int i = 0; i < 100; i++) { - v += i; - } - RayLog.core.info("call3:" + v); - return String.valueOf(v); - } -} \ No newline at end of file diff --git a/java/test/src/main/java/org/ray/api/test/MethodIdTest.java b/java/test/src/main/java/org/ray/api/test/MethodIdTest.java deleted file mode 100644 index 81bd24cfb..000000000 --- a/java/test/src/main/java/org/ray/api/test/MethodIdTest.java +++ /dev/null @@ -1,44 +0,0 @@ -package org.ray.api.test; - - -import java.lang.reflect.Executable; -import org.junit.Assert; -import org.junit.Test; -import org.ray.api.function.RayFunc2; -import org.ray.runtime.util.MethodId; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class MethodIdTest { - - private static final Logger LOGGER = LoggerFactory.getLogger(MethodIdTest.class); - - @Test - public void testNormalMethod() throws Exception { - RayFunc2 f = MethodIdTest::foo; - MethodId m1 = MethodId.fromSerializedLambda(f); - Executable e = MethodIdTest.class.getDeclaredMethod("foo", int.class, String.class); - MethodId m2 = MethodId.fromExecutable(e); - LOGGER.info("{}, {}", m1, m2); - Assert.assertEquals(m1, m2); - } - - @Test - public void testConstructor() throws Exception { - RayFunc2 f = Foo::new; - MethodId m1 = MethodId.fromSerializedLambda(f); - Executable e = Foo.class.getConstructor(int.class, String.class); - MethodId m2 = MethodId.fromExecutable(e); - LOGGER.info("{}, {}", m1, m2); - Assert.assertEquals(m1, m2); - } - - public static String foo(int a, String b) { - return a + b; - } - - public static class Foo { - public Foo(int a, String b) {} - } -} - diff --git a/java/test/src/main/java/org/ray/api/test/RayActorMethodsTest.java b/java/test/src/main/java/org/ray/api/test/RayActorMethodsTest.java deleted file mode 100644 index c91501217..000000000 --- a/java/test/src/main/java/org/ray/api/test/RayActorMethodsTest.java +++ /dev/null @@ -1,32 +0,0 @@ -package org.ray.api.test; - -import org.junit.Assert; -import org.junit.Test; -import org.ray.api.annotation.RayRemote; -import org.ray.runtime.functionmanager.RayActorMethods; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class RayActorMethodsTest { - - private static final Logger LOGGER = LoggerFactory.getLogger(RayActorMethodsTest.class); - - @RayRemote - public static class ExampleActor { - - public void func1() {} - - public void func2() {} - - public static void func3() {} - } - - @Test - public void testActorMethods() { - RayActorMethods methods = RayActorMethods - .fromClass(ExampleActor.class.getName(), RayActorMethodsTest.class.getClassLoader()); - LOGGER.info(methods.toString()); - Assert.assertEquals(methods.functions.size(), 2); - Assert.assertEquals(methods.staticFunctions.size(), 1); - } -} diff --git a/java/test/src/main/java/org/ray/api/test/RayTaskMethodsTest.java b/java/test/src/main/java/org/ray/api/test/RayTaskMethodsTest.java deleted file mode 100644 index e731ccb49..000000000 --- a/java/test/src/main/java/org/ray/api/test/RayTaskMethodsTest.java +++ /dev/null @@ -1,43 +0,0 @@ -package org.ray.api.test; - -import org.junit.Assert; -import org.junit.Test; -import org.ray.runtime.functionmanager.RayMethod; -import org.ray.runtime.functionmanager.RayTaskMethods; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - - -public class RayTaskMethodsTest { - - private static final Logger LOGGER = LoggerFactory.getLogger(RayTaskMethodsTest.class); - - private static class Foo { - - public Foo() {} - - public Foo(int x) {} - - public static void f1() {} - - public void f2() {} - } - - @Test - public void testTask() { - RayTaskMethods methods = RayTaskMethods - .fromClass(Foo.class.getName(), Foo.class.getClassLoader()); - LOGGER.info(methods.toString()); - int numMethods = 0; - int numConstructors = 0; - for (RayMethod m : methods.functions.values()) { - if (m.isConstructor()) { - numConstructors += 1; - } else { - numMethods += 1; - } - } - Assert.assertEquals(numMethods, 1); - Assert.assertEquals(numConstructors, 2); - } -} diff --git a/src/common/format/common.fbs b/src/common/format/common.fbs index e0988472c..9dc9f651a 100644 --- a/src/common/format/common.fbs +++ b/src/common/format/common.fbs @@ -62,6 +62,13 @@ table TaskInfo { required_resources: [ResourcePair]; // The language that this task belongs to language: TaskLanguage; + // Function descriptor, which is a list of strings that can + // uniquely describe a function. + // For a Python function, it should be: [module_name, class_name, function_name] + // For a Java function, it should be: [class_name, method_name, type_descriptor] + // TODO(hchen): after changing Python worker to use function_descriptor, + // function_id can be removed. + function_descriptor: [string]; } // Object information data structure.