mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[Java] Support calling functions returning void (#5494)
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -7,5 +7,6 @@ package org.ray.api.function;
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface RayFunc0<R> extends RayFunc {
|
||||
R apply();
|
||||
|
||||
R apply() throws Exception;
|
||||
}
|
||||
|
||||
@@ -7,5 +7,6 @@ package org.ray.api.function;
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface RayFunc1<T0, R> extends RayFunc {
|
||||
R apply(T0 t0);
|
||||
|
||||
R apply(T0 t0) throws Exception;
|
||||
}
|
||||
|
||||
@@ -7,5 +7,6 @@ package org.ray.api.function;
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface RayFunc2<T0, T1, R> extends RayFunc {
|
||||
R apply(T0 t0, T1 t1);
|
||||
|
||||
R apply(T0 t0, T1 t1) throws Exception;
|
||||
}
|
||||
|
||||
@@ -7,5 +7,6 @@ package org.ray.api.function;
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface RayFunc3<T0, T1, T2, R> extends RayFunc {
|
||||
R apply(T0 t0, T1 t1, T2 t2);
|
||||
|
||||
R apply(T0 t0, T1 t1, T2 t2) throws Exception;
|
||||
}
|
||||
|
||||
@@ -7,5 +7,6 @@ package org.ray.api.function;
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface RayFunc4<T0, T1, T2, T3, R> extends RayFunc {
|
||||
R apply(T0 t0, T1 t1, T2 t2, T3 t3);
|
||||
|
||||
R apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Exception;
|
||||
}
|
||||
|
||||
@@ -7,5 +7,6 @@ package org.ray.api.function;
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface RayFunc5<T0, T1, T2, T3, T4, R> extends RayFunc {
|
||||
R apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4);
|
||||
|
||||
R apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Exception;
|
||||
}
|
||||
|
||||
@@ -7,5 +7,6 @@ package org.ray.api.function;
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface RayFunc6<T0, T1, T2, T3, T4, T5, R> extends RayFunc {
|
||||
R apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5);
|
||||
|
||||
R apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
package org.ray.api.function;
|
||||
|
||||
/**
|
||||
* Interface of all `RayFuncVoidX` classes.
|
||||
*/
|
||||
public interface RayFuncVoid extends RayFunc {
|
||||
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
// generated automatically, do not modify.
|
||||
|
||||
package org.ray.api.function;
|
||||
|
||||
/**
|
||||
* Functional interface for a remote function that has 0 parameter.
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface RayFuncVoid0 extends RayFuncVoid {
|
||||
|
||||
void apply() throws Exception;
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
// generated automatically, do not modify.
|
||||
|
||||
package org.ray.api.function;
|
||||
|
||||
/**
|
||||
* Functional interface for a remote function that has 1 parameter.
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface RayFuncVoid1<T0> extends RayFuncVoid {
|
||||
|
||||
void apply(T0 t0) throws Exception;
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
// generated automatically, do not modify.
|
||||
|
||||
package org.ray.api.function;
|
||||
|
||||
/**
|
||||
* Functional interface for a remote function that has 2 parameters.
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface RayFuncVoid2<T0, T1> extends RayFuncVoid {
|
||||
|
||||
void apply(T0 t0, T1 t1) throws Exception;
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
// generated automatically, do not modify.
|
||||
|
||||
package org.ray.api.function;
|
||||
|
||||
/**
|
||||
* Functional interface for a remote function that has 3 parameters.
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface RayFuncVoid3<T0, T1, T2> extends RayFuncVoid {
|
||||
|
||||
void apply(T0 t0, T1 t1, T2 t2) throws Exception;
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
// generated automatically, do not modify.
|
||||
|
||||
package org.ray.api.function;
|
||||
|
||||
/**
|
||||
* Functional interface for a remote function that has 4 parameters.
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface RayFuncVoid4<T0, T1, T2, T3> extends RayFuncVoid {
|
||||
|
||||
void apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Exception;
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
// generated automatically, do not modify.
|
||||
|
||||
package org.ray.api.function;
|
||||
|
||||
/**
|
||||
* Functional interface for a remote function that has 5 parameters.
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface RayFuncVoid5<T0, T1, T2, T3, T4> extends RayFuncVoid {
|
||||
|
||||
void apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Exception;
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
// generated automatically, do not modify.
|
||||
|
||||
package org.ray.api.function;
|
||||
|
||||
/**
|
||||
* Functional interface for a remote function that has 6 parameters.
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface RayFuncVoid6<T0, T1, T2, T3, T4, T5> extends RayFuncVoid {
|
||||
|
||||
void apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception;
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import org.ray.api.RayPyActor;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.function.RayFunc;
|
||||
import org.ray.api.function.RayFuncVoid;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
@@ -107,7 +108,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
FunctionDescriptor functionDescriptor =
|
||||
functionManager.getFunction(workerContext.getCurrentJobId(), func)
|
||||
.functionDescriptor;
|
||||
return callNormalFunction(functionDescriptor, args, options);
|
||||
int numReturns = func instanceof RayFuncVoid ? 0 : 1;
|
||||
return callNormalFunction(functionDescriptor, args, numReturns, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -115,7 +117,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
FunctionDescriptor functionDescriptor =
|
||||
functionManager.getFunction(workerContext.getCurrentJobId(), func)
|
||||
.functionDescriptor;
|
||||
return callActorFunction(actor, functionDescriptor, args);
|
||||
int numReturns = func instanceof RayFuncVoid ? 0 : 1;
|
||||
return callActorFunction(actor, functionDescriptor, args, numReturns);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -143,7 +146,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(moduleName, "",
|
||||
functionName);
|
||||
return callNormalFunction(functionDescriptor, args, options);
|
||||
// Python functions always have a return value, even if it's `None`.
|
||||
return callNormalFunction(functionDescriptor, args, /*numReturns=*/1, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -151,7 +155,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(),
|
||||
pyActor.getClassName(), functionName);
|
||||
return callActorFunction(pyActor, functionDescriptor, args);
|
||||
// Python functions always have a return value, even if it's `None`.
|
||||
return callActorFunction(pyActor, functionDescriptor, args, /*numReturns=*/1);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -164,21 +169,31 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
}
|
||||
|
||||
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
|
||||
Object[] args, CallOptions options) {
|
||||
Object[] args, int numReturns, CallOptions options) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder
|
||||
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
|
||||
List<ObjectId> returnIds = taskSubmitter.submitTask(functionDescriptor,
|
||||
functionArgs, 1, options);
|
||||
return new RayObjectImpl(returnIds.get(0));
|
||||
functionArgs, numReturns, options);
|
||||
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
|
||||
if (returnIds.isEmpty()) {
|
||||
return null;
|
||||
} else {
|
||||
return new RayObjectImpl(returnIds.get(0));
|
||||
}
|
||||
}
|
||||
|
||||
private RayObject callActorFunction(RayActor rayActor,
|
||||
FunctionDescriptor functionDescriptor, Object[] args) {
|
||||
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder
|
||||
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
|
||||
List<ObjectId> returnIds = taskSubmitter.submitActorTask(rayActor,
|
||||
functionDescriptor, functionArgs, 1, null);
|
||||
return new RayObjectImpl(returnIds.get(0));
|
||||
functionDescriptor, functionArgs, numReturns, null);
|
||||
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
|
||||
if (returnIds.isEmpty()) {
|
||||
return null;
|
||||
} else {
|
||||
return new RayObjectImpl(returnIds.get(0));
|
||||
}
|
||||
}
|
||||
|
||||
private RayActor createActorImpl(FunctionDescriptor functionDescriptor,
|
||||
|
||||
@@ -72,6 +72,17 @@ public class RayFunction {
|
||||
return rayRemote;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Whether this function has a return value.
|
||||
*/
|
||||
public boolean hasReturn() {
|
||||
if (isConstructor()) {
|
||||
return true;
|
||||
} else {
|
||||
return !getMethod().getReturnType().equals(void.class);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return executable.toString();
|
||||
|
||||
@@ -68,6 +68,18 @@ public abstract class ObjectStore {
|
||||
return putRaw(serialize(object));
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize and put an object to the object store, with the given object id.
|
||||
*
|
||||
* This method is only used for testing.
|
||||
*
|
||||
* @param object The object to put.
|
||||
* @param objectId Object id.
|
||||
*/
|
||||
public void put(Object object, ObjectId objectId) {
|
||||
putRaw(serialize(object), objectId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a list of raw objects from the object store.
|
||||
*
|
||||
@@ -156,7 +168,8 @@ public abstract class ObjectStore {
|
||||
* Delete a list of objects from the object store.
|
||||
*
|
||||
* @param objectIds IDs of the objects to delete.
|
||||
* @param localOnly Whether only delete the objects in local node, or all nodes in the cluster.
|
||||
* @param localOnly Whether only delete the objects in local node, or all nodes in the
|
||||
* cluster.
|
||||
* @param deleteCreatingTasks Whether also delete the tasks that created these objects.
|
||||
*/
|
||||
public abstract void delete(List<ObjectId> objectIds, boolean localOnly,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.protobuf.ByteString;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayDeque;
|
||||
@@ -160,7 +161,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
@Override
|
||||
public List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
int numReturns, CallOptions options) {
|
||||
Preconditions.checkState(numReturns == 1);
|
||||
Preconditions.checkState(numReturns <= 1);
|
||||
TaskSpec taskSpec = getTaskSpecBuilder(TaskType.NORMAL_TASK, functionDescriptor, args)
|
||||
.setNumReturns(numReturns)
|
||||
.build();
|
||||
@@ -185,7 +186,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
@Override
|
||||
public List<ObjectId> submitActorTask(RayActor actor, FunctionDescriptor functionDescriptor,
|
||||
List<FunctionArg> args, int numReturns, CallOptions options) {
|
||||
Preconditions.checkState(numReturns == 1);
|
||||
Preconditions.checkState(numReturns <= 1);
|
||||
TaskSpec.Builder builder = getTaskSpecBuilder(TaskType.ACTOR_TASK, functionDescriptor, args);
|
||||
List<ObjectId> returnIds = getReturnIds(
|
||||
TaskId.fromBytes(builder.getTaskId().toByteArray()), numReturns + 1);
|
||||
@@ -200,7 +201,11 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
.build())
|
||||
.build();
|
||||
submitTaskSpec(taskSpec);
|
||||
return Collections.singletonList(returnIds.get(0));
|
||||
if (numReturns == 0) {
|
||||
return ImmutableList.of();
|
||||
} else {
|
||||
return ImmutableList.of(returnIds.get(0));
|
||||
}
|
||||
}
|
||||
|
||||
public static ActorId getActorId(TaskSpec taskSpec) {
|
||||
|
||||
@@ -70,10 +70,11 @@ public final class TaskExecutor {
|
||||
|
||||
List<NativeRayObject> returnObjects = new ArrayList<>();
|
||||
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
|
||||
// Find the executable object.
|
||||
RayFunction rayFunction = runtime.getFunctionManager()
|
||||
.getFunction(jobId, parseFunctionDescriptor(rayFunctionInfo));
|
||||
Preconditions.checkNotNull(rayFunction);
|
||||
try {
|
||||
// Get method
|
||||
RayFunction rayFunction = runtime.getFunctionManager()
|
||||
.getFunction(jobId, parseFunctionDescriptor(rayFunctionInfo));
|
||||
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
|
||||
runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);
|
||||
|
||||
@@ -100,7 +101,9 @@ public final class TaskExecutor {
|
||||
// TODO (kfstorm): handle checkpoint in core worker.
|
||||
maybeSaveCheckpoint(actor, runtime.getWorkerContext().getCurrentActorId());
|
||||
}
|
||||
returnObjects.add(runtime.getObjectStore().serialize(result));
|
||||
if (rayFunction.hasReturn()) {
|
||||
returnObjects.add(runtime.getObjectStore().serialize(result));
|
||||
}
|
||||
} else {
|
||||
// TODO (kfstorm): handle checkpoint in core worker.
|
||||
maybeLoadCheckpoint(result, runtime.getWorkerContext().getCurrentActorId());
|
||||
@@ -110,8 +113,10 @@ public final class TaskExecutor {
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Error executing task " + taskId, e);
|
||||
if (taskType != TaskType.ACTOR_CREATION_TASK) {
|
||||
returnObjects.add(runtime.getObjectStore()
|
||||
.serialize(new RayTaskException("Error executing task " + taskId, e)));
|
||||
if(rayFunction.hasReturn()) {
|
||||
returnObjects.add(runtime.getObjectStore()
|
||||
.serialize(new RayTaskException("Error executing task " + taskId, e)));
|
||||
}
|
||||
} else {
|
||||
actorCreationException = e;
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@ import java.util.List;
|
||||
import org.ray.runtime.util.FileUtil;
|
||||
|
||||
/**
|
||||
* A util class that generates `RayCall.java`,
|
||||
* which provides type-safe interfaces for `Ray.call` and `Ray.createActor`.
|
||||
* A util class that generates `RayCall.java`, which provides type-safe interfaces for `Ray.call`
|
||||
* and `Ray.createActor`.
|
||||
*/
|
||||
public class RayCallGenerator extends BaseGenerator {
|
||||
|
||||
@@ -21,13 +21,12 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
newLine("");
|
||||
newLine("package org.ray.api;");
|
||||
newLine("");
|
||||
newLine("import org.ray.api.function.RayFunc0;");
|
||||
newLine("import org.ray.api.function.RayFunc1;");
|
||||
newLine("import org.ray.api.function.RayFunc2;");
|
||||
newLine("import org.ray.api.function.RayFunc3;");
|
||||
newLine("import org.ray.api.function.RayFunc4;");
|
||||
newLine("import org.ray.api.function.RayFunc5;");
|
||||
newLine("import org.ray.api.function.RayFunc6;");
|
||||
for (int i = 0; i <= MAX_PARAMETERS; i++) {
|
||||
newLine("import org.ray.api.function.RayFunc" + i + ";");
|
||||
}
|
||||
for (int i = 0; i <= MAX_PARAMETERS; i++) {
|
||||
newLine("import org.ray.api.function.RayFuncVoid" + i + ";");
|
||||
}
|
||||
newLine("import org.ray.api.options.ActorCreationOptions;");
|
||||
newLine("import org.ray.api.options.CallOptions;");
|
||||
newLine("");
|
||||
@@ -41,22 +40,25 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
newLine(1, "// Methods for remote function invocation.");
|
||||
newLine(1, "// =======================================");
|
||||
for (int i = 0; i <= MAX_PARAMETERS; i++) {
|
||||
buildCalls(i, false, false, false);
|
||||
buildCalls(i, false, false, true);
|
||||
buildCalls(i, false, false, true, false);
|
||||
buildCalls(i, false, false, true, true);
|
||||
buildCalls(i, false, false, false, false);
|
||||
buildCalls(i, false, false, false, true);
|
||||
}
|
||||
|
||||
newLine(1, "// ===========================================");
|
||||
newLine(1, "// Methods for remote actor method invocation.");
|
||||
newLine(1, "// ===========================================");
|
||||
for (int i = 0; i <= MAX_PARAMETERS - 1; i++) {
|
||||
buildCalls(i, true, false, false);
|
||||
buildCalls(i, true, false, true, false);
|
||||
buildCalls(i, true, false, false, false);
|
||||
}
|
||||
newLine(1, "// ===========================");
|
||||
newLine(1, "// Methods for actor creation.");
|
||||
newLine(1, "// ===========================");
|
||||
for (int i = 0; i <= MAX_PARAMETERS; i++) {
|
||||
buildCalls(i, false, true, false);
|
||||
buildCalls(i, false, true, true);
|
||||
buildCalls(i, false, true, true, false);
|
||||
buildCalls(i, false, true, true, true);
|
||||
}
|
||||
|
||||
newLine(1, "// ===========================");
|
||||
@@ -71,7 +73,7 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
}
|
||||
for (int i = 0; i <= MAX_PARAMETERS; i++) {
|
||||
buildPyCalls(i, false, true, false);
|
||||
buildPyCalls(i,false, true, true);
|
||||
buildPyCalls(i, false, true, true);
|
||||
}
|
||||
newLine("}");
|
||||
return sb.toString();
|
||||
@@ -79,72 +81,106 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
|
||||
/**
|
||||
* Build the `Ray.call` or `Ray.createActor` methods with the given number of parameters.
|
||||
*
|
||||
* @param numParameters the number of parameters
|
||||
* @param forActor build actor api when true, otherwise build task api.
|
||||
* @param hasReturn if true, build api for functions with return.
|
||||
* @param forActorCreation build `Ray.createActor` when true, otherwise build `Ray.call`.
|
||||
*/
|
||||
private void buildCalls(int numParameters, boolean forActor,
|
||||
boolean forActorCreation, boolean hasOptionsParam) {
|
||||
boolean forActorCreation, boolean hasReturn, boolean hasOptionsParam) {
|
||||
// Template of the generated function:
|
||||
// public static [genericTypes] [returnType] [callFunc]([argsDeclaration]) {
|
||||
// Objects[] args = new Object[]{[args]};
|
||||
// return Ray.internal().[callFunc](f[, actor], args[, options]);
|
||||
// }
|
||||
|
||||
// 1) Construct the `genericTypes` part, e.g. `<T0, T1, T2, R>`.
|
||||
String genericTypes = "";
|
||||
String argList = "";
|
||||
for (int i = 0; i < numParameters; i++) {
|
||||
genericTypes += "T" + i + ", ";
|
||||
argList += "t" + i + ", ";
|
||||
}
|
||||
if (forActor) {
|
||||
// Actor generic type.
|
||||
genericTypes = "A, " + genericTypes;
|
||||
}
|
||||
genericTypes += forActorCreation ? "A" : "R";
|
||||
if (argList.endsWith(", ")) {
|
||||
argList = argList.substring(0, argList.length() - 2);
|
||||
// Return generic type.
|
||||
if (forActorCreation) {
|
||||
genericTypes += "A, ";
|
||||
} else {
|
||||
if (hasReturn) {
|
||||
genericTypes += "R, ";
|
||||
}
|
||||
}
|
||||
if (!genericTypes.isEmpty()) {
|
||||
// Trim trailing ", ";
|
||||
genericTypes = genericTypes.substring(0, genericTypes.length() - 2);
|
||||
genericTypes = "<" + genericTypes + ">";
|
||||
}
|
||||
|
||||
String paramPrefix = String.format("RayFunc%d<%s> f",
|
||||
// 2) Construct the `returnType` part.
|
||||
String returnType;
|
||||
if (forActorCreation) {
|
||||
returnType = "RayActor<A>";
|
||||
} else {
|
||||
returnType = hasReturn ? "RayObject<R>" : "void";
|
||||
}
|
||||
|
||||
// 3) Construct the `argsDeclaration` part.
|
||||
String argsDeclarationPrefix = String.format("RayFunc%s%d%s f, ",
|
||||
hasReturn ? "" : "Void",
|
||||
!forActor ? numParameters : numParameters + 1,
|
||||
genericTypes);
|
||||
if (forActor) {
|
||||
paramPrefix += ", RayActor<A> actor";
|
||||
}
|
||||
if (numParameters > 0) {
|
||||
paramPrefix += ", ";
|
||||
argsDeclarationPrefix += "RayActor<A> actor, ";
|
||||
}
|
||||
|
||||
String optionsParam;
|
||||
if (hasOptionsParam) {
|
||||
optionsParam = forActorCreation ? ", ActorCreationOptions options" : ", CallOptions options";
|
||||
} else {
|
||||
optionsParam = "";
|
||||
}
|
||||
String callFunc = forActorCreation ? "createActor" : "call";
|
||||
|
||||
String optionsArg;
|
||||
if (forActor) {
|
||||
optionsArg = "";
|
||||
} else {
|
||||
// Enumerate all combinations of the parameters.
|
||||
for (String param : generateParameters(numParameters)) {
|
||||
String argsDeclaration = argsDeclarationPrefix + param;
|
||||
if (hasOptionsParam) {
|
||||
optionsArg = ", options";
|
||||
} else {
|
||||
optionsArg = ", null";
|
||||
argsDeclaration +=
|
||||
forActorCreation ? "ActorCreationOptions options, " : "CallOptions options, ";
|
||||
}
|
||||
}
|
||||
|
||||
String returnType = !forActorCreation ? "RayObject<R>" : "RayActor<A>";
|
||||
String funcName = !forActorCreation ? "call" : "createActor";
|
||||
String funcArgs = !forActor ? "f, args" : "f, actor, args";
|
||||
for (String param : generateParameters(0, numParameters)) {
|
||||
// Method signature.
|
||||
// Trim trailing ", ";
|
||||
argsDeclaration = argsDeclaration.substring(0, argsDeclaration.length() - 2);
|
||||
// Print the first line (method signature).
|
||||
newLine(1, String.format(
|
||||
"public static <%s> %s %s(%s%s) {",
|
||||
genericTypes, returnType, funcName, paramPrefix + param, optionsParam
|
||||
"public static%s %s %s(%s) {",
|
||||
genericTypes.isEmpty() ? "" : " " + genericTypes, returnType, callFunc, argsDeclaration
|
||||
));
|
||||
// Method body.
|
||||
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
|
||||
newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg));
|
||||
|
||||
// 4) Construct the `args` part.
|
||||
String args = "";
|
||||
for (int i = 0; i < numParameters; i++) {
|
||||
args += "t" + i + ", ";
|
||||
}
|
||||
// Trim trailing ", ";
|
||||
if (!args.isEmpty()) {
|
||||
args = args.substring(0, args.length() - 2);
|
||||
}
|
||||
// Print the second line (local args declaration).
|
||||
newLine(2, String.format("Object[] args = new Object[]{%s};", args));
|
||||
|
||||
// 5) Construct the third line.
|
||||
String callFuncArgs = "f, ";
|
||||
if (forActor) {
|
||||
callFuncArgs += "actor, ";
|
||||
}
|
||||
callFuncArgs += "args, ";
|
||||
callFuncArgs += forActor ? "" : hasOptionsParam ? "options, " : "null, ";
|
||||
callFuncArgs = callFuncArgs.substring(0, callFuncArgs.length() - 2);
|
||||
newLine(2, String.format("%sRay.internal().%s(%s);",
|
||||
hasReturn ? "return " : "", callFunc, callFuncArgs));
|
||||
newLine(1, "}");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the `Ray.callPy` or `Ray.createPyActor` methods.
|
||||
*
|
||||
* @param forActor build actor api when true, otherwise build task api.
|
||||
* @param forActorCreation build `Ray.createPyActor` when true, otherwise build `Ray.callPy`.
|
||||
*/
|
||||
@@ -211,24 +247,21 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
newLine(1, "}");
|
||||
}
|
||||
|
||||
private List<String> generateParameters(int from, int to) {
|
||||
private List<String> generateParameters(int numParams) {
|
||||
List<String> res = new ArrayList<>();
|
||||
dfs(from, from, to, "", res);
|
||||
dfs(0, numParams, "", res);
|
||||
return res;
|
||||
}
|
||||
|
||||
private void dfs(int pos, int from, int to, String cur, List<String> res) {
|
||||
if (pos >= to) {
|
||||
private void dfs(int pos, int numParams, String cur, List<String> res) {
|
||||
if (pos >= numParams) {
|
||||
res.add(cur);
|
||||
return;
|
||||
}
|
||||
if (pos > from) {
|
||||
cur += ", ";
|
||||
}
|
||||
String nextParameter = String.format("T%d t%d", pos, pos);
|
||||
dfs(pos + 1, from, to, cur + nextParameter, res);
|
||||
nextParameter = String.format("RayObject<T%d> t%d", pos, pos);
|
||||
dfs(pos + 1, from, to, cur + nextParameter, res);
|
||||
String nextParameter = String.format("T%d t%d, ", pos, pos);
|
||||
dfs(pos + 1, numParams, cur + nextParameter, res);
|
||||
nextParameter = String.format("RayObject<T%d> t%d, ", pos, pos);
|
||||
dfs(pos + 1, numParams, cur + nextParameter, res);
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws IOException {
|
||||
|
||||
@@ -8,7 +8,7 @@ import org.ray.runtime.util.FileUtil;
|
||||
*/
|
||||
public class RayFuncGenerator extends BaseGenerator {
|
||||
|
||||
private String generate(int numParameters) {
|
||||
private String generate(int numParameters, boolean hasReturn) {
|
||||
sb = new StringBuilder();
|
||||
|
||||
String genericTypes = "";
|
||||
@@ -20,6 +20,14 @@ public class RayFuncGenerator extends BaseGenerator {
|
||||
}
|
||||
paramList += String.format("T%d t%d", i, i);
|
||||
}
|
||||
if (hasReturn) {
|
||||
genericTypes += "R, ";
|
||||
}
|
||||
if (!genericTypes.isEmpty()) {
|
||||
// Remove trailing ", ".
|
||||
genericTypes = genericTypes.substring(0, genericTypes.length() - 2);
|
||||
genericTypes = "<" + genericTypes + ">";
|
||||
}
|
||||
|
||||
newLine("// generated automatically, do not modify.");
|
||||
newLine("");
|
||||
@@ -32,10 +40,12 @@ public class RayFuncGenerator extends BaseGenerator {
|
||||
newLine(comment);
|
||||
newLine(" */");
|
||||
newLine("@FunctionalInterface");
|
||||
newLine(String.format("public interface RayFunc%d<%sR> extends RayFunc {",
|
||||
numParameters, genericTypes));
|
||||
String className = "RayFunc" + (hasReturn ? "" : "Void") + numParameters;
|
||||
newLine(String.format("public interface %s%s extends %s {",
|
||||
className, genericTypes, hasReturn ? "RayFunc" : "RayFuncVoid"));
|
||||
newLine("");
|
||||
indents(1);
|
||||
newLine(String.format("R apply(%s);", paramList));
|
||||
newLine(String.format("%s apply(%s) throws Exception;", hasReturn ? "R" : "void", paramList));
|
||||
newLine("}");
|
||||
|
||||
return sb.toString();
|
||||
@@ -46,8 +56,12 @@ public class RayFuncGenerator extends BaseGenerator {
|
||||
+ "/api/src/main/java/org/ray/api/function/";
|
||||
RayFuncGenerator generator = new RayFuncGenerator();
|
||||
for (int i = 0; i <= MAX_PARAMETERS; i++) {
|
||||
String content = generator.generate(i);
|
||||
// Functions that have return.
|
||||
String content = generator.generate(i, true);
|
||||
FileUtil.overrideFile(root + "RayFunc" + i + ".java", content);
|
||||
// Functions that don't have return.
|
||||
content = generator.generate(i, false);
|
||||
FileUtil.overrideFile(root + "RayFuncVoid" + i + ".java", content);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import org.ray.api.Ray;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.function.RayFunc1;
|
||||
import org.ray.api.test.BaseTest;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -64,7 +65,7 @@ public abstract class RayBenchmarkTest<T> extends BaseTest implements Serializab
|
||||
long endTime = remoteResult.getFinishTime();
|
||||
long costTime = endTime - temp.getStartTime();
|
||||
counterList.add(costTime / 1000);
|
||||
LOGGER.warn("{}_cost_time:{}ns",logPrefix, costTime);
|
||||
LOGGER.warn("{}_cost_time:{}ns", logPrefix, costTime);
|
||||
Assert.assertTrue(rayBenchmarkTest.checkResult(remoteResult.getResult()));
|
||||
}
|
||||
return counterList;
|
||||
@@ -130,7 +131,13 @@ public abstract class RayBenchmarkTest<T> extends BaseTest implements Serializab
|
||||
RayObject<List<Long>>[] rayObjects = new RayObject[clientNum];
|
||||
|
||||
for (int i = 0; i < clientNum; i++) {
|
||||
rayObjects[i] = Ray.call(RayBenchmarkTest::singleClient, pressureTestParameter);
|
||||
// Java compiler can't automatically infer the type of
|
||||
// `RayBenchmarkTest::singleClient`, because `RayBenchmarkTest` is a generic class.
|
||||
// It will match both `RayFunc1` and `RayFuncVoid1`. This looks like a bug or
|
||||
// defect of the Java compiler.
|
||||
// TODO(hchen): Figure out how to avoid manually declaring `RayFunc` type in this case.
|
||||
RayFunc1<PressureTestParameter, List<Long>> func = RayBenchmarkTest::singleClient;
|
||||
rayObjects[i] = Ray.call(func, pressureTestParameter);
|
||||
}
|
||||
for (int i = 0; i < clientNum; i++) {
|
||||
List<Long> subCounterList = rayObjects[i].get();
|
||||
|
||||
@@ -32,7 +32,11 @@ public class ActorTest extends BaseTest {
|
||||
return value;
|
||||
}
|
||||
|
||||
public int increase(int delta) {
|
||||
public void increase(int delta) {
|
||||
value += delta;
|
||||
}
|
||||
|
||||
public int increaseAndGet(int delta) {
|
||||
value += delta;
|
||||
return value;
|
||||
}
|
||||
@@ -45,7 +49,8 @@ public class ActorTest extends BaseTest {
|
||||
Assert.assertNotEquals(actor.getId(), UniqueId.NIL);
|
||||
// Test calling an actor
|
||||
Assert.assertEquals(Integer.valueOf(1), Ray.call(Counter::getValue, actor).get());
|
||||
Assert.assertEquals(Integer.valueOf(11), Ray.call(Counter::increase, actor, 10).get());
|
||||
Ray.call(Counter::increase, actor, 1);
|
||||
Assert.assertEquals(Integer.valueOf(3), Ray.call(Counter::increaseAndGet, actor, 1).get());
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
@@ -64,19 +69,19 @@ public class ActorTest extends BaseTest {
|
||||
|
||||
@RayRemote
|
||||
public static int testActorAsFirstParameter(RayActor<Counter> actor, int delta) {
|
||||
RayObject<Integer> res = Ray.call(Counter::increase, actor, delta);
|
||||
RayObject<Integer> res = Ray.call(Counter::increaseAndGet, actor, delta);
|
||||
return res.get();
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
public static int testActorAsSecondParameter(int delta, RayActor<Counter> actor) {
|
||||
RayObject<Integer> res = Ray.call(Counter::increase, actor, delta);
|
||||
RayObject<Integer> res = Ray.call(Counter::increaseAndGet, actor, delta);
|
||||
return res.get();
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
public static int testActorAsFieldOfParameter(List<RayActor<Counter>> actor, int delta) {
|
||||
RayObject<Integer> res = Ray.call(Counter::increase, actor.get(0), delta);
|
||||
RayObject<Integer> res = Ray.call(Counter::increaseAndGet, actor.get(0), delta);
|
||||
return res.get();
|
||||
}
|
||||
|
||||
@@ -96,9 +101,9 @@ public class ActorTest extends BaseTest {
|
||||
public void testForkingActorHandle() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
RayActor<Counter> counter = Ray.createActor(Counter::new, 100);
|
||||
Assert.assertEquals(Integer.valueOf(101), Ray.call(Counter::increase, counter, 1).get());
|
||||
Assert.assertEquals(Integer.valueOf(101), Ray.call(Counter::increaseAndGet, counter, 1).get());
|
||||
RayActor<Counter> counter2 = ((NativeRayActor) counter).fork();
|
||||
Assert.assertEquals(Integer.valueOf(103), Ray.call(Counter::increase, counter2, 2).get());
|
||||
Assert.assertEquals(Integer.valueOf(103), Ray.call(Counter::increaseAndGet, counter2, 2).get());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -18,14 +18,10 @@ public class PlasmaStoreTest extends BaseTest {
|
||||
ObjectId objectId = ObjectId.fromRandom();
|
||||
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
|
||||
ObjectStore objectStore = runtime.getObjectStore();
|
||||
objectStore.putRaw(new NativeRayObject(new byte[]{1}, null), objectId);
|
||||
Assert.assertEquals(
|
||||
objectStore.getRaw(Collections.singletonList(objectId), -1).get(0).data[0],
|
||||
(byte) 1);
|
||||
objectStore.putRaw(new NativeRayObject(new byte[]{2}, null), objectId);
|
||||
// Putting 2 objects with duplicate ID should fail but ignored.
|
||||
Assert.assertEquals(
|
||||
objectStore.getRaw(Collections.singletonList(objectId), -1).get(0).data[0],
|
||||
(byte) 1);
|
||||
objectStore.put("1", objectId);
|
||||
Assert.assertEquals(Ray.get(objectId), "1");
|
||||
objectStore.put("2", objectId);
|
||||
// Putting the second object with duplicate ID should fail but ignored.
|
||||
Assert.assertEquals(Ray.get(objectId), "1");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@@ -66,6 +68,7 @@ public class RayCallTest extends BaseTest {
|
||||
}
|
||||
|
||||
public static class LargeObject implements Serializable {
|
||||
|
||||
private byte[] data = new byte[1024 * 1024];
|
||||
}
|
||||
|
||||
@@ -74,6 +77,12 @@ public class RayCallTest extends BaseTest {
|
||||
return largeObject;
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
private static void testNoReturn(ObjectId objectId) {
|
||||
// Put an object in object store to inform driver that this function is executing.
|
||||
((AbstractRayRuntime) Ray.internal()).getObjectStore().put(1, objectId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Test calling and returning different types.
|
||||
*/
|
||||
@@ -93,6 +102,10 @@ public class RayCallTest extends BaseTest {
|
||||
Assert.assertEquals(map, Ray.call(RayCallTest::testMap, map).get());
|
||||
LargeObject largeObject = new LargeObject();
|
||||
Assert.assertNotNull(Ray.call(RayCallTest::testLargeObject, largeObject).get());
|
||||
|
||||
ObjectId randomObjectId = ObjectId.fromRandom();
|
||||
Ray.call(RayCallTest::testNoReturn, randomObjectId);
|
||||
Assert.assertEquals(((int) Ray.get(randomObjectId)), 1);
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
|
||||
Reference in New Issue
Block a user