Cross-language invocation Part 1: Java calling Python functions and actors (#4166)

This commit is contained in:
Hao Chen
2019-03-21 13:34:21 +08:00
committed by GitHub
parent 828dc08ac8
commit d03999d01e
28 changed files with 872 additions and 228 deletions
+139 -2
View File
@@ -2,7 +2,6 @@
package org.ray.api;
import org.ray.api.function.RayFunc;
import org.ray.api.function.RayFunc0;
import org.ray.api.function.RayFunc1;
import org.ray.api.function.RayFunc2;
@@ -11,7 +10,6 @@ import org.ray.api.function.RayFunc4;
import org.ray.api.function.RayFunc5;
import org.ray.api.function.RayFunc6;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.BaseTaskOptions;
import org.ray.api.options.CallOptions;
/**
@@ -2312,4 +2310,143 @@ class RayCall {
Object[] args = new Object[]{t0, t1, t2, t3, t4, t5};
return Ray.internal().createActor(f, args, options);
}
// ===========================
// Cross-language methods.
// ===========================
public static RayObject callPy(String moduleName, String functionName) {
Object[] args = new Object[]{};
return Ray.internal().callPy(moduleName, functionName, args, null);
}
public static RayObject callPy(String moduleName, String functionName, CallOptions options) {
Object[] args = new Object[]{};
return Ray.internal().callPy(moduleName, functionName, args, options);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0) {
Object[] args = new Object[]{obj0};
return Ray.internal().callPy(moduleName, functionName, args, null);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, CallOptions options) {
Object[] args = new Object[]{obj0};
return Ray.internal().callPy(moduleName, functionName, args, options);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1) {
Object[] args = new Object[]{obj0, obj1};
return Ray.internal().callPy(moduleName, functionName, args, null);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, CallOptions options) {
Object[] args = new Object[]{obj0, obj1};
return Ray.internal().callPy(moduleName, functionName, args, options);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2) {
Object[] args = new Object[]{obj0, obj1, obj2};
return Ray.internal().callPy(moduleName, functionName, args, null);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, CallOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2};
return Ray.internal().callPy(moduleName, functionName, args, options);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
return Ray.internal().callPy(moduleName, functionName, args, null);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, CallOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
return Ray.internal().callPy(moduleName, functionName, args, options);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
return Ray.internal().callPy(moduleName, functionName, args, null);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, CallOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
return Ray.internal().callPy(moduleName, functionName, args, options);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5};
return Ray.internal().callPy(moduleName, functionName, args, null);
}
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5, CallOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5};
return Ray.internal().callPy(moduleName, functionName, args, options);
}
public static RayObject callPy(RayPyActor pyActor, String functionName) {
Object[] args = new Object[]{};
return Ray.internal().callPy(pyActor, functionName, args);
}
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0) {
Object[] args = new Object[]{obj0};
return Ray.internal().callPy(pyActor, functionName, args);
}
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1) {
Object[] args = new Object[]{obj0, obj1};
return Ray.internal().callPy(pyActor, functionName, args);
}
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2) {
Object[] args = new Object[]{obj0, obj1, obj2};
return Ray.internal().callPy(pyActor, functionName, args);
}
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2, Object obj3) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
return Ray.internal().callPy(pyActor, functionName, args);
}
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
return Ray.internal().callPy(pyActor, functionName, args);
}
public static RayPyActor createPyActor(String moduleName, String className) {
Object[] args = new Object[]{};
return Ray.internal().createPyActor(moduleName, className, args, null);
}
public static RayPyActor createPyActor(String moduleName, String className, ActorCreationOptions options) {
Object[] args = new Object[]{};
return Ray.internal().createPyActor(moduleName, className, args, options);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0) {
Object[] args = new Object[]{obj0};
return Ray.internal().createPyActor(moduleName, className, args, null);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, ActorCreationOptions options) {
Object[] args = new Object[]{obj0};
return Ray.internal().createPyActor(moduleName, className, args, options);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1) {
Object[] args = new Object[]{obj0, obj1};
return Ray.internal().createPyActor(moduleName, className, args, null);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, ActorCreationOptions options) {
Object[] args = new Object[]{obj0, obj1};
return Ray.internal().createPyActor(moduleName, className, args, options);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2) {
Object[] args = new Object[]{obj0, obj1, obj2};
return Ray.internal().createPyActor(moduleName, className, args, null);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, ActorCreationOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2};
return Ray.internal().createPyActor(moduleName, className, args, options);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
return Ray.internal().createPyActor(moduleName, className, args, null);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, ActorCreationOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
return Ray.internal().createPyActor(moduleName, className, args, options);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
return Ray.internal().createPyActor(moduleName, className, args, null);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, ActorCreationOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
return Ray.internal().createPyActor(moduleName, className, args, options);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5};
return Ray.internal().createPyActor(moduleName, className, args, null);
}
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5, ActorCreationOptions options) {
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5};
return Ray.internal().createPyActor(moduleName, className, args, options);
}
}
@@ -0,0 +1,18 @@
package org.ray.api;
/**
* Handle of a Python actor.
*/
public interface RayPyActor extends RayActor {
/**
* @return Module name of the Python actor class.
*/
String getModuleName();
/**
* @return Name of the Python actor class.
*/
String getClassName();
}
@@ -3,6 +3,7 @@ package org.ray.api.runtime;
import java.util.List;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
import org.ray.api.RuntimeContext;
import org.ray.api.WaitResult;
import org.ray.api.function.RayFunc;
@@ -45,8 +46,8 @@ public interface RayRuntime {
<T> List<T> get(List<UniqueId> objectIds);
/**
* Wait for a list of RayObjects to be locally available,
* until specified number of objects are ready, or specified timeout has passed.
* Wait for a list of RayObjects to be locally available, until specified number of objects are
* ready, or specified timeout has passed.
*
* @param waitList A list of RayObject to wait for.
* @param numReturns The number of objects that should be returned.
@@ -96,4 +97,37 @@ public interface RayRuntime {
ActorCreationOptions options);
RuntimeContext getRuntimeContext();
/**
* Invoke a remote Python function.
*
* @param moduleName Module name of the Python function.
* @param functionName Name of the Python function.
* @param args Arguments of the function.
* @param options The options for this call.
* @return The result object.
*/
RayObject callPy(String moduleName, String functionName, Object[] args, CallOptions options);
/**
* Invoke a remote Python function on an actor.
*
* @param pyActor A handle to the actor.
* @param functionName Name of the actor method.
* @param args Arguments of the function.
* @return The result object.
*/
RayObject callPy(RayPyActor pyActor, String functionName, Object[] args);
/**
* Create a Python actor on a remote node.
*
* @param moduleName Module name of the Python actor class.
* @param className Name of the Python actor class.
* @param args Arguments of the actor constructor.
* @param options The options for creating actor.
* @return A handle to the actor.
*/
RayPyActor createPyActor(String moduleName, String className, Object[] args,
ActorCreationOptions options);
}
-1
View File
@@ -235,7 +235,6 @@
<failsOnError>true</failsOnError>
<failOnViolation>true</failOnViolation>
<violationSeverity>warning</violationSeverity>
<format>xml</format>
<outputFile>${project.build.directory}/checkstyle-errors.xml</outputFile>
<linkXRef>false</linkXRef>
</configuration>
@@ -1,5 +1,6 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Arrays;
@@ -10,6 +11,7 @@ import java.util.Map;
import java.util.stream.Collectors;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
import org.ray.api.RuntimeContext;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
@@ -20,12 +22,14 @@ import org.ray.api.options.BaseTaskOptions;
import org.ray.api.options.CallOptions;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.functionmanager.FunctionDescriptor;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.functionmanager.RayFunction;
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.ray.runtime.objectstore.ObjectStoreProxy.GetResult;
import org.ray.runtime.raylet.RayletClient;
import org.ray.runtime.task.ArgumentsBuilder;
import org.ray.runtime.task.TaskLanguage;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.ResourceUtil;
import org.ray.runtime.util.UniqueIdUtil;
@@ -69,7 +73,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
functionManager = new FunctionManager(rayConfig.driverResourcePath);
worker = new Worker(this);
workerContext = new WorkerContext(rayConfig.workerMode,
rayConfig.driverId, rayConfig.runMode);
rayConfig.driverId, rayConfig.runMode);
runtimeContext = new RuntimeContextImpl(this);
}
@@ -229,7 +233,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, args, false, options);
TaskSpec spec = createTaskSpec(func, null, RayActorImpl.NIL, args, false, options);
rayletClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
@@ -242,7 +246,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
RayActorImpl<?> actorImpl = (RayActorImpl) actor;
TaskSpec spec;
synchronized (actor) {
spec = createTaskSpec(func, actorImpl, args, false, null);
spec = createTaskSpec(func, null, actorImpl, args, false, null);
spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor());
actorImpl.setTaskCursor(spec.returnIds[1]);
actorImpl.clearNewActorHandles();
@@ -255,7 +259,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@SuppressWarnings("unchecked")
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc,
Object[] args, ActorCreationOptions options) {
TaskSpec spec = createTaskSpec(actorFactoryFunc, RayActorImpl.NIL,
TaskSpec spec = createTaskSpec(actorFactoryFunc, null, RayActorImpl.NIL,
args, true, options);
RayActorImpl<?> actor = new RayActorImpl(spec.returnIds[0]);
actor.increaseTaskCounter();
@@ -264,17 +268,71 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return (RayActor<T>) actor;
}
private void checkPyArguments(Object[] args) {
for (Object arg : args) {
Preconditions.checkArgument(
(arg instanceof RayPyActor) || (arg instanceof byte[]),
"Python argument can only be a RayPyActor or a byte array, not {}.",
arg.getClass().getName());
}
}
@Override
public RayObject callPy(String moduleName, String functionName, Object[] args,
CallOptions options) {
checkPyArguments(args);
PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, "", functionName);
TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, false, options);
rayletClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
@Override
public RayObject callPy(RayPyActor pyActor, String functionName, Object... args) {
checkPyArguments(args);
PyFunctionDescriptor desc = new PyFunctionDescriptor(pyActor.getModuleName(),
pyActor.getClassName(), functionName);
RayPyActorImpl actorImpl = (RayPyActorImpl) pyActor;
TaskSpec spec;
synchronized (pyActor) {
spec = createTaskSpec(null, desc, actorImpl, args, false, null);
spec.getExecutionDependencies().add(actorImpl.getTaskCursor());
actorImpl.setTaskCursor(spec.returnIds[1]);
actorImpl.clearNewActorHandles();
}
rayletClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
@Override
public RayPyActor createPyActor(String moduleName, String className, Object[] args,
ActorCreationOptions options) {
checkPyArguments(args);
PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, className, "__init__");
TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, true, options);
RayPyActorImpl actor = new RayPyActorImpl(spec.actorCreationId, moduleName, className);
actor.increaseTaskCounter();
actor.setTaskCursor(spec.returnIds[0]);
rayletClient.submitTask(spec);
return actor;
}
/**
* Create the task specification.
*
* @param func The target remote function.
* @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a
* Python task.
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
* @param args The arguments for the remote function.
* @param isActorCreationTask Whether this task is an actor creation task.
* @return A TaskSpec object.
*/
private TaskSpec createTaskSpec(RayFunc func, RayActorImpl<?> actor, Object[] args,
private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDescriptor,
RayActorImpl<?> actor, Object[] args,
boolean isActorCreationTask, BaseTaskOptions taskOptions) {
Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null));
UniqueId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(),
workerContext.getCurrentTaskId(), workerContext.nextTaskIndex());
int numReturns = actor.getId().isNil() ? 1 : 2;
@@ -302,7 +360,16 @@ public abstract class AbstractRayRuntime implements RayRuntime {
maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions;
}
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentDriverId(), func);
TaskLanguage language;
FunctionDescriptor functionDescriptor;
if (func != null) {
language = TaskLanguage.JAVA;
functionDescriptor = functionManager.getFunction(workerContext.getCurrentDriverId(), func)
.getFunctionDescriptor();
} else {
language = TaskLanguage.PYTHON;
functionDescriptor = pyFunctionDescriptor;
}
return new TaskSpec(
workerContext.getCurrentDriverId(),
@@ -315,10 +382,11 @@ public abstract class AbstractRayRuntime implements RayRuntime {
actor.getHandleId(),
actor.increaseTaskCounter(),
actor.getNewActorHandles().toArray(new UniqueId[0]),
ArgumentsBuilder.wrap(args),
ArgumentsBuilder.wrap(args, language == TaskLanguage.PYTHON),
returnIds,
resources,
rayFunction.getFunctionDescriptor()
language,
functionDescriptor
);
}
@@ -10,26 +10,32 @@ import org.ray.api.RayActor;
import org.ray.api.id.UniqueId;
import org.ray.runtime.util.Sha1Digestor;
public final class RayActorImpl<T> implements RayActor<T>, Externalizable {
public class RayActorImpl<T> implements RayActor<T>, Externalizable {
public static final RayActorImpl NIL = new RayActorImpl();
private UniqueId id;
private UniqueId handleId;
/**
* Id of this actor.
*/
protected UniqueId id;
/**
* Handle id of this actor.
*/
protected UniqueId handleId;
/**
* The number of tasks that have been invoked on this actor.
*/
private int taskCounter;
protected int taskCounter;
/**
* The unique id of the last return of the last task.
* It's used as a dependency for the next task.
*/
private UniqueId taskCursor;
protected UniqueId taskCursor;
/**
* The number of times that this actor handle has been forked.
* It's used to make sure ids of actor handles are unique.
*/
private int numForks;
protected int numForks;
/**
* The new actor handles that were created from this handle
@@ -37,7 +43,7 @@ public final class RayActorImpl<T> implements RayActor<T>, Externalizable {
* used to garbage-collect dummy objects that are no longer
* necessary in the backend.
*/
private List<UniqueId> newActorHandles;
protected List<UniqueId> newActorHandles;
public RayActorImpl() {
this(UniqueId.NIL, UniqueId.NIL);
@@ -0,0 +1,69 @@
package org.ray.runtime;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import org.ray.api.RayPyActor;
import org.ray.api.id.UniqueId;
public class RayPyActorImpl extends RayActorImpl implements RayPyActor {
public static final RayPyActorImpl NIL = new RayPyActorImpl(UniqueId.NIL, null, null);
/**
* Module name of the Python actor class.
*/
private String moduleName;
/**
* Name of the Python actor class.
*/
private String className;
private RayPyActorImpl() {}
public RayPyActorImpl(UniqueId id, String moduleName, String className) {
super(id);
this.moduleName = moduleName;
this.className = className;
}
@Override
public String getModuleName() {
return moduleName;
}
@Override
public String getClassName() {
return className;
}
public RayPyActorImpl fork() {
RayPyActorImpl ret = new RayPyActorImpl();
ret.id = this.id;
ret.taskCounter = 0;
ret.numForks = 0;
ret.taskCursor = this.taskCursor;
ret.moduleName = this.moduleName;
ret.className = this.className;
ret.handleId = this.computeNextActorHandleId();
newActorHandles.add(ret.handleId);
return ret;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
super.writeExternal(out);
out.writeObject(this.moduleName);
out.writeObject(this.className);
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
super.readExternal(in);
this.moduleName = (String) in.readObject();
this.className = (String) in.readObject();
}
}
@@ -85,7 +85,7 @@ public class Worker {
try {
// Get method
RayFunction rayFunction = runtime.getFunctionManager()
.getFunction(spec.driverId, spec.functionDescriptor);
.getFunction(spec.driverId, spec.getJavaFunctionDescriptor());
// Set context
runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader);
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
@@ -1,52 +1,11 @@
package org.ray.runtime.functionmanager;
import com.google.common.base.Objects;
/**
* Represents the function's metadata.
* Base interface of a Ray task's function descriptor.
*
* A function descriptor is a list of strings that can uniquely describe a function. It's used to
* load a function in workers.
*/
public final class FunctionDescriptor {
public interface FunctionDescriptor {
/**
* Function's class name.
*/
public final String className;
/**
* Function's name.
*/
public final String name;
/**
* Function's type descriptor.
*/
public final String typeDescriptor;
public FunctionDescriptor(String className, String name, String typeDescriptor) {
this.className = className;
this.name = name;
this.typeDescriptor = typeDescriptor;
}
@Override
public String toString() {
return className + "." + name;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FunctionDescriptor that = (FunctionDescriptor) o;
return Objects.equal(className, that.className) &&
Objects.equal(name, that.name) &&
Objects.equal(typeDescriptor, that.typeDescriptor);
}
@Override
public int hashCode() {
return Objects.hashCode(className, name, typeDescriptor);
}
}
@@ -30,10 +30,10 @@ public class FunctionManager {
static final String CONSTRUCTOR_NAME = "<init>";
/**
* Cache from a RayFunc object to its corresponding FunctionDescriptor. Because
* Cache from a RayFunc object to its corresponding JavaFunctionDescriptor. Because
* `LambdaUtils.getSerializedLambda` is expensive.
*/
private static final ThreadLocal<WeakHashMap<Class<? extends RayFunc>, FunctionDescriptor>>
private static final ThreadLocal<WeakHashMap<Class<? extends RayFunc>, JavaFunctionDescriptor>>
RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new);
/**
@@ -64,13 +64,13 @@ public class FunctionManager {
* @return A RayFunction object.
*/
public RayFunction getFunction(UniqueId driverId, RayFunc func) {
FunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
if (functionDescriptor == null) {
SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func);
final String className = serializedLambda.getImplClass().replace('/', '.');
final String methodName = serializedLambda.getImplMethodName();
final String typeDescriptor = serializedLambda.getImplMethodSignature();
functionDescriptor = new FunctionDescriptor(className, methodName, typeDescriptor);
functionDescriptor = new JavaFunctionDescriptor(className, methodName, typeDescriptor);
RAY_FUNC_CACHE.get().put(func.getClass(),functionDescriptor);
}
return getFunction(driverId, functionDescriptor);
@@ -83,7 +83,7 @@ public class FunctionManager {
* @param functionDescriptor The function descriptor.
* @return A RayFunction object.
*/
public RayFunction getFunction(UniqueId driverId, FunctionDescriptor functionDescriptor) {
public RayFunction getFunction(UniqueId driverId, JavaFunctionDescriptor functionDescriptor) {
DriverFunctionTable driverFunctionTable = driverFunctionTables.get(driverId);
if (driverFunctionTable == null) {
String resourcePath = driverResourcePath + "/" + driverId.toString() + "/";
@@ -122,7 +122,7 @@ public class FunctionManager {
this.functions = new HashMap<>();
}
RayFunction getFunction(FunctionDescriptor descriptor) {
RayFunction getFunction(JavaFunctionDescriptor descriptor) {
Map<Pair<String, String>, RayFunction> classFunctions = functions.get(descriptor.className);
if (classFunctions == null) {
classFunctions = loadFunctionsForClass(descriptor.className);
@@ -150,7 +150,7 @@ public class FunctionManager {
e instanceof Method ? Type.getType((Method) e) : Type.getType((Constructor) e);
final String typeDescriptor = type.getDescriptor();
RayFunction rayFunction = new RayFunction(e, classLoader,
new FunctionDescriptor(className, methodName, typeDescriptor));
new JavaFunctionDescriptor(className, methodName, typeDescriptor));
map.put(ImmutablePair.of(methodName, typeDescriptor), rayFunction);
}
} catch (Exception e) {
@@ -0,0 +1,52 @@
package org.ray.runtime.functionmanager;
import com.google.common.base.Objects;
/**
* Represents metadata of Java function.
*/
public final class JavaFunctionDescriptor implements FunctionDescriptor {
/**
* Function's class name.
*/
public final String className;
/**
* Function's name.
*/
public final String name;
/**
* Function's type descriptor.
*/
public final String typeDescriptor;
public JavaFunctionDescriptor(String className, String name, String typeDescriptor) {
this.className = className;
this.name = name;
this.typeDescriptor = typeDescriptor;
}
@Override
public String toString() {
return className + "." + name;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
JavaFunctionDescriptor that = (JavaFunctionDescriptor) o;
return Objects.equal(className, that.className) &&
Objects.equal(name, that.name) &&
Objects.equal(typeDescriptor, that.typeDescriptor);
}
@Override
public int hashCode() {
return Objects.hashCode(className, name, typeDescriptor);
}
}
@@ -0,0 +1,25 @@
package org.ray.runtime.functionmanager;
/**
* Represents metadata of a Python function.
*/
public class PyFunctionDescriptor implements FunctionDescriptor {
public String moduleName;
public String className;
public String functionName;
public PyFunctionDescriptor(String moduleName, String className, String functionName) {
this.moduleName = moduleName;
this.className = className;
this.functionName = functionName;
}
@Override
public String toString() {
return moduleName + "." + className + "." + functionName;
}
}
@@ -23,10 +23,10 @@ public class RayFunction {
/**
* Function's metadata.
*/
public final FunctionDescriptor functionDescriptor;
public final JavaFunctionDescriptor functionDescriptor;
public RayFunction(Executable executable, ClassLoader classLoader,
FunctionDescriptor functionDescriptor) {
JavaFunctionDescriptor functionDescriptor) {
this.executable = executable;
this.classLoader = classLoader;
this.functionDescriptor = functionDescriptor;
@@ -53,7 +53,7 @@ public class RayFunction {
return (Method) executable;
}
public FunctionDescriptor getFunctionDescriptor() {
public JavaFunctionDescriptor getFunctionDescriptor() {
return functionDescriptor;
}
@@ -36,6 +36,8 @@ public class ObjectStoreProxy {
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes();
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
private final AbstractRayRuntime runtime;
private static ThreadLocal<ObjectStoreLink> objectStore;
@@ -83,9 +85,8 @@ public class ObjectStoreProxy {
GetResult<T> result;
if (meta != null) {
// If meta is not null, deserialize the exception.
RayException exception = deserializeRayExceptionFromMeta(meta, ids.get(i));
result = new GetResult<>(true, null, exception);
// If meta is not null, deserialize the object from meta.
result = deserializeFromMeta(meta, data, ids.get(i));
} else if (data != null) {
// If data is not null, deserialize the Java object.
Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader());
@@ -112,13 +113,16 @@ public class ObjectStoreProxy {
return results;
}
private RayException deserializeRayExceptionFromMeta(byte[] meta, UniqueId objectId) {
if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
return RayWorkerException.INSTANCE;
@SuppressWarnings("unchecked")
private <T> GetResult<T> deserializeFromMeta(byte[] meta, byte[] data, UniqueId objectId) {
if (Arrays.equals(meta, RAW_TYPE_META)) {
return (GetResult<T>) new GetResult<>(true, data, null);
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
return new GetResult<>(true, null, RayWorkerException.INSTANCE);
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
return RayActorException.INSTANCE;
return new GetResult<>(true, null, RayActorException.INSTANCE);
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
return new UnreconstructableException(objectId);
return new GetResult<>(true, null, new UnreconstructableException(objectId));
}
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
}
@@ -131,7 +135,13 @@ public class ObjectStoreProxy {
*/
public void put(UniqueId id, Object object) {
try {
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
if (object instanceof byte[]) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
objectStore.get().put(id.getBytes(), (byte[]) object, RAW_TYPE_META);
} else {
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
}
} catch (DuplicateObjectException e) {
LOGGER.warn(e.getMessage());
}
@@ -12,12 +12,13 @@ import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.functionmanager.FunctionDescriptor;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.generated.Arg;
import org.ray.runtime.generated.Language;
import org.ray.runtime.generated.ResourcePair;
import org.ray.runtime.generated.TaskInfo;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.TaskLanguage;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.UniqueIdUtil;
import org.slf4j.Logger;
@@ -183,13 +184,14 @@ public class RayletClientImpl implements RayletClient {
resources.put(info.requiredResources(i).key(), info.requiredResources(i).value());
}
// Deserialize function descriptor
Preconditions.checkArgument(info.language() == Language.JAVA);
Preconditions.checkArgument(info.functionDescriptorLength() == 3);
FunctionDescriptor functionDescriptor = new FunctionDescriptor(
JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor(
info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2)
);
return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId,
maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles,
args, returnIds, resources, functionDescriptor);
args, returnIds, resources, TaskLanguage.JAVA, functionDescriptor);
}
private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) {
@@ -250,12 +252,29 @@ public class RayletClientImpl implements RayletClient {
int requiredPlacementResourcesOffset =
fbb.createVectorOfTables(requiredPlacementResourcesOffsets);
int[] functionDescriptorOffsets = new int[]{
fbb.createString(task.functionDescriptor.className),
fbb.createString(task.functionDescriptor.name),
fbb.createString(task.functionDescriptor.typeDescriptor)
};
int functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
int language;
int functionDescriptorOffset;
if (task.language == TaskLanguage.JAVA) {
// This is a Java task.
language = Language.JAVA;
int[] functionDescriptorOffsets = new int[]{
fbb.createString(task.getJavaFunctionDescriptor().className),
fbb.createString(task.getJavaFunctionDescriptor().name),
fbb.createString(task.getJavaFunctionDescriptor().typeDescriptor)
};
functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
} else {
// This is a Python task.
language = Language.PYTHON;
int[] functionDescriptorOffsets = new int[]{
fbb.createString(task.getPyFunctionDescriptor().moduleName),
fbb.createString(task.getPyFunctionDescriptor().className),
fbb.createString(task.getPyFunctionDescriptor().functionName),
fbb.createString("")
};
functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
}
int root = TaskInfo.createTaskInfo(
fbb,
@@ -274,7 +293,7 @@ public class RayletClientImpl implements RayletClient {
returnsOffset,
requiredResourcesOffset,
requiredPlacementResourcesOffset,
Language.JAVA,
language,
functionDescriptorOffset);
fbb.finish(root);
ByteBuffer buffer = fbb.dataBuffer();
@@ -12,16 +12,15 @@ import org.ray.runtime.util.Serializer;
public class ArgumentsBuilder {
/**
* If the the size of an argument's serialized data is smaller than this number,
* the argument will be passed by value. Otherwise it'll be passed by reference.
* If the the size of an argument's serialized data is smaller than this number, the argument will
* be passed by value. Otherwise it'll be passed by reference.
*/
private static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024;
/**
* Convert real function arguments to task spec arguments.
*/
public static FunctionArg[] wrap(Object[] args) {
public static FunctionArg[] wrap(Object[] args, boolean crossLanguage) {
FunctionArg[] ret = new FunctionArg[args.length];
for (int i = 0; i < ret.length; i++) {
Object arg = args[i];
@@ -33,10 +32,15 @@ public class ArgumentsBuilder {
data = Serializer.encode(arg);
} else if (arg instanceof RayObject) {
id = ((RayObject) arg).getId();
} else if (arg instanceof byte[] && crossLanguage) {
// If the argument is a byte array and will be used by a different language,
// do not inline this argument. Because the other language doesn't know how
// to deserialize it.
id = Ray.put(arg).getId();
} else {
byte[] serialized = Serializer.encode(arg);
if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) {
id = ((AbstractRayRuntime)Ray.internal()).putSerialized(serialized).getId();
id = ((AbstractRayRuntime) Ray.internal()).putSerialized(serialized).getId();
} else {
data = serialized;
}
@@ -0,0 +1,11 @@
package org.ray.runtime.task;
/**
* Language of a Ray task.
*/
public enum TaskLanguage {
JAVA,
PYTHON,
}
@@ -1,12 +1,14 @@
package org.ray.runtime.task;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.ray.api.id.UniqueId;
import org.ray.runtime.functionmanager.FunctionDescriptor;
import org.ray.runtime.util.ResourceUtil;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
/**
* Represents necessary information of a task for scheduling and executing.
@@ -52,9 +54,13 @@ public class TaskSpec {
// The task's resource demands.
public final Map<String, Double> resources;
// Function descriptor is a list of strings that can uniquely identify a function.
// It will be sent to worker and used to load the target callable function.
public final FunctionDescriptor functionDescriptor;
// Language of this task.
public final TaskLanguage language;
// Descriptor of the remote function.
// Note, if task language is Java, the type is JavaFunctionDescriptor. If the task language
// is Python, the type is PyFunctionDescriptor.
private final FunctionDescriptor functionDescriptor;
private List<UniqueId> executionDependencies;
@@ -66,10 +72,22 @@ public class TaskSpec {
return !actorCreationId.isNil();
}
public TaskSpec(UniqueId driverId, UniqueId taskId, UniqueId parentTaskId, int parentCounter,
UniqueId actorCreationId, int maxActorReconstructions, UniqueId actorId,
UniqueId actorHandleId, int actorCounter, UniqueId[] newActorHandles, FunctionArg[] args,
UniqueId[] returnIds, Map<String, Double> resources, FunctionDescriptor functionDescriptor) {
public TaskSpec(
UniqueId driverId,
UniqueId taskId,
UniqueId parentTaskId,
int parentCounter,
UniqueId actorCreationId,
int maxActorReconstructions,
UniqueId actorId,
UniqueId actorHandleId,
int actorCounter,
UniqueId[] newActorHandles,
FunctionArg[] args,
UniqueId[] returnIds,
Map<String, Double> resources,
TaskLanguage language,
FunctionDescriptor functionDescriptor) {
this.driverId = driverId;
this.taskId = taskId;
this.parentTaskId = parentTaskId;
@@ -83,10 +101,30 @@ public class TaskSpec {
this.args = args;
this.returnIds = returnIds;
this.resources = resources;
this.language = language;
if (language == TaskLanguage.JAVA) {
Preconditions.checkArgument(functionDescriptor instanceof JavaFunctionDescriptor,
"Expect JavaFunctionDescriptor type, but got {}.", functionDescriptor.getClass());
} else if (language == TaskLanguage.PYTHON) {
Preconditions.checkArgument(functionDescriptor instanceof PyFunctionDescriptor,
"Expect PyFunctionDescriptor type, but got {}.", functionDescriptor.getClass());
} else {
Preconditions.checkArgument(false, "Unknown task language: {}.", language);
}
this.functionDescriptor = functionDescriptor;
this.executionDependencies = new ArrayList<>();
}
public JavaFunctionDescriptor getJavaFunctionDescriptor() {
Preconditions.checkState(language == TaskLanguage.JAVA);
return (JavaFunctionDescriptor) functionDescriptor;
}
public PyFunctionDescriptor getPyFunctionDescriptor() {
Preconditions.checkState(language == TaskLanguage.PYTHON);
return (PyFunctionDescriptor) functionDescriptor;
}
public List<UniqueId> getExecutionDependencies() {
return executionDependencies;
}
@@ -99,13 +137,17 @@ public class TaskSpec {
", parentTaskId=" + parentTaskId +
", parentCounter=" + parentCounter +
", actorCreationId=" + actorCreationId +
", maxActorReconstructions=" + maxActorReconstructions +
", actorId=" + actorId +
", actorHandleId=" + actorHandleId +
", actorCounter=" + actorCounter +
", newActorHandles=" + Arrays.toString(newActorHandles) +
", args=" + Arrays.toString(args) +
", returnIds=" + Arrays.toString(returnIds) +
", resources=" + ResourceUtil.getResourcesStringFromMap(resources) +
", resources=" + resources +
", language=" + language +
", functionDescriptor=" + functionDescriptor +
", executionDependencies=" + executionDependencies +
'}';
}
}
@@ -21,7 +21,6 @@ public class RayCallGenerator extends BaseGenerator {
newLine("");
newLine("package org.ray.api;");
newLine("");
newLine("import org.ray.api.function.RayFunc;");
newLine("import org.ray.api.function.RayFunc0;");
newLine("import org.ray.api.function.RayFunc1;");
newLine("import org.ray.api.function.RayFunc2;");
@@ -30,7 +29,6 @@ public class RayCallGenerator extends BaseGenerator {
newLine("import org.ray.api.function.RayFunc5;");
newLine("import org.ray.api.function.RayFunc6;");
newLine("import org.ray.api.options.ActorCreationOptions;");
newLine("import org.ray.api.options.BaseTaskOptions;");
newLine("import org.ray.api.options.CallOptions;");
newLine("");
@@ -46,6 +44,7 @@ public class RayCallGenerator extends BaseGenerator {
buildCalls(i, false, false, false);
buildCalls(i, false, false, true);
}
newLine(1, "// ===========================================");
newLine(1, "// Methods for remote actor method invocation.");
newLine(1, "// ===========================================");
@@ -59,6 +58,21 @@ public class RayCallGenerator extends BaseGenerator {
buildCalls(i, false, true, false);
buildCalls(i, false, true, true);
}
newLine(1, "// ===========================");
newLine(1, "// Cross-language methods.");
newLine(1, "// ===========================");
for (int i = 0; i <= MAX_PARAMETERS; i++) {
buildPyCalls(i, false, false, false);
buildPyCalls(i, false, false, true);
}
for (int i = 0; i <= MAX_PARAMETERS - 1; i++) {
buildPyCalls(i, true, false, false);
}
for (int i = 0; i <= MAX_PARAMETERS; i++) {
buildPyCalls(i, false, true, false);
buildPyCalls(i,false, true, true);
}
newLine("}");
return sb.toString();
}
@@ -117,18 +131,86 @@ public class RayCallGenerator extends BaseGenerator {
String funcName = !forActorCreation ? "call" : "createActor";
String funcArgs = !forActor ? "f, args" : "f, actor, args";
for (String param : generateParameters(0, numParameters)) {
// method signature
// Method signature.
newLine(1, String.format(
"public static <%s> %s %s(%s%s) {",
genericTypes, returnType, funcName, paramPrefix + param, optionsParam
));
// method body
// Method body.
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg));
newLine(1, "}");
}
}
/**
* Build the `Ray.callPy` or `Ray.createPyActor` methods.
* @param forActor build actor api when true, otherwise build task api.
* @param forActorCreation build `Ray.createPyActor` when true, otherwise build `Ray.callPy`.
*/
private void buildPyCalls(int numParameters, boolean forActor,
boolean forActorCreation, boolean hasOptionsParam) {
String argList = "";
String paramList = "";
for (int i = 0; i < numParameters; i++) {
paramList += "Object obj" + i + ", ";
argList += "obj" + i + ", ";
}
if (argList.endsWith(", ")) {
argList = argList.substring(0, argList.length() - 2);
}
if (paramList.endsWith(", ")) {
paramList = paramList.substring(0, paramList.length() - 2);
}
String paramPrefix = "";
String funcArgs = "";
if (forActorCreation) {
paramPrefix += "String moduleName, String className";
funcArgs += "moduleName, className";
} else if (forActor) {
paramPrefix += "RayPyActor pyActor, String functionName";
funcArgs += "pyActor, functionName";
} else {
paramPrefix += "String moduleName, String functionName";
funcArgs += "moduleName, functionName";
}
if (numParameters > 0) {
paramPrefix += ", ";
}
String optionsParam;
if (hasOptionsParam) {
optionsParam = forActorCreation ? ", ActorCreationOptions options" : ", CallOptions options";
} else {
optionsParam = "";
}
String optionsArg;
if (forActor) {
optionsArg = "";
} else {
if (hasOptionsParam) {
optionsArg = ", options";
} else {
optionsArg = ", null";
}
}
String returnType = !forActorCreation ? "RayObject" : "RayPyActor";
String funcName = !forActorCreation ? "callPy" : "createPyActor";
funcArgs += ", args";
// Method signature.
newLine(1, String.format(
"public static %s %s(%s%s) {",
returnType, funcName, paramPrefix + paramList, optionsParam
));
// Method body.
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg));
newLine(1, "}");
}
private List<String> generateParameters(int from, int to) {
List<String> res = new ArrayList<>();
dfs(from, from, to, "", res);
@@ -155,3 +237,4 @@ public class RayCallGenerator extends BaseGenerator {
FileUtil.overrideFile(path, new RayCallGenerator().build());
}
}
@@ -40,20 +40,20 @@ public class FunctionManagerTest {
private static RayFunc0<Object> fooFunc;
private static RayFunc1<Bar, Object> barFunc;
private static RayFunc0<Bar> barConstructor;
private static FunctionDescriptor fooDescriptor;
private static FunctionDescriptor barDescriptor;
private static FunctionDescriptor barConstructorDescriptor;
private static JavaFunctionDescriptor fooDescriptor;
private static JavaFunctionDescriptor barDescriptor;
private static JavaFunctionDescriptor barConstructorDescriptor;
@BeforeClass
public static void beforeClass() {
fooFunc = FunctionManagerTest::foo;
barConstructor = Bar::new;
barFunc = Bar::bar;
fooDescriptor = new FunctionDescriptor(FunctionManagerTest.class.getName(), "foo",
fooDescriptor = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(), "foo",
"()Ljava/lang/Object;");
barDescriptor = new FunctionDescriptor(Bar.class.getName(), "bar",
barDescriptor = new JavaFunctionDescriptor(Bar.class.getName(), "bar",
"()Ljava/lang/Object;");
barConstructorDescriptor = new FunctionDescriptor(Bar.class.getName(),
barConstructorDescriptor = new JavaFunctionDescriptor(Bar.class.getName(),
FunctionManager.CONSTRUCTOR_NAME,
"()V");
}
@@ -132,7 +132,7 @@ public class FunctionManagerTest {
Files.copy(Paths.get(srcJarPath), Paths.get(destJarPath), StandardCopyOption.REPLACE_EXISTING);
final FunctionManager functionManager = new FunctionManager(resourcePath);
FunctionDescriptor sayHelloDescriptor = new FunctionDescriptor("org.ray.exercise.Exercise02",
JavaFunctionDescriptor sayHelloDescriptor = new JavaFunctionDescriptor("org.ray.exercise.Exercise02",
"sayHello", "()Ljava/lang/String;");
RayFunction func = functionManager.getFunction(driverId, sayHelloDescriptor);
Assert.assertEquals(func.getFunctionDescriptor(), sayHelloDescriptor);
+4 -1
View File
@@ -49,9 +49,12 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.21.0</version>
<version>3.0.0-M3</version>
<configuration>
<useFile>false</useFile>
<trimStackTrace>false</trimStackTrace>
<testSourceDirectory>${basedir}/src/main/java/</testSourceDirectory>
<testResourcesDirectory>${basedir}/src/main/resources/</testResourcesDirectory>
<testClassesDirectory>${project.build.directory}/classes/</testClassesDirectory>
</configuration>
</plugin>
@@ -0,0 +1,112 @@
package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.io.File;
import java.lang.ProcessBuilder.Redirect;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.TimeUnit;
import org.ray.api.Ray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.SkipException;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
public abstract class BaseMultiLanguageTest {
private static final Logger LOGGER = LoggerFactory.getLogger(BaseMultiLanguageTest.class);
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/test/plasma_store_socket";
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
/**
* Execute an external command.
*
* @return Whether the command succeeded.
*/
private boolean executeCommand(List<String> command, int waitTimeoutSeconds,
Map<String, String> env) {
try {
LOGGER.info("Executing command: {}", String.join(" ", command));
ProcessBuilder processBuilder = new ProcessBuilder(command).redirectOutput(Redirect.INHERIT)
.redirectError(Redirect.INHERIT);
for (Entry<String, String> entry : env.entrySet()) {
processBuilder.environment().put(entry.getKey(), entry.getValue());
}
Process process = processBuilder.start();
process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
return process.exitValue() == 0;
} catch (Exception e) {
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
}
}
@BeforeClass
public void setUp() {
if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) {
LOGGER.info("Skip Multi-language tests because environment variable "
+ "ENABLE_MULTI_LANGUAGE_TESTS isn't set");
throw new SkipException("Skip test.");
}
// Delete existing socket files.
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
File file = new File(socket);
if (file.exists()) {
file.delete();
}
}
// Start ray cluster.
String workerOptions =
" -classpath " + System.getProperty("java.class.path");
final List<String> startCommand = ImmutableList.of(
"ray",
"start",
"--head",
"--redis-port=6379",
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
"--load-code-from-local",
"--include-java",
"--java-worker-options=" + workerOptions
);
if (!executeCommand(startCommand, 10, getRayStartEnv())) {
throw new RuntimeException("Couldn't start ray cluster.");
}
// Connect to the cluster.
System.setProperty("ray.redis.address", "127.0.0.1:6379");
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
Ray.init();
}
/**
* @return The environment variables needed for the `ray start` command.
*/
protected Map<String, String> getRayStartEnv() {
return ImmutableMap.of();
}
@AfterClass
public void tearDown() {
// Disconnect to the cluster.
Ray.shutdown();
System.clearProperty("ray.redis.address");
System.clearProperty("ray.object-store.socket-name");
System.clearProperty("ray.raylet.socket-name");
// Stop ray cluster.
final List<String> stopCommand = ImmutableList.of(
"ray",
"stop"
);
if (!executeCommand(stopCommand, 10, ImmutableMap.of())) {
throw new RuntimeException("Couldn't stop ray cluster");
}
}
}
@@ -0,0 +1,54 @@
package org.ray.api.test;
import com.google.common.collect.ImmutableMap;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
import org.testng.Assert;
import org.testng.annotations.Test;
public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
private static final String PYTHON_MODULE = "test_cross_language_invocation";
@Override
protected Map<String, String> getRayStartEnv() {
// Delete and re-create the temp dir.
File tempDir = new File(
System.getProperty("java.io.tmpdir") + File.separator + "ray_cross_language_test");
FileUtils.deleteQuietly(tempDir);
tempDir.mkdirs();
tempDir.deleteOnExit();
// Write the test Python file to the temp dir.
InputStream in = CrossLanguageInvocationTest.class
.getResourceAsStream("/" + PYTHON_MODULE + ".py");
File pythonFile = new File(
tempDir.getAbsolutePath() + File.separator + PYTHON_MODULE + ".py");
try {
FileUtils.copyInputStreamToFile(in, pythonFile);
} catch (IOException e) {
throw new RuntimeException(e);
}
return ImmutableMap.of("PYTHONPATH", tempDir.getAbsolutePath());
}
@Test
public void testCallingPythonFunction() {
RayObject res = Ray.callPy(PYTHON_MODULE, "py_func", "hello".getBytes());
Assert.assertEquals(res.get(), "Response from Python: hello".getBytes());
}
@Test
public void testCallingPythonActor() {
RayPyActor actor = Ray.createPyActor(PYTHON_MODULE, "Counter", "1".getBytes());
RayObject res = Ray.callPy(actor, "increase", "1".getBytes());
Assert.assertEquals(res.get(), "2".getBytes());
}
}
@@ -1,113 +1,18 @@
package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import java.io.File;
import java.lang.reflect.Method;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.annotation.RayRemote;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.SkipException;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
/**
* Test starting a ray cluster with multi-language support.
*/
public class MultiLanguageClusterTest {
private static final Logger LOGGER = LoggerFactory.getLogger(MultiLanguageClusterTest.class);
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/test/plasma_store_socket";
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
public class MultiLanguageClusterTest extends BaseMultiLanguageTest {
@RayRemote
public static String echo(String word) {
return word;
}
/**
* Execute an external command.
*
* @return Whether the command succeeded.
*/
private boolean executeCommand(List<String> command, int waitTimeoutSeconds) {
try {
LOGGER.info("Executing command: {}", String.join(" ", command));
Process process = new ProcessBuilder(command).inheritIO().start();
process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
return process.exitValue() == 0;
} catch (Exception e) {
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
}
}
@BeforeMethod
public void setUp(Method method) {
String testName = method.getName();
if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) {
LOGGER.info("Skip " + testName +
" because env variable ENABLE_MULTI_LANGUAGE_TESTS isn't set");
throw new SkipException("Skip test.");
}
// Delete existing socket files.
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
File file = new File(socket);
if (file.exists()) {
file.delete();
}
}
// Start ray cluster.
String testDir = System.getProperty("user.dir");
String workerOptions =
" -classpath " + String.format("%s/../../build/java/*:%s/target/*", testDir, testDir);
final List<String> startCommand = ImmutableList.of(
"ray",
"start",
"--head",
"--redis-port=6379",
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
"--load-code-from-local",
"--include-java",
"--java-worker-options=" + workerOptions
);
if (!executeCommand(startCommand, 10)) {
throw new RuntimeException("Couldn't start ray cluster.");
}
// Connect to the cluster.
System.setProperty("ray.redis.address", "127.0.0.1:6379");
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
Ray.init();
}
@AfterMethod
public void tearDown() {
// Disconnect to the cluster.
Ray.shutdown();
System.clearProperty("ray.redis.address");
System.clearProperty("ray.object-store.socket-name");
System.clearProperty("ray.raylet.socket-name");
// Stop ray cluster.
final List<String> stopCommand = ImmutableList.of(
"ray",
"stop"
);
if (!executeCommand(stopCommand, 10)) {
throw new RuntimeException("Couldn't stop ray cluster");
}
}
@Test
public void testMultiLanguageCluster() {
RayObject<String> obj = Ray.call(MultiLanguageClusterTest::echo, "hello");
@@ -0,0 +1,20 @@
# This file is used by CrossLanguageInvocationTest.java to test cross-language
# invocation.
import ray
import six
@ray.remote
def py_func(value):
assert isinstance(value, bytes)
return b"Response from Python: " + value
@ray.remote
class Counter(object):
def __init__(self, value):
self.value = int(value)
def increase(self, delta):
self.value += int(delta)
return str(self.value).encode("utf-8") if six.PY3 else str(self.value)