[Java] Support calling functions returning void (#5494)

This commit is contained in:
Hao Chen
2019-08-23 21:10:15 +08:00
committed by GitHub
parent 7812dd5636
commit 239c177fe8
27 changed files with 1603 additions and 112 deletions
File diff suppressed because it is too large Load Diff
@@ -7,5 +7,6 @@ package org.ray.api.function;
*/
@FunctionalInterface
public interface RayFunc0<R> extends RayFunc {
R apply();
R apply() throws Exception;
}
@@ -7,5 +7,6 @@ package org.ray.api.function;
*/
@FunctionalInterface
public interface RayFunc1<T0, R> extends RayFunc {
R apply(T0 t0);
R apply(T0 t0) throws Exception;
}
@@ -7,5 +7,6 @@ package org.ray.api.function;
*/
@FunctionalInterface
public interface RayFunc2<T0, T1, R> extends RayFunc {
R apply(T0 t0, T1 t1);
R apply(T0 t0, T1 t1) throws Exception;
}
@@ -7,5 +7,6 @@ package org.ray.api.function;
*/
@FunctionalInterface
public interface RayFunc3<T0, T1, T2, R> extends RayFunc {
R apply(T0 t0, T1 t1, T2 t2);
R apply(T0 t0, T1 t1, T2 t2) throws Exception;
}
@@ -7,5 +7,6 @@ package org.ray.api.function;
*/
@FunctionalInterface
public interface RayFunc4<T0, T1, T2, T3, R> extends RayFunc {
R apply(T0 t0, T1 t1, T2 t2, T3 t3);
R apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Exception;
}
@@ -7,5 +7,6 @@ package org.ray.api.function;
*/
@FunctionalInterface
public interface RayFunc5<T0, T1, T2, T3, T4, R> extends RayFunc {
R apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4);
R apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Exception;
}
@@ -7,5 +7,6 @@ package org.ray.api.function;
*/
@FunctionalInterface
public interface RayFunc6<T0, T1, T2, T3, T4, T5, R> extends RayFunc {
R apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5);
R apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception;
}
@@ -0,0 +1,8 @@
package org.ray.api.function;
/**
* Interface of all `RayFuncVoidX` classes.
*/
public interface RayFuncVoid extends RayFunc {
}
@@ -0,0 +1,12 @@
// generated automatically, do not modify.
package org.ray.api.function;
/**
* Functional interface for a remote function that has 0 parameter.
*/
@FunctionalInterface
public interface RayFuncVoid0 extends RayFuncVoid {
void apply() throws Exception;
}
@@ -0,0 +1,12 @@
// generated automatically, do not modify.
package org.ray.api.function;
/**
* Functional interface for a remote function that has 1 parameter.
*/
@FunctionalInterface
public interface RayFuncVoid1<T0> extends RayFuncVoid {
void apply(T0 t0) throws Exception;
}
@@ -0,0 +1,12 @@
// generated automatically, do not modify.
package org.ray.api.function;
/**
* Functional interface for a remote function that has 2 parameters.
*/
@FunctionalInterface
public interface RayFuncVoid2<T0, T1> extends RayFuncVoid {
void apply(T0 t0, T1 t1) throws Exception;
}
@@ -0,0 +1,12 @@
// generated automatically, do not modify.
package org.ray.api.function;
/**
* Functional interface for a remote function that has 3 parameters.
*/
@FunctionalInterface
public interface RayFuncVoid3<T0, T1, T2> extends RayFuncVoid {
void apply(T0 t0, T1 t1, T2 t2) throws Exception;
}
@@ -0,0 +1,12 @@
// generated automatically, do not modify.
package org.ray.api.function;
/**
* Functional interface for a remote function that has 4 parameters.
*/
@FunctionalInterface
public interface RayFuncVoid4<T0, T1, T2, T3> extends RayFuncVoid {
void apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Exception;
}
@@ -0,0 +1,12 @@
// generated automatically, do not modify.
package org.ray.api.function;
/**
* Functional interface for a remote function that has 5 parameters.
*/
@FunctionalInterface
public interface RayFuncVoid5<T0, T1, T2, T3, T4> extends RayFuncVoid {
void apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Exception;
}
@@ -0,0 +1,12 @@
// generated automatically, do not modify.
package org.ray.api.function;
/**
* Functional interface for a remote function that has 6 parameters.
*/
@FunctionalInterface
public interface RayFuncVoid6<T0, T1, T2, T3, T4, T5> extends RayFuncVoid {
void apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception;
}
@@ -9,6 +9,7 @@ import org.ray.api.RayPyActor;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
import org.ray.api.function.RayFunc;
import org.ray.api.function.RayFuncVoid;
import org.ray.api.id.ObjectId;
import org.ray.api.id.UniqueId;
import org.ray.api.options.ActorCreationOptions;
@@ -107,7 +108,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
FunctionDescriptor functionDescriptor =
functionManager.getFunction(workerContext.getCurrentJobId(), func)
.functionDescriptor;
return callNormalFunction(functionDescriptor, args, options);
int numReturns = func instanceof RayFuncVoid ? 0 : 1;
return callNormalFunction(functionDescriptor, args, numReturns, options);
}
@Override
@@ -115,7 +117,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
FunctionDescriptor functionDescriptor =
functionManager.getFunction(workerContext.getCurrentJobId(), func)
.functionDescriptor;
return callActorFunction(actor, functionDescriptor, args);
int numReturns = func instanceof RayFuncVoid ? 0 : 1;
return callActorFunction(actor, functionDescriptor, args, numReturns);
}
@Override
@@ -143,7 +146,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
checkPyArguments(args);
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(moduleName, "",
functionName);
return callNormalFunction(functionDescriptor, args, options);
// Python functions always have a return value, even if it's `None`.
return callNormalFunction(functionDescriptor, args, /*numReturns=*/1, options);
}
@Override
@@ -151,7 +155,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
checkPyArguments(args);
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(),
pyActor.getClassName(), functionName);
return callActorFunction(pyActor, functionDescriptor, args);
// Python functions always have a return value, even if it's `None`.
return callActorFunction(pyActor, functionDescriptor, args, /*numReturns=*/1);
}
@Override
@@ -164,21 +169,31 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
Object[] args, CallOptions options) {
Object[] args, int numReturns, CallOptions options) {
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
List<ObjectId> returnIds = taskSubmitter.submitTask(functionDescriptor,
functionArgs, 1, options);
return new RayObjectImpl(returnIds.get(0));
functionArgs, numReturns, options);
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
if (returnIds.isEmpty()) {
return null;
} else {
return new RayObjectImpl(returnIds.get(0));
}
}
private RayObject callActorFunction(RayActor rayActor,
FunctionDescriptor functionDescriptor, Object[] args) {
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) {
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
List<ObjectId> returnIds = taskSubmitter.submitActorTask(rayActor,
functionDescriptor, functionArgs, 1, null);
return new RayObjectImpl(returnIds.get(0));
functionDescriptor, functionArgs, numReturns, null);
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
if (returnIds.isEmpty()) {
return null;
} else {
return new RayObjectImpl(returnIds.get(0));
}
}
private RayActor createActorImpl(FunctionDescriptor functionDescriptor,
@@ -72,6 +72,17 @@ public class RayFunction {
return rayRemote;
}
/**
* @return Whether this function has a return value.
*/
public boolean hasReturn() {
if (isConstructor()) {
return true;
} else {
return !getMethod().getReturnType().equals(void.class);
}
}
@Override
public String toString() {
return executable.toString();
@@ -68,6 +68,18 @@ public abstract class ObjectStore {
return putRaw(serialize(object));
}
/**
* Serialize and put an object to the object store, with the given object id.
*
* This method is only used for testing.
*
* @param object The object to put.
* @param objectId Object id.
*/
public void put(Object object, ObjectId objectId) {
putRaw(serialize(object), objectId);
}
/**
* Get a list of raw objects from the object store.
*
@@ -156,7 +168,8 @@ public abstract class ObjectStore {
* Delete a list of objects from the object store.
*
* @param objectIds IDs of the objects to delete.
* @param localOnly Whether only delete the objects in local node, or all nodes in the cluster.
* @param localOnly Whether only delete the objects in local node, or all nodes in the
* cluster.
* @param deleteCreatingTasks Whether also delete the tasks that created these objects.
*/
public abstract void delete(List<ObjectId> objectIds, boolean localOnly,
@@ -1,6 +1,7 @@
package org.ray.runtime.task;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
@@ -160,7 +161,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
@Override
public List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
int numReturns, CallOptions options) {
Preconditions.checkState(numReturns == 1);
Preconditions.checkState(numReturns <= 1);
TaskSpec taskSpec = getTaskSpecBuilder(TaskType.NORMAL_TASK, functionDescriptor, args)
.setNumReturns(numReturns)
.build();
@@ -185,7 +186,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
@Override
public List<ObjectId> submitActorTask(RayActor actor, FunctionDescriptor functionDescriptor,
List<FunctionArg> args, int numReturns, CallOptions options) {
Preconditions.checkState(numReturns == 1);
Preconditions.checkState(numReturns <= 1);
TaskSpec.Builder builder = getTaskSpecBuilder(TaskType.ACTOR_TASK, functionDescriptor, args);
List<ObjectId> returnIds = getReturnIds(
TaskId.fromBytes(builder.getTaskId().toByteArray()), numReturns + 1);
@@ -200,7 +201,11 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
.build())
.build();
submitTaskSpec(taskSpec);
return Collections.singletonList(returnIds.get(0));
if (numReturns == 0) {
return ImmutableList.of();
} else {
return ImmutableList.of(returnIds.get(0));
}
}
public static ActorId getActorId(TaskSpec taskSpec) {
@@ -70,10 +70,11 @@ public final class TaskExecutor {
List<NativeRayObject> returnObjects = new ArrayList<>();
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
// Find the executable object.
RayFunction rayFunction = runtime.getFunctionManager()
.getFunction(jobId, parseFunctionDescriptor(rayFunctionInfo));
Preconditions.checkNotNull(rayFunction);
try {
// Get method
RayFunction rayFunction = runtime.getFunctionManager()
.getFunction(jobId, parseFunctionDescriptor(rayFunctionInfo));
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);
@@ -100,7 +101,9 @@ public final class TaskExecutor {
// TODO (kfstorm): handle checkpoint in core worker.
maybeSaveCheckpoint(actor, runtime.getWorkerContext().getCurrentActorId());
}
returnObjects.add(runtime.getObjectStore().serialize(result));
if (rayFunction.hasReturn()) {
returnObjects.add(runtime.getObjectStore().serialize(result));
}
} else {
// TODO (kfstorm): handle checkpoint in core worker.
maybeLoadCheckpoint(result, runtime.getWorkerContext().getCurrentActorId());
@@ -110,8 +113,10 @@ public final class TaskExecutor {
} catch (Exception e) {
LOGGER.error("Error executing task " + taskId, e);
if (taskType != TaskType.ACTOR_CREATION_TASK) {
returnObjects.add(runtime.getObjectStore()
.serialize(new RayTaskException("Error executing task " + taskId, e)));
if(rayFunction.hasReturn()) {
returnObjects.add(runtime.getObjectStore()
.serialize(new RayTaskException("Error executing task " + taskId, e)));
}
} else {
actorCreationException = e;
}
@@ -6,8 +6,8 @@ import java.util.List;
import org.ray.runtime.util.FileUtil;
/**
* A util class that generates `RayCall.java`,
* which provides type-safe interfaces for `Ray.call` and `Ray.createActor`.
* A util class that generates `RayCall.java`, which provides type-safe interfaces for `Ray.call`
* and `Ray.createActor`.
*/
public class RayCallGenerator extends BaseGenerator {
@@ -21,13 +21,12 @@ public class RayCallGenerator extends BaseGenerator {
newLine("");
newLine("package org.ray.api;");
newLine("");
newLine("import org.ray.api.function.RayFunc0;");
newLine("import org.ray.api.function.RayFunc1;");
newLine("import org.ray.api.function.RayFunc2;");
newLine("import org.ray.api.function.RayFunc3;");
newLine("import org.ray.api.function.RayFunc4;");
newLine("import org.ray.api.function.RayFunc5;");
newLine("import org.ray.api.function.RayFunc6;");
for (int i = 0; i <= MAX_PARAMETERS; i++) {
newLine("import org.ray.api.function.RayFunc" + i + ";");
}
for (int i = 0; i <= MAX_PARAMETERS; i++) {
newLine("import org.ray.api.function.RayFuncVoid" + i + ";");
}
newLine("import org.ray.api.options.ActorCreationOptions;");
newLine("import org.ray.api.options.CallOptions;");
newLine("");
@@ -41,22 +40,25 @@ public class RayCallGenerator extends BaseGenerator {
newLine(1, "// Methods for remote function invocation.");
newLine(1, "// =======================================");
for (int i = 0; i <= MAX_PARAMETERS; i++) {
buildCalls(i, false, false, false);
buildCalls(i, false, false, true);
buildCalls(i, false, false, true, false);
buildCalls(i, false, false, true, true);
buildCalls(i, false, false, false, false);
buildCalls(i, false, false, false, true);
}
newLine(1, "// ===========================================");
newLine(1, "// Methods for remote actor method invocation.");
newLine(1, "// ===========================================");
for (int i = 0; i <= MAX_PARAMETERS - 1; i++) {
buildCalls(i, true, false, false);
buildCalls(i, true, false, true, false);
buildCalls(i, true, false, false, false);
}
newLine(1, "// ===========================");
newLine(1, "// Methods for actor creation.");
newLine(1, "// ===========================");
for (int i = 0; i <= MAX_PARAMETERS; i++) {
buildCalls(i, false, true, false);
buildCalls(i, false, true, true);
buildCalls(i, false, true, true, false);
buildCalls(i, false, true, true, true);
}
newLine(1, "// ===========================");
@@ -71,7 +73,7 @@ public class RayCallGenerator extends BaseGenerator {
}
for (int i = 0; i <= MAX_PARAMETERS; i++) {
buildPyCalls(i, false, true, false);
buildPyCalls(i,false, true, true);
buildPyCalls(i, false, true, true);
}
newLine("}");
return sb.toString();
@@ -79,72 +81,106 @@ public class RayCallGenerator extends BaseGenerator {
/**
* Build the `Ray.call` or `Ray.createActor` methods with the given number of parameters.
*
* @param numParameters the number of parameters
* @param forActor build actor api when true, otherwise build task api.
* @param hasReturn if true, build api for functions with return.
* @param forActorCreation build `Ray.createActor` when true, otherwise build `Ray.call`.
*/
private void buildCalls(int numParameters, boolean forActor,
boolean forActorCreation, boolean hasOptionsParam) {
boolean forActorCreation, boolean hasReturn, boolean hasOptionsParam) {
// Template of the generated function:
// public static [genericTypes] [returnType] [callFunc]([argsDeclaration]) {
// Objects[] args = new Object[]{[args]};
// return Ray.internal().[callFunc](f[, actor], args[, options]);
// }
// 1) Construct the `genericTypes` part, e.g. `<T0, T1, T2, R>`.
String genericTypes = "";
String argList = "";
for (int i = 0; i < numParameters; i++) {
genericTypes += "T" + i + ", ";
argList += "t" + i + ", ";
}
if (forActor) {
// Actor generic type.
genericTypes = "A, " + genericTypes;
}
genericTypes += forActorCreation ? "A" : "R";
if (argList.endsWith(", ")) {
argList = argList.substring(0, argList.length() - 2);
// Return generic type.
if (forActorCreation) {
genericTypes += "A, ";
} else {
if (hasReturn) {
genericTypes += "R, ";
}
}
if (!genericTypes.isEmpty()) {
// Trim trailing ", ";
genericTypes = genericTypes.substring(0, genericTypes.length() - 2);
genericTypes = "<" + genericTypes + ">";
}
String paramPrefix = String.format("RayFunc%d<%s> f",
// 2) Construct the `returnType` part.
String returnType;
if (forActorCreation) {
returnType = "RayActor<A>";
} else {
returnType = hasReturn ? "RayObject<R>" : "void";
}
// 3) Construct the `argsDeclaration` part.
String argsDeclarationPrefix = String.format("RayFunc%s%d%s f, ",
hasReturn ? "" : "Void",
!forActor ? numParameters : numParameters + 1,
genericTypes);
if (forActor) {
paramPrefix += ", RayActor<A> actor";
}
if (numParameters > 0) {
paramPrefix += ", ";
argsDeclarationPrefix += "RayActor<A> actor, ";
}
String optionsParam;
if (hasOptionsParam) {
optionsParam = forActorCreation ? ", ActorCreationOptions options" : ", CallOptions options";
} else {
optionsParam = "";
}
String callFunc = forActorCreation ? "createActor" : "call";
String optionsArg;
if (forActor) {
optionsArg = "";
} else {
// Enumerate all combinations of the parameters.
for (String param : generateParameters(numParameters)) {
String argsDeclaration = argsDeclarationPrefix + param;
if (hasOptionsParam) {
optionsArg = ", options";
} else {
optionsArg = ", null";
argsDeclaration +=
forActorCreation ? "ActorCreationOptions options, " : "CallOptions options, ";
}
}
String returnType = !forActorCreation ? "RayObject<R>" : "RayActor<A>";
String funcName = !forActorCreation ? "call" : "createActor";
String funcArgs = !forActor ? "f, args" : "f, actor, args";
for (String param : generateParameters(0, numParameters)) {
// Method signature.
// Trim trailing ", ";
argsDeclaration = argsDeclaration.substring(0, argsDeclaration.length() - 2);
// Print the first line (method signature).
newLine(1, String.format(
"public static <%s> %s %s(%s%s) {",
genericTypes, returnType, funcName, paramPrefix + param, optionsParam
"public static%s %s %s(%s) {",
genericTypes.isEmpty() ? "" : " " + genericTypes, returnType, callFunc, argsDeclaration
));
// Method body.
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg));
// 4) Construct the `args` part.
String args = "";
for (int i = 0; i < numParameters; i++) {
args += "t" + i + ", ";
}
// Trim trailing ", ";
if (!args.isEmpty()) {
args = args.substring(0, args.length() - 2);
}
// Print the second line (local args declaration).
newLine(2, String.format("Object[] args = new Object[]{%s};", args));
// 5) Construct the third line.
String callFuncArgs = "f, ";
if (forActor) {
callFuncArgs += "actor, ";
}
callFuncArgs += "args, ";
callFuncArgs += forActor ? "" : hasOptionsParam ? "options, " : "null, ";
callFuncArgs = callFuncArgs.substring(0, callFuncArgs.length() - 2);
newLine(2, String.format("%sRay.internal().%s(%s);",
hasReturn ? "return " : "", callFunc, callFuncArgs));
newLine(1, "}");
}
}
/**
* Build the `Ray.callPy` or `Ray.createPyActor` methods.
*
* @param forActor build actor api when true, otherwise build task api.
* @param forActorCreation build `Ray.createPyActor` when true, otherwise build `Ray.callPy`.
*/
@@ -211,24 +247,21 @@ public class RayCallGenerator extends BaseGenerator {
newLine(1, "}");
}
private List<String> generateParameters(int from, int to) {
private List<String> generateParameters(int numParams) {
List<String> res = new ArrayList<>();
dfs(from, from, to, "", res);
dfs(0, numParams, "", res);
return res;
}
private void dfs(int pos, int from, int to, String cur, List<String> res) {
if (pos >= to) {
private void dfs(int pos, int numParams, String cur, List<String> res) {
if (pos >= numParams) {
res.add(cur);
return;
}
if (pos > from) {
cur += ", ";
}
String nextParameter = String.format("T%d t%d", pos, pos);
dfs(pos + 1, from, to, cur + nextParameter, res);
nextParameter = String.format("RayObject<T%d> t%d", pos, pos);
dfs(pos + 1, from, to, cur + nextParameter, res);
String nextParameter = String.format("T%d t%d, ", pos, pos);
dfs(pos + 1, numParams, cur + nextParameter, res);
nextParameter = String.format("RayObject<T%d> t%d, ", pos, pos);
dfs(pos + 1, numParams, cur + nextParameter, res);
}
public static void main(String[] args) throws IOException {
@@ -8,7 +8,7 @@ import org.ray.runtime.util.FileUtil;
*/
public class RayFuncGenerator extends BaseGenerator {
private String generate(int numParameters) {
private String generate(int numParameters, boolean hasReturn) {
sb = new StringBuilder();
String genericTypes = "";
@@ -20,6 +20,14 @@ public class RayFuncGenerator extends BaseGenerator {
}
paramList += String.format("T%d t%d", i, i);
}
if (hasReturn) {
genericTypes += "R, ";
}
if (!genericTypes.isEmpty()) {
// Remove trailing ", ".
genericTypes = genericTypes.substring(0, genericTypes.length() - 2);
genericTypes = "<" + genericTypes + ">";
}
newLine("// generated automatically, do not modify.");
newLine("");
@@ -32,10 +40,12 @@ public class RayFuncGenerator extends BaseGenerator {
newLine(comment);
newLine(" */");
newLine("@FunctionalInterface");
newLine(String.format("public interface RayFunc%d<%sR> extends RayFunc {",
numParameters, genericTypes));
String className = "RayFunc" + (hasReturn ? "" : "Void") + numParameters;
newLine(String.format("public interface %s%s extends %s {",
className, genericTypes, hasReturn ? "RayFunc" : "RayFuncVoid"));
newLine("");
indents(1);
newLine(String.format("R apply(%s);", paramList));
newLine(String.format("%s apply(%s) throws Exception;", hasReturn ? "R" : "void", paramList));
newLine("}");
return sb.toString();
@@ -46,8 +56,12 @@ public class RayFuncGenerator extends BaseGenerator {
+ "/api/src/main/java/org/ray/api/function/";
RayFuncGenerator generator = new RayFuncGenerator();
for (int i = 0; i <= MAX_PARAMETERS; i++) {
String content = generator.generate(i);
// Functions that have return.
String content = generator.generate(i, true);
FileUtil.overrideFile(root + "RayFunc" + i + ".java", content);
// Functions that don't have return.
content = generator.generate(i, false);
FileUtil.overrideFile(root + "RayFuncVoid" + i + ".java", content);
}
}
@@ -10,6 +10,7 @@ import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.annotation.RayRemote;
import org.ray.api.function.RayFunc1;
import org.ray.api.test.BaseTest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -64,7 +65,7 @@ public abstract class RayBenchmarkTest<T> extends BaseTest implements Serializab
long endTime = remoteResult.getFinishTime();
long costTime = endTime - temp.getStartTime();
counterList.add(costTime / 1000);
LOGGER.warn("{}_cost_time:{}ns",logPrefix, costTime);
LOGGER.warn("{}_cost_time:{}ns", logPrefix, costTime);
Assert.assertTrue(rayBenchmarkTest.checkResult(remoteResult.getResult()));
}
return counterList;
@@ -130,7 +131,13 @@ public abstract class RayBenchmarkTest<T> extends BaseTest implements Serializab
RayObject<List<Long>>[] rayObjects = new RayObject[clientNum];
for (int i = 0; i < clientNum; i++) {
rayObjects[i] = Ray.call(RayBenchmarkTest::singleClient, pressureTestParameter);
// Java compiler can't automatically infer the type of
// `RayBenchmarkTest::singleClient`, because `RayBenchmarkTest` is a generic class.
// It will match both `RayFunc1` and `RayFuncVoid1`. This looks like a bug or
// defect of the Java compiler.
// TODO(hchen): Figure out how to avoid manually declaring `RayFunc` type in this case.
RayFunc1<PressureTestParameter, List<Long>> func = RayBenchmarkTest::singleClient;
rayObjects[i] = Ray.call(func, pressureTestParameter);
}
for (int i = 0; i < clientNum; i++) {
List<Long> subCounterList = rayObjects[i].get();
@@ -32,7 +32,11 @@ public class ActorTest extends BaseTest {
return value;
}
public int increase(int delta) {
public void increase(int delta) {
value += delta;
}
public int increaseAndGet(int delta) {
value += delta;
return value;
}
@@ -45,7 +49,8 @@ public class ActorTest extends BaseTest {
Assert.assertNotEquals(actor.getId(), UniqueId.NIL);
// Test calling an actor
Assert.assertEquals(Integer.valueOf(1), Ray.call(Counter::getValue, actor).get());
Assert.assertEquals(Integer.valueOf(11), Ray.call(Counter::increase, actor, 10).get());
Ray.call(Counter::increase, actor, 1);
Assert.assertEquals(Integer.valueOf(3), Ray.call(Counter::increaseAndGet, actor, 1).get());
}
@RayRemote
@@ -64,19 +69,19 @@ public class ActorTest extends BaseTest {
@RayRemote
public static int testActorAsFirstParameter(RayActor<Counter> actor, int delta) {
RayObject<Integer> res = Ray.call(Counter::increase, actor, delta);
RayObject<Integer> res = Ray.call(Counter::increaseAndGet, actor, delta);
return res.get();
}
@RayRemote
public static int testActorAsSecondParameter(int delta, RayActor<Counter> actor) {
RayObject<Integer> res = Ray.call(Counter::increase, actor, delta);
RayObject<Integer> res = Ray.call(Counter::increaseAndGet, actor, delta);
return res.get();
}
@RayRemote
public static int testActorAsFieldOfParameter(List<RayActor<Counter>> actor, int delta) {
RayObject<Integer> res = Ray.call(Counter::increase, actor.get(0), delta);
RayObject<Integer> res = Ray.call(Counter::increaseAndGet, actor.get(0), delta);
return res.get();
}
@@ -96,9 +101,9 @@ public class ActorTest extends BaseTest {
public void testForkingActorHandle() {
TestUtils.skipTestUnderSingleProcess();
RayActor<Counter> counter = Ray.createActor(Counter::new, 100);
Assert.assertEquals(Integer.valueOf(101), Ray.call(Counter::increase, counter, 1).get());
Assert.assertEquals(Integer.valueOf(101), Ray.call(Counter::increaseAndGet, counter, 1).get());
RayActor<Counter> counter2 = ((NativeRayActor) counter).fork();
Assert.assertEquals(Integer.valueOf(103), Ray.call(Counter::increase, counter2, 2).get());
Assert.assertEquals(Integer.valueOf(103), Ray.call(Counter::increaseAndGet, counter2, 2).get());
}
@Test
@@ -18,14 +18,10 @@ public class PlasmaStoreTest extends BaseTest {
ObjectId objectId = ObjectId.fromRandom();
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
ObjectStore objectStore = runtime.getObjectStore();
objectStore.putRaw(new NativeRayObject(new byte[]{1}, null), objectId);
Assert.assertEquals(
objectStore.getRaw(Collections.singletonList(objectId), -1).get(0).data[0],
(byte) 1);
objectStore.putRaw(new NativeRayObject(new byte[]{2}, null), objectId);
// Putting 2 objects with duplicate ID should fail but ignored.
Assert.assertEquals(
objectStore.getRaw(Collections.singletonList(objectId), -1).get(0).data[0],
(byte) 1);
objectStore.put("1", objectId);
Assert.assertEquals(Ray.get(objectId), "1");
objectStore.put("2", objectId);
// Putting the second object with duplicate ID should fail but ignored.
Assert.assertEquals(Ray.get(objectId), "1");
}
}
@@ -7,6 +7,8 @@ import java.util.List;
import java.util.Map;
import org.ray.api.Ray;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.ObjectId;
import org.ray.runtime.AbstractRayRuntime;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -66,6 +68,7 @@ public class RayCallTest extends BaseTest {
}
public static class LargeObject implements Serializable {
private byte[] data = new byte[1024 * 1024];
}
@@ -74,6 +77,12 @@ public class RayCallTest extends BaseTest {
return largeObject;
}
@RayRemote
private static void testNoReturn(ObjectId objectId) {
// Put an object in object store to inform driver that this function is executing.
((AbstractRayRuntime) Ray.internal()).getObjectStore().put(1, objectId);
}
/**
* Test calling and returning different types.
*/
@@ -93,6 +102,10 @@ public class RayCallTest extends BaseTest {
Assert.assertEquals(map, Ray.call(RayCallTest::testMap, map).get());
LargeObject largeObject = new LargeObject();
Assert.assertNotNull(Ray.call(RayCallTest::testLargeObject, largeObject).get());
ObjectId randomObjectId = ObjectId.fromRandom();
Ray.call(RayCallTest::testNoReturn, randomObjectId);
Assert.assertEquals(((int) Ray.get(randomObjectId)), 1);
}
@RayRemote