[java] support creating an actor with parameters (#2817)

Previously `Ray.createActor` only support creating an actor without any parameter. This PR adds the support for creating an actor with parameters. Moreover, besides using a constructor, it's now also allowed to create an actor with a factory method. For more usage, prefer refer to `ActorTest.java`.
This commit is contained in:
Hao Chen
2018-09-04 00:53:03 +08:00
committed by Robert Nishihara
parent b37a283053
commit 9d655721e5
27 changed files with 1131 additions and 680 deletions
@@ -105,16 +105,6 @@ public final class Ray extends RayCall {
return runtime.wait(waitList, waitList.size(), Integer.MAX_VALUE);
}
/**
* Create an actor on a remote node.
*
* @param actorClass the class of the actor to be created.
* @return A handle to the newly created actor.
*/
public static <T> RayActor<T> createActor(Class<T> actorClass) {
return runtime.createActor(actorClass);
}
/**
* Get the underlying runtime instance.
*/
File diff suppressed because it is too large Load Diff
@@ -52,14 +52,6 @@ public interface RayRuntime {
*/
<T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs);
/**
* Create an actor on a remote node.
*
* @param actorClass the class of the actor to be created.
* @return A handle to the newly created actor.
*/
<T> RayActor<T> createActor(Class<T> actorClass);
/**
* Invoke a remote function.
*
@@ -78,4 +70,14 @@ public interface RayRuntime {
* @return The result object.
*/
RayObject call(RayFunc func, RayActor actor, Object[] args);
/**
* Create an actor on a remote node.
*
* @param actorFactoryFunc A remote function whose return value is the actor object.
* @param args The arguments for the remote function.
* @param <T> The type of the actor object.
* @return A handle to the actor.
*/
<T> RayActor<T> createActor(RayFunc actorFactoryFunc, Object[] args);
}
@@ -4,6 +4,8 @@ import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.lang.invoke.MethodHandleInfo;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.nio.ByteBuffer;
@@ -67,13 +69,15 @@ public final class MethodId {
return sb.toString();
}
public static MethodId fromMethod(Method method) {
final boolean isstatic = Modifier.isStatic(method.getModifiers());
public static MethodId fromExecutable(Executable method) {
final boolean isStatic = Modifier.isStatic(method.getModifiers());
final String className = method.getDeclaringClass().getName();
final String methodName = method.getName();
final Type type = Type.getType(method);
final String methodName = method instanceof Method
? method.getName() : "<init>";
final Type type = method instanceof Method
? Type.getType((Method) method) : Type.getType((Constructor) method);
final String methodDesc = type.getDescriptor();
return new MethodId(className, methodName, methodDesc, isstatic);
return new MethodId(className, methodName, methodDesc, isStatic);
}
public static MethodId fromSerializedLambda(Serializable serial) {
@@ -101,7 +105,6 @@ public final class MethodId {
return id;
}
public Method load() {
return load(null);
}
@@ -2,6 +2,7 @@ package org.ray.util;
import java.util.HashMap;
import java.util.Map;
import org.ray.api.annotation.RayRemote;
import org.ray.api.annotation.ResourceItem;
public class ResourceUtil {
@@ -11,17 +12,18 @@ public class ResourceUtil {
/**
* Convert the array that contains resource items to a map.
*
* @param resourceArray The resources list to be converted.
* @param remoteAnnotation The RayRemote annotation that contains the resource items.
* @return The map whose key represents the resource name
* and the value represents the resource quantity.
*/
public static Map<String, Double> getResourcesMapFromArray(ResourceItem[] resourceArray) {
public static Map<String, Double> getResourcesMapFromArray(RayRemote remoteAnnotation) {
Map<String, Double> resourceMap = new HashMap<>();
if (resourceArray != null) {
for (ResourceItem item : resourceArray) {
if (!item.name().isEmpty()) {
resourceMap.put(item.name(), item.value());
}
if (remoteAnnotation == null) {
return resourceMap;
}
for (ResourceItem item : remoteAnnotation.resources()) {
if (!item.name().isEmpty()) {
resourceMap.put(item.name(), item.value());
}
}
@@ -10,6 +10,11 @@ public abstract class BaseGenerator {
sb.append(line).append("\n");
}
protected void newLine(int numIndents, String line) {
indents(numIndents);
newLine(line);
}
protected void indents(int numIndents) {
for (int i = 0; i < numIndents; i++) {
sb.append(" ");
@@ -6,7 +6,8 @@ import java.util.List;
import org.ray.util.FileUtil;
/**
* A util class that generates `RayCall.java`
* A util class that generates `RayCall.java`,
* which provides type-safe interfaces for `Ray.call` and `Ray.createActor`.
*/
public class RayCallGenerator extends BaseGenerator {
@@ -24,64 +25,76 @@ public class RayCallGenerator extends BaseGenerator {
newLine("");
newLine("/**");
newLine(" * This class provides type-safe interfaces for Ray.call.");
newLine(" * This class provides type-safe interfaces for `Ray.call` and `Ray.createActor`.");
newLine(" **/");
newLine("@SuppressWarnings({\"rawtypes\", \"unchecked\"})");
newLine("class RayCall {");
for (int i = 0; i <= 6; i++) {
if (i > 0) {
buildCalls(i, true);
}
buildCalls(i, false);
newLine(1, "// =======================================");
newLine(1, "// Methods for remote function invocation.");
newLine(1, "// =======================================");
for (int i = 0; i <= MAX_PARAMETERS; i++) {
buildCalls(i, false, false);
}
newLine(1, "// ===========================================");
newLine(1, "// Methods for remote actor method invocation.");
newLine(1, "// ===========================================");
for (int i = 0; i <= MAX_PARAMETERS - 1; i++) {
buildCalls(i, true, false);
}
newLine(1, "// ===========================");
newLine(1, "// Methods for actor creation.");
newLine(1, "// ===========================");
for (int i = 0; i <= MAX_PARAMETERS; i++) {
buildCalls(i, false, true);
}
newLine("}");
return sb.toString();
}
/**
* Build the `Ray.call` methods for given number of parameters.
* @param numParameters the number of parameters, including the actor parameter.
* Build the `Ray.call` or `Ray.createActor` methods with the given number of parameters.
* @param numParameters the number of parameters
* @param forActor build actor api when true, otherwise build task api.
* @param forActorCreation build `Ray.createActor` when true, otherwise build `Ray.call`.
*/
private void buildCalls(int numParameters, boolean forActor) {
private void buildCalls(int numParameters, boolean forActor, boolean forActorCreation) {
String genericTypes = "";
String argList = "";
for (int i = 0; i < numParameters; i++) {
genericTypes += "T" + i + ", ";
if (!forActor || i > 0) {
argList += "t" + i + ", ";
}
argList += "t" + i + ", ";
}
if (forActor) {
genericTypes = "A, " + genericTypes;
}
genericTypes += forActorCreation ? "A" : "R";
if (argList.endsWith(", ")) {
argList = argList.substring(0, argList.length() - 2);
}
String funcParam = String.format("RayFunc%d<%sR> f%s",
numParameters,
genericTypes,
numParameters > 0 ? ", " : "");
String actorParam = "";
String paramPrefix = String.format("RayFunc%d<%s> f",
!forActor ? numParameters : numParameters + 1,
genericTypes);
if (forActor) {
actorParam = "RayActor<T0> actor";
if (numParameters > 1) {
actorParam += ", ";
}
paramPrefix += ", RayActor<A> actor";
}
if (numParameters > 0) {
paramPrefix += ", ";
}
for (String param : generateParameters(forActor ? 1 : 0, numParameters)) {
String returnType = !forActorCreation ? "RayObject<R>" : "RayActor<A>";
String funcName = !forActorCreation ? "call" : "createActor";
String funcArgs = !forActor ? "f, args" : "f, actor, args";
for (String param : generateParameters(0, numParameters)) {
// method signature
indents(1);
newLine(String.format(
"public static <%sR> RayObject<R> call(%s%s%s) {",
genericTypes, funcParam, actorParam, param
newLine(1, String.format(
"public static <%s> %s %s(%s) {",
genericTypes, returnType, funcName, paramPrefix + param
));
// method body
indents(2);
newLine(String.format("Object[] args = new Object[]{%s};", argList));
indents(2);
newLine(String.format("return Ray.internal().call(f%s, args);", forActor ? ", actor" : ""));
indents(1);
newLine("}");
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
newLine(2, String.format("return Ray.internal().%s(%s);", funcName, funcArgs));
newLine(1, "}");
}
}
@@ -14,9 +14,7 @@ import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.annotation.RayRemote;
import org.ray.api.function.RayFunc;
import org.ray.api.function.RayFunc2;
import org.ray.api.id.UniqueId;
import org.ray.api.runtime.RayRuntime;
import org.ray.core.model.RayParameters;
@@ -50,9 +48,9 @@ public abstract class AbstractRayRuntime implements RayRuntime {
protected PathConfig pathConfig;
/**
* Actor ID -> actor instance.
* Actor ID -> local actor instance.
*/
private Map<UniqueId, Object> actors = new HashMap<>();
Map<UniqueId, Object> localActors = new HashMap<>();
// app level Ray.init()
// make it private so there is no direct usage but only from Ray.init
@@ -137,7 +135,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
objectStoreProxy = new ObjectStoreProxy(plink, slink);
}
worker = new Worker(localSchedulerClient, functions);
worker = new Worker(this);
}
private static AbstractRayRuntime instantiate(RayParameters params) {
@@ -195,13 +193,15 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public <T> RayObject<T> put(T obj) {
UniqueId objectId = getCurrentTaskNextPutId();
UniqueId objectId = UniqueIdHelper.computePutId(
WorkerContext.currentTask().taskId, WorkerContext.nextPutIndex());
put(objectId, obj);
return new RayObjectImpl<>(objectId);
}
public <T> void put(UniqueId objectId, T obj) {
UniqueId taskId = getCurrentTaskId();
UniqueId taskId = WorkerContext.currentTask().taskId;
RayLog.core.info("Putting object {}, for task {} ", objectId, taskId);
if (!params.use_raylet) {
localSchedulerClient.markTaskPutDependency(taskId, objectId);
@@ -209,21 +209,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
objectStoreProxy.put(objectId, obj, null);
}
/**
* get the task identity of the currently running task, UniqueId.NIL if not inside any
*/
public UniqueId getCurrentTaskId() {
return worker.getCurrentTaskId();
}
/**
* get the to-be-returned objects identities of the currently running task, empty array if not
* inside any.
*/
public UniqueId getCurrentTaskNextPutId() {
return worker.getCurrentTaskNextPutId();
}
@Override
public <T> T get(UniqueId objectId) throws TaskExecutionException {
List<T> ret = get(ImmutableList.of(objectId));
@@ -233,7 +218,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public <T> List<T> get(List<UniqueId> objectIds) {
boolean wasBlocked = false;
UniqueId taskId = getCurrentTaskId();
UniqueId taskId = WorkerContext.currentTask().taskId;
try {
int numObjectIds = objectIds.size();
@@ -343,7 +328,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public RayObject call(RayFunc func, Object[] args) {
TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, args, null);
TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, args, false);
localSchedulerClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
@@ -354,7 +339,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
throw new IllegalArgumentException("Unsupported actor type: " + actor.getClass().getName());
}
RayActorImpl actorImpl = (RayActorImpl)actor;
TaskSpec spec = createTaskSpec(func, actorImpl, args, null);
TaskSpec spec = createTaskSpec(func, actorImpl, args, false);
actorImpl.setTaskCursor(spec.returnIds[1]);
localSchedulerClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
@@ -362,30 +347,13 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
@SuppressWarnings("unchecked")
public <T> RayActor<T> createActor(Class<T> actorClass) {
RayFunc2<UniqueId, String, Object> func = AbstractRayRuntime::createLocalActor;
TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, null, actorClass);
RayActorImpl actor = new RayActorImpl(spec.returnIds[0]);
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc, Object[] args) {
TaskSpec spec = createTaskSpec(actorFactoryFunc, RayActorImpl.NIL, args, true);
RayActorImpl<?> actor = new RayActorImpl(spec.returnIds[0]);
actor.increaseTaskCounter();
actor.setTaskCursor(spec.returnIds[0]);
localSchedulerClient.submitTask(spec);
return actor;
}
@RayRemote
private static Object createLocalActor(UniqueId actorId, String className) {
try {
Class<?> cls = Class.forName(className, true, Thread.currentThread().getContextClassLoader());
Object actor = cls.getConstructor().newInstance();
getInstance().actors.put(actorId, actor);
RayLog.core.info("Created actor: {}, actor id: {}", className, actorId);
return null;
} catch (ClassNotFoundException | InstantiationException | IllegalAccessException
| IllegalArgumentException | InvocationTargetException | NoSuchMethodException
| SecurityException e) {
RayLog.core.error("Failed to create actor {}", className, e);
throw new TaskExecutionException(e);
}
return (RayActor<T>) actor;
}
/**
@@ -404,12 +372,11 @@ public abstract class AbstractRayRuntime implements RayRuntime {
* @param func The target remote function.
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
* @param args The arguments for the remote function.
* @param actorClassForCreation If the task is a actor creation task, the argument should be
* the actor class. Otherwise, it should be null.
* @param isActorCreationTask Whether this task is an actor creation task.
* @return A TaskSpec object.
*/
private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args,
Class actorClassForCreation) {
boolean isActorCreationTask) {
final TaskSpec current = WorkerContext.currentTask();
UniqueId taskId = localSchedulerClient.generateTaskId(current.driverId,
current.taskId,
@@ -418,8 +385,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
UniqueId[] returnIds = genReturnIds(taskId, numReturns);
UniqueId actorCreationId = UniqueId.NIL;
if (actorClassForCreation != null) {
args = new Object[] {returnIds[0], actorClassForCreation.getName()};
if (isActorCreationTask) {
actorCreationId = returnIds[0];
}
@@ -448,24 +414,25 @@ public abstract class AbstractRayRuntime implements RayRuntime {
returnIds,
actor.getHandleId(),
actorCreationId,
ResourceUtil.getResourcesMapFromArray(rayMethod.remoteAnnotation.resources()),
ResourceUtil.getResourcesMapFromArray(rayMethod.remoteAnnotation),
actor.getTaskCursor()
);
}
/***********
* Internal Methods.
***********/
public void loop() {
worker.loop();
}
/**
* get actor with given id.
*/
public Object getLocalActor(UniqueId id) {
return actors.get(id);
public Worker getWorker() {
return worker;
}
}
public LocalSchedulerLink getLocalSchedulerClient() {
return localSchedulerClient;
}
public LocalFunctionManager getLocalFunctionManager() {
return functions;
}
}
@@ -9,12 +9,16 @@ import org.ray.api.id.UniqueId;
import org.ray.spi.model.FunctionArg;
import org.ray.spi.model.TaskSpec;
/**
* arguments wrap and unwrap.
*/
public class ArgumentsBuilder {
@SuppressWarnings({"rawtypes", "unchecked"})
private static boolean checkSimpleValue(Object o) {
// TODO(raulchen): implement this.
return true;
}
/**
* Convert real function arguments to task spec arguments.
*/
public static FunctionArg[] wrap(Object[] args) {
FunctionArg[] ret = new FunctionArg[args.length];
for (int i = 0; i < ret.length; i++) {
@@ -38,13 +42,11 @@ public class ArgumentsBuilder {
return ret;
}
private static boolean checkSimpleValue(Object o) {
return true;//TODO I think Ray don't want to pass big parameter
}
@SuppressWarnings({"rawtypes", "unchecked"})
public static Pair<Object, Object[]> unwrap(TaskSpec task, Method m, ClassLoader classLoader) {
// the last arg is className
/**
* Convert task spec arguments to real function arguments.
*/
public static Object[] unwrap(TaskSpec task, ClassLoader classLoader) {
// Ignore the last arg, which is the class name
Object[] realArgs = new Object[task.args.length - 1];
for (int i = 0; i < task.args.length - 1; i++) {
FunctionArg arg = task.args[i];
@@ -57,8 +59,6 @@ public class ArgumentsBuilder {
realArgs[i] = Ray.get(arg.id);
}
}
Object actor = task.actorId.isNil()
? null : AbstractRayRuntime.getInstance().getLocalActor(task.actorId);
return Pair.of(actor, realArgs);
return realArgs;
}
}
@@ -1,94 +0,0 @@
package org.ray.core;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.id.UniqueId;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.TaskSpec;
import org.ray.util.exception.TaskExecutionException;
import org.ray.util.logger.RayLog;
/**
* how to execute a invocation.
*/
public class InvocationExecutor {
public static void execute(TaskSpec task, Pair<ClassLoader, RayMethod> pr)
throws TaskExecutionException {
String taskdesc =
"[" + pr.getRight().fullName + "_" + task.taskId.toString() + " actorId = " + task.actorId
+ "]";
TaskExecutionException ex = null;
// switch to current driver's loader
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
if (pr.getLeft() != null) {
Thread.currentThread().setContextClassLoader(pr.getLeft());
}
// execute
try {
executeInternal(task, pr);
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
if (!task.actorId.isNil()
&& AbstractRayRuntime.getInstance().getLocalActor(task.actorId) == null) {
ex = new TaskExecutionException("Task " + taskdesc + " execution on actor " + task.actorId
+ " failed as the actor is not present ", e);
RayLog.core.error("Task " + taskdesc + " execution on actor " + task.actorId
+ " failed as the actor is not present ", e);
} else {
ex = new TaskExecutionException(
formatTaskExecutionExceptionMsg(task, pr.getRight().fullName), e);
RayLog.core.error("Task " + taskdesc + " execution failed ", e);
}
RayLog.core.error(e.getMessage());
RayLog.core.error("task info: \n" + task.toString());
} catch (Throwable e) {
ex = new TaskExecutionException(formatTaskExecutionExceptionMsg(task, pr.getRight().fullName),
e);
RayLog.core.error("Task " + taskdesc + " execution with unknown error ", e);
RayLog.core.error(e.getMessage());
}
// recover loader
if (pr.getLeft() != null) {
Thread.currentThread().setContextClassLoader(oldLoader);
}
// set exception as the output results
if (ex != null) {
throw ex;
}
}
private static void executeInternal(TaskSpec task, Pair<ClassLoader, RayMethod> pr)
throws IllegalAccessException, IllegalArgumentException, InvocationTargetException {
Method m = pr.getRight().invokable;
Pair<Object, Object[]> realArgs = ArgumentsBuilder.unwrap(task, m, pr.getLeft());
// execute
Object result = null;
try {
result = m.invoke(realArgs.getLeft(), realArgs.getRight());
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
RayLog.core.error("invoke failed:" + m);
throw e;
}
if (task.returnIds == null || task.returnIds.length == 0) {
return;
}
AbstractRayRuntime.getInstance().put(task.returnIds[0], result);
}
private static String formatTaskExecutionExceptionMsg(TaskSpec task, String funcName) {
return "Execute task " + task.taskId
+ " failed with function name = " + funcName;
}
private static void safePut(UniqueId objectId, Object obj) {
AbstractRayRuntime.getInstance().put(objectId, obj);
}
}
@@ -62,8 +62,7 @@ public class LocalFunctionManager {
* it may block for a while if the related resources (e.g., jars) are not ready on local machine
*/
public Pair<ClassLoader, RayMethod> getMethod(UniqueId driverId, UniqueId actorId,
UniqueId methodId,
FunctionArg[] args) throws NoSuchMethodException, SecurityException, ClassNotFoundException {
UniqueId methodId, FunctionArg[] args) {
Preconditions.checkArgument(args.length >= 1, "method's args len %s<=1", args.length);
String className = (String) Serializer.decode(args[args.length - 1].data);
return getMethod(driverId, actorId, methodId, className);
@@ -84,20 +83,20 @@ public class LocalFunctionManager {
final ClassLoader classLoader;
final ConcurrentHashMap<String, RayTaskMethods> taskMethods = new ConcurrentHashMap<>();
final ConcurrentHashMap<String, RayActorMethods> actors = new ConcurrentHashMap<>();
final ConcurrentHashMap<String, RayActorMethods> actorMethods = new ConcurrentHashMap<>();
FunctionTable(ClassLoader classLoader) {
this.classLoader = classLoader;
}
RayMethod getTaskMethod(UniqueId methodId, String className) {
RayTaskMethods tasks = taskMethods.get(className);
if (tasks == null) {
tasks = RayTaskMethods.fromClass(className, classLoader);
RayLog.core.info("create RayTaskMethods:" + tasks);
taskMethods.put(className, tasks);
RayTaskMethods taskMethods = this.taskMethods.get(className);
if (taskMethods == null) {
taskMethods = RayTaskMethods.fromClass(className, classLoader);
RayLog.core.info("create RayTaskMethods: {}", taskMethods);
this.taskMethods.put(className, taskMethods);
}
RayMethod m = tasks.functions.get(methodId);
RayMethod m = taskMethods.functions.get(methodId);
if (m != null) {
return m;
}
@@ -110,13 +109,14 @@ public class LocalFunctionManager {
}
private RayMethod getActorMethod(UniqueId methodId, String className, boolean isStatic) {
RayActorMethods actor = actors.get(className);
if (actor == null) {
actor = RayActorMethods.fromClass(className, classLoader);
RayLog.core.info("create RayActorMethods:" + actor);
actors.put(className, actor);
RayActorMethods actorMethods = this.actorMethods.get(className);
if (actorMethods == null) {
actorMethods = RayActorMethods.fromClass(className, classLoader);
RayLog.core.info("create RayActorMethods: {}", actorMethods);
this.actorMethods.put(className, actorMethods);
}
return isStatic ? actor.staticFunctions.get(methodId) : actor.functions.get(methodId);
return isStatic ? actorMethods.staticFunctions.get(methodId)
: actorMethods.functions.get(methodId);
}
}
}
}
@@ -68,5 +68,4 @@ public class UniqueIdHelper {
return new UniqueId(taskId);
}
}
@@ -1,9 +1,8 @@
package org.ray.core;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.exception.RayException;
import org.ray.api.id.UniqueId;
import org.ray.spi.LocalSchedulerLink;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.TaskSpec;
import org.ray.util.logger.RayLog;
@@ -14,51 +13,58 @@ import org.ray.util.logger.RayLog;
*/
public class Worker {
private final LocalSchedulerLink scheduler;
private final LocalFunctionManager functions;
private final AbstractRayRuntime runtime;
public Worker(LocalSchedulerLink scheduler, LocalFunctionManager functions) {
this.scheduler = scheduler;
this.functions = functions;
public Worker(AbstractRayRuntime runtime) {
this.runtime = runtime;
}
public void loop() {
while (true) {
RayLog.core.info(Thread.currentThread().getName() + ":fetching new task...");
TaskSpec task = scheduler.getTask();
execute(task, functions);
TaskSpec task = runtime.getLocalSchedulerClient().getTask();
execute(task);
}
}
public static void execute(TaskSpec task, LocalFunctionManager funcs) {
RayLog.core.info("Executing task {}", task.taskId);
if (!task.actorId.isNil() || (task.createActorId != null && !task.createActorId.isNil())) {
task.returnIds = ArrayUtils.subarray(task.returnIds, 0, task.returnIds.length - 1);
}
/**
* Execute a task.
*/
public void execute(TaskSpec spec) {
RayLog.core.info("Executing task {}", spec.taskId);
UniqueId returnId = spec.returnIds[0];
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
try {
Pair<ClassLoader, RayMethod> pr = funcs
.getMethod(task.driverId, task.actorId, task.functionId, task.args);
WorkerContext.prepare(task, pr.getLeft());
InvocationExecutor.execute(task, pr);
RayLog.core.info("Finished executing task {}", task.taskId);
// Get method
Pair<ClassLoader, RayMethod> pair = runtime.getLocalFunctionManager().getMethod(
spec.driverId, spec.actorId, spec.functionId, spec.args);
ClassLoader classLoader = pair.getLeft();
RayMethod method = pair.getRight();
// Set context
WorkerContext.prepare(spec, classLoader);
Thread.currentThread().setContextClassLoader(classLoader);
// Get local actor object and arguments.
Object actor = spec.isActorTask() ? runtime.localActors.get(spec.actorId) : null;
Object[] args = ArgumentsBuilder.unwrap(spec, classLoader);
// Execute the task.
Object result;
if (!method.isConstructor()) {
result = method.getMethod().invoke(actor, args);
} else {
result = method.getConstructor().newInstance(args);
}
// Set result
if (!spec.isActorCreationTask()) {
runtime.put(returnId, result);
} else {
runtime.localActors.put(returnId, result);
}
RayLog.core.info("Finished executing task {}", spec.taskId);
} catch (Exception e) {
RayLog.core.error("Failed to execute task " + task.taskId, e);
AbstractRayRuntime.getInstance().put(task.returnIds[0], e);
RayLog.core.error("Error executing task " + spec, e);
runtime.put(returnId, new RayException("Error executing task " + spec, e));
} finally {
Thread.currentThread().setContextClassLoader(oldLoader);
}
}
public UniqueId getCurrentTaskId() {
return WorkerContext.currentTask().taskId;
}
public UniqueId getCurrentTaskNextPutId() {
return UniqueIdHelper.computePutId(
WorkerContext.currentTask().taskId, WorkerContext.nextPutIndex());
}
public UniqueId[] getCurrentTaskReturnIDs() {
return WorkerContext.currentTask().returnIds;
}
}
@@ -1,10 +1,15 @@
package org.ray.spi.model;
import com.google.common.base.Preconditions;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.UniqueId;
@@ -28,22 +33,21 @@ public final class RayActorMethods {
this.staticFunctions = Collections.unmodifiableMap(new HashMap<>(staticFunctions));
}
public static RayActorMethods fromClass(String clazzName, ClassLoader classLoader) {
public static RayActorMethods fromClass(String className, ClassLoader classLoader) {
try {
Class clazz = Class.forName(clazzName, true, classLoader);
Class clazz = Class.forName(className, true, classLoader);
RayRemote remoteAnnotation = (RayRemote) clazz.getAnnotation(RayRemote.class);
Preconditions
.checkNotNull(remoteAnnotation, "%s must declare @RayRemote", clazzName);
Method[] methods = clazz.getDeclaredMethods();
Map<UniqueId, RayMethod> functions = new HashMap<>(methods.length * 2);
Map<UniqueId, RayMethod> staticFunctions = new HashMap<>(methods.length * 2);
Preconditions.checkNotNull(remoteAnnotation,
"%s must be annotated with @RayRemote", className);
for (Method m : methods) {
if (!Modifier.isPublic(m.getModifiers())) {
continue;
}
RayMethod rayMethod = RayMethod.from(m, remoteAnnotation);
if (Modifier.isStatic(m.getModifiers())) {
List<Executable> executables = new ArrayList<>(Arrays.asList(clazz.getDeclaredMethods()));
Map<UniqueId, RayMethod> functions = new HashMap<>();
Map<UniqueId, RayMethod> staticFunctions = new HashMap<>();
for (Executable e : executables) {
RayMethod rayMethod = RayMethod.from(e, remoteAnnotation);
if (Modifier.isStatic(e.getModifiers())) {
staticFunctions.put(rayMethod.getFuncId(), rayMethod);
} else {
functions.put(rayMethod.getFuncId(), rayMethod);
@@ -51,7 +55,7 @@ public final class RayActorMethods {
}
return new RayActorMethods(clazz, remoteAnnotation, functions, staticFunctions);
} catch (Exception e) {
throw new RuntimeException("failed to get RayActorMethods from " + clazzName, e);
throw new RuntimeException("failed to get RayActorMethods from " + className, e);
}
}
@@ -1,5 +1,7 @@
package org.ray.spi.model;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.UniqueId;
@@ -10,29 +12,40 @@ import org.ray.util.MethodId;
*/
public class RayMethod {
public final Method invokable;
public final Executable invokable;
public final String fullName;
public final RayRemote remoteAnnotation;
private final UniqueId funcId;
private RayMethod(Method m, RayRemote remoteAnnotation, UniqueId funcId) {
this.invokable = m;
private RayMethod(Executable e, RayRemote remoteAnnotation, UniqueId funcId) {
this.invokable = e;
this.remoteAnnotation = remoteAnnotation;
this.funcId = funcId;
fullName = m.getDeclaringClass().getName() + "." + m.getName();
fullName = e.getDeclaringClass().getName() + "." + e.getName();
}
public static RayMethod from(Method m, RayRemote parentRemoteAnnotation) {
Class<?> clazz = m.getDeclaringClass();
RayRemote remoteAnnotation = m.getAnnotation(RayRemote.class);
MethodId mid = MethodId.fromMethod(m);
public static RayMethod from(Executable e, RayRemote parentRemoteAnnotation) {
RayRemote remoteAnnotation = e.getAnnotation(RayRemote.class);
MethodId mid = MethodId.fromExecutable(e);
UniqueId funcId = new UniqueId(mid.getSha1Hash());
RayMethod method = new RayMethod(m,
RayMethod method = new RayMethod(e,
remoteAnnotation != null ? remoteAnnotation : parentRemoteAnnotation,
funcId);
return method;
}
public boolean isConstructor() {
return invokable instanceof Constructor;
}
public Constructor<?> getConstructor() {
return (Constructor<?>) invokable;
}
public Method getMethod() {
return (Method) invokable;
}
@Override
public String toString() {
return fullName;
@@ -1,9 +1,14 @@
package org.ray.spi.model;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.UniqueId;
@@ -23,20 +28,19 @@ public final class RayTaskMethods {
public static RayTaskMethods fromClass(String clazzName, ClassLoader classLoader) {
try {
Class clazz = Class.forName(clazzName, true, classLoader);
Method[] methods = clazz.getDeclaredMethods();
Map<UniqueId, RayMethod> functions = new HashMap<>(methods.length * 2);
List<Executable> executables = new ArrayList<>();
executables.addAll(Arrays.asList(clazz.getDeclaredMethods()));
executables.addAll(Arrays.asList(clazz.getConstructors()));
Map<UniqueId, RayMethod> functions = new HashMap<>(executables.size());
for (Method m : methods) {
if (!Modifier.isStatic(m.getModifiers())) {
for (Executable e : executables) {
// This executable must be either a constructor or a static method.
if (!(e instanceof Constructor)
&& !Modifier.isStatic(e.getModifiers())) {
continue;
}
//task method only for static.
RayRemote remoteAnnotation = m.getAnnotation(RayRemote.class);
if (remoteAnnotation == null) {
continue;
}
m.setAccessible(true);
RayMethod rayMethod = RayMethod.from(m, null);
e.setAccessible(true);
RayMethod rayMethod = RayMethod.from(e, null);
functions.put(rayMethod.getFuncId(), rayMethod);
}
return new RayTaskMethods(clazz, functions);
@@ -94,4 +94,11 @@ public class TaskSpec {
return builder.toString();
}
public boolean isActorTask() {
return !actorId.isNil();
}
public boolean isActorCreationTask() {
return !createActorId.isNil();
}
}
@@ -15,7 +15,7 @@ public class RayDevRuntime extends AbstractRayRuntime {
PathConfig pathConfig = new PathConfig(configReader);
RemoteFunctionManager rfm = new NopRemoteFunctionManager(params.driver_id);
MockObjectStore store = new MockObjectStore();
MockLocalScheduler scheduler = new MockLocalScheduler(store);
MockLocalScheduler scheduler = new MockLocalScheduler(this, store);
init(scheduler, store, rfm, pathConfig);
scheduler.setLocalFunctionManager(this.functions);
}
@@ -6,6 +6,7 @@ import java.util.concurrent.ConcurrentHashMap;
import org.ray.api.id.UniqueId;
import org.ray.core.LocalFunctionManager;
import org.ray.core.Worker;
import org.ray.core.impl.RayDevRuntime;
import org.ray.spi.LocalSchedulerLink;
import org.ray.spi.model.FunctionArg;
import org.ray.spi.model.TaskSpec;
@@ -19,8 +20,10 @@ public class MockLocalScheduler implements LocalSchedulerLink {
private final Map<UniqueId, Map<UniqueId, TaskSpec>> waitTasks = new ConcurrentHashMap<>();
private final MockObjectStore store;
private LocalFunctionManager functions = null;
private final RayDevRuntime runtime;
public MockLocalScheduler(MockObjectStore store) {
public MockLocalScheduler(RayDevRuntime runtime, MockObjectStore store) {
this.runtime = runtime;
this.store = store;
store.registerScheduler(this);
}
@@ -43,7 +46,7 @@ public class MockLocalScheduler implements LocalSchedulerLink {
public void submitTask(TaskSpec task) {
UniqueId id = isTaskReady(task);
if (id == null) {
Worker.execute(task, functions);
runtime.getWorker().execute(task);
} else {
Map<UniqueId, TaskSpec> bucket = waitTasks
.computeIfAbsent(id, id_ -> new ConcurrentHashMap<>());
@@ -14,7 +14,7 @@ public class ActorPressTest extends RayBenchmarkTest {
@Test
public void singleLatencyTest() {
int times = 10;
RayActor<ActorPressTest.Adder> adder = Ray.createActor(ActorPressTest.Adder.class);
RayActor<ActorPressTest.Adder> adder = Ray.createActor(ActorPressTest.Adder::new);
super.singleLatencyTest(times, adder);
}
@@ -22,7 +22,7 @@ public class ActorPressTest extends RayBenchmarkTest {
public void maxTest() {
int clientNum = 2;
int totalNum = 20;
RayActor<ActorPressTest.Adder> adder = Ray.createActor(ActorPressTest.Adder.class);
RayActor<ActorPressTest.Adder> adder = Ray.createActor(ActorPressTest.Adder::new);
PressureTestParameter pressureTestParameter = new PressureTestParameter();
pressureTestParameter.setClientNum(clientNum);
pressureTestParameter.setTotalNum(totalNum);
@@ -36,7 +36,7 @@ public class ActorPressTest extends RayBenchmarkTest {
int clientNum = 2;
int totalQps = 2;
int duration = 3;
RayActor<ActorPressTest.Adder> adder = Ray.createActor(ActorPressTest.Adder.class);
RayActor<ActorPressTest.Adder> adder = Ray.createActor(ActorPressTest.Adder::new);
PressureTestParameter pressureTestParameter = new PressureTestParameter();
pressureTestParameter.setClientNum(clientNum);
pressureTestParameter.setTotalQps(totalQps);
@@ -16,40 +16,61 @@ public class ActorTest {
@RayRemote
public static class Counter {
private int value = 0;
private int value;
public int incr(int delta) {
public Counter(int initValue) {
this.value = initValue;
}
public int getValue() {
return value;
}
public int increase(int delta) {
value += delta;
return value;
}
}
@Test
public void testCreateAndCallActor() {
// Test creating an actor from a constructor
RayActor<Counter> actor = Ray.createActor(Counter::new, 1);
Assert.assertNotEquals(actor.getId(), UniqueId.NIL);
// Test calling an actor
Assert.assertEquals(Integer.valueOf(1), Ray.call(Counter::getValue, actor).get());
Assert.assertEquals(Integer.valueOf(11), Ray.call(Counter::increase, actor, 10).get());
}
@RayRemote
public static Counter factory(int initValue) {
return new Counter(initValue);
}
@Test
public void testCreateActorFromFactory() {
// Test creating an actor from a factory method
RayActor<Counter> actor = Ray.createActor(ActorTest::factory, 1);
Assert.assertNotEquals(actor.getId(), UniqueId.NIL);
// Test calling an actor
Assert.assertEquals(Integer.valueOf(1), Ray.call(Counter::getValue, actor).get());
}
@RayRemote
public static int testActorAsFirstParameter(RayActor<Counter> actor, int delta) {
RayObject<Integer> res = Ray.call(Counter::incr, actor, delta);
RayObject<Integer> res = Ray.call(Counter::increase, actor, delta);
return res.get();
}
@RayRemote
public static int testActorAsSecondParameter(int delta, RayActor<Counter> actor) {
RayObject<Integer> res = Ray.call(Counter::incr, actor, delta);
RayObject<Integer> res = Ray.call(Counter::increase, actor, delta);
return res.get();
}
@Test
public void testCreateAndCallActor() {
// Test creating an actor
RayActor<Counter> actor = Ray.createActor(Counter.class);
Assert.assertNotEquals(actor.getId(), UniqueId.NIL);
// Test calling an actor
RayFunc2<Counter, Integer, Integer> f = Counter::incr;
Assert.assertEquals(Integer.valueOf(1), Ray.call(f, actor, 1).get());
Assert.assertEquals(Integer.valueOf(11), Ray.call(Counter::incr, actor, 10).get());
}
@Test
public void testPassActorAsParameter() {
RayActor<Counter> actor = Ray.createActor(Counter.class);
RayActor<Counter> actor = Ray.createActor(Counter::new, 0);
RayFunc2<RayActor, Integer, Integer> f = ActorTest::testActorAsFirstParameter;
Assert.assertEquals(Integer.valueOf(1),
Ray.call(ActorTest::testActorAsFirstParameter, actor, 1).get());
@@ -1,46 +0,0 @@
package org.ray.api.test;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.annotation.RayRemote;
@RunWith(MyRunner.class)
public class EchoTest {
@RayRemote
public static String hi() {
return "hi";
}
@RayRemote
public static String who(String who) {
return who;
}
@RayRemote
public static String recho(String pre, String who) {
return pre + ", " + who + "!";
}
@Test
public void test() {
long startTime = 0;
long endTime = 0;
for (int i = 0; i < 100; i++) {
startTime = System.nanoTime();
String ret = echo("Ray++" + i);
endTime = System.nanoTime();
System.out.println("echo: " + ret + " , total time is " + (endTime - startTime));
}
}
public String echo(String who) {
return Ray.call(
EchoTest::recho,
Ray.call(EchoTest::hi),
Ray.call(EchoTest::who, who)
).get();
}
}
@@ -1,38 +1,44 @@
package org.ray.api.test;
import java.lang.reflect.Method;
import java.lang.reflect.Executable;
import org.junit.Assert;
import org.junit.Test;
import org.ray.api.function.RayFunc3;
import org.ray.api.function.RayFunc2;
import org.ray.util.MethodId;
import org.ray.util.logger.RayLog;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class MethodIdTest {
public static <T0, T1, T2, R0> MethodId fromLambda(RayFunc3<T0, T1, T2, R0> f) {
MethodId mid = MethodId.fromSerializedLambda(f, true);
return mid;
}
public static MethodId fromClass(Method method) {
return MethodId.fromMethod(method);
}
private static final Logger LOGGER = LoggerFactory.getLogger(MethodIdTest.class);
@Test
public void testMethodId2From() throws Exception {
MethodId m1 = fromLambda(MethodIdTest::call);
Method m = MethodIdTest.class.getDeclaredMethod("call", new Class[]{long.class, String.class});
MethodId m2 = fromClass(m);
RayLog.core.info(m1.toString());
public void testNormalMethod() throws Exception {
RayFunc2<Integer, String, String> f = MethodIdTest::foo;
MethodId m1 = MethodId.fromSerializedLambda(f);
Executable e = MethodIdTest.class.getDeclaredMethod("foo", int.class, String.class);
MethodId m2 = MethodId.fromExecutable(e);
LOGGER.info("{}, {}", m1, m2);
Assert.assertEquals(m1, m2);
}
public String call(long v, String s) {
for (int i = 0; i < 100; i++) {
v += i;
}
RayLog.core.info("call:" + v);
return String.valueOf(v);
@Test
public void testConstructor() throws Exception {
RayFunc2<Integer, String, Foo> f = Foo::new;
MethodId m1 = MethodId.fromSerializedLambda(f);
Executable e = Foo.class.getConstructor(int.class, String.class);
MethodId m2 = MethodId.fromExecutable(e);
LOGGER.info("{}, {}", m1, m2);
Assert.assertEquals(m1, m2);
}
}
public static String foo(int a, String b) {
return a + b;
}
public static class Foo {
public Foo(int a, String b) {}
}
}
@@ -4,10 +4,13 @@ import org.junit.Assert;
import org.junit.Test;
import org.ray.api.annotation.RayRemote;
import org.ray.spi.model.RayActorMethods;
import org.ray.util.logger.RayLog;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class RayActorMethodsTest {
private static final Logger LOGGER = LoggerFactory.getLogger(RayActorMethodsTest.class);
@RayRemote
public static class ExampleActor {
@@ -22,7 +25,7 @@ public class RayActorMethodsTest {
public void testActorMethods() {
RayActorMethods methods = RayActorMethods
.fromClass(ExampleActor.class.getName(), RayActorMethodsTest.class.getClassLoader());
RayLog.core.info(methods.toString());
LOGGER.info(methods.toString());
Assert.assertEquals(methods.functions.size(), 2);
Assert.assertEquals(methods.staticFunctions.size(), 1);
}
@@ -1,18 +1,44 @@
package org.ray.api.test;
import java.lang.reflect.Constructor;
import org.junit.Assert;
import org.junit.Test;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.RayTaskMethods;
import org.ray.util.logger.RayLog;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class RayTaskMethodsTest {
@Test
public void testTask() throws Exception {
RayTaskMethods methods = RayTaskMethods
.fromClass(EchoTest.class.getName(), RayTaskMethodsTest.class.getClassLoader());
RayLog.core.info(methods.toString());
Assert.assertEquals(methods.functions.size(), 3);
private static final Logger LOGGER = LoggerFactory.getLogger(RayTaskMethodsTest.class);
private static class Foo {
public Foo() {}
public Foo(int x) {}
public static void f1() {}
public void f2() {}
}
}
@Test
public void testTask() {
RayTaskMethods methods = RayTaskMethods
.fromClass(Foo.class.getName(), Foo.class.getClassLoader());
LOGGER.info(methods.toString());
int numMethods = 0;
int numConstructors = 0;
for (RayMethod m : methods.functions.values()) {
if (m.isConstructor()) {
numConstructors += 1;
} else {
numMethods += 1;
}
}
Assert.assertEquals(numMethods, 1);
Assert.assertEquals(numConstructors, 2);
}
}
@@ -66,12 +66,12 @@ public class ResourcesManagementTest {
public void testActors() {
Assume.assumeTrue(AbstractRayRuntime.getParams().use_raylet);
// This is a case that can satisfy required resources.
RayActor<ResourcesManagementTest.Echo1> echo1 = Ray.createActor(Echo1.class);
RayActor<ResourcesManagementTest.Echo1> echo1 = Ray.createActor(Echo1::new);
final RayObject<Integer> result1 = Ray.call(Echo1::echo, echo1, 100);
Assert.assertEquals(100, (int) result1.get());
// This is a case that can't satisfy required resources.
RayActor<ResourcesManagementTest.Echo2> echo2 = Ray.createActor(Echo2.class);
RayActor<ResourcesManagementTest.Echo2> echo2 = Ray.createActor(Echo2::new);
final RayObject<Integer> result2 = Ray.call(Echo2::echo, echo2, 100);
WaitResult<Integer> waitResult = Ray.wait(ImmutableList.of(result2), 1, 1000);
@@ -14,7 +14,7 @@ public class Exercise05 {
try {
Ray.init();
// `Ray.createActor` creates an actor instance.
RayActor<Adder> adder = Ray.createActor(Adder.class);
RayActor<Adder> adder = Ray.createActor(Adder::new, 0);
// Use `Ray.call(actor, parameters)` to call an actor method.
RayObject<Integer> result1 = Ray.call(Adder::add, adder, 1);
System.out.println(result1.get());
@@ -34,8 +34,8 @@ public class Exercise05 {
@RayRemote
public static class Adder {
public Adder() {
sum = 0;
public Adder(int initValue) {
sum = initValue;
}
public int add(int n) {