mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 18:31:37 +08:00
Cross-language invocation Part 1: Java calling Python functions and actors (#4166)
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
|
||||
package org.ray.api;
|
||||
|
||||
import org.ray.api.function.RayFunc;
|
||||
import org.ray.api.function.RayFunc0;
|
||||
import org.ray.api.function.RayFunc1;
|
||||
import org.ray.api.function.RayFunc2;
|
||||
@@ -11,7 +10,6 @@ import org.ray.api.function.RayFunc4;
|
||||
import org.ray.api.function.RayFunc5;
|
||||
import org.ray.api.function.RayFunc6;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.BaseTaskOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
|
||||
/**
|
||||
@@ -2312,4 +2310,143 @@ class RayCall {
|
||||
Object[] args = new Object[]{t0, t1, t2, t3, t4, t5};
|
||||
return Ray.internal().createActor(f, args, options);
|
||||
}
|
||||
// ===========================
|
||||
// Cross-language methods.
|
||||
// ===========================
|
||||
public static RayObject callPy(String moduleName, String functionName) {
|
||||
Object[] args = new Object[]{};
|
||||
return Ray.internal().callPy(moduleName, functionName, args, null);
|
||||
}
|
||||
public static RayObject callPy(String moduleName, String functionName, CallOptions options) {
|
||||
Object[] args = new Object[]{};
|
||||
return Ray.internal().callPy(moduleName, functionName, args, options);
|
||||
}
|
||||
public static RayObject callPy(String moduleName, String functionName, Object obj0) {
|
||||
Object[] args = new Object[]{obj0};
|
||||
return Ray.internal().callPy(moduleName, functionName, args, null);
|
||||
}
|
||||
public static RayObject callPy(String moduleName, String functionName, Object obj0, CallOptions options) {
|
||||
Object[] args = new Object[]{obj0};
|
||||
return Ray.internal().callPy(moduleName, functionName, args, options);
|
||||
}
|
||||
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1) {
|
||||
Object[] args = new Object[]{obj0, obj1};
|
||||
return Ray.internal().callPy(moduleName, functionName, args, null);
|
||||
}
|
||||
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, CallOptions options) {
|
||||
Object[] args = new Object[]{obj0, obj1};
|
||||
return Ray.internal().callPy(moduleName, functionName, args, options);
|
||||
}
|
||||
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2};
|
||||
return Ray.internal().callPy(moduleName, functionName, args, null);
|
||||
}
|
||||
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, CallOptions options) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2};
|
||||
return Ray.internal().callPy(moduleName, functionName, args, options);
|
||||
}
|
||||
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
|
||||
return Ray.internal().callPy(moduleName, functionName, args, null);
|
||||
}
|
||||
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, CallOptions options) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
|
||||
return Ray.internal().callPy(moduleName, functionName, args, options);
|
||||
}
|
||||
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
|
||||
return Ray.internal().callPy(moduleName, functionName, args, null);
|
||||
}
|
||||
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, CallOptions options) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
|
||||
return Ray.internal().callPy(moduleName, functionName, args, options);
|
||||
}
|
||||
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5};
|
||||
return Ray.internal().callPy(moduleName, functionName, args, null);
|
||||
}
|
||||
public static RayObject callPy(String moduleName, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5, CallOptions options) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5};
|
||||
return Ray.internal().callPy(moduleName, functionName, args, options);
|
||||
}
|
||||
public static RayObject callPy(RayPyActor pyActor, String functionName) {
|
||||
Object[] args = new Object[]{};
|
||||
return Ray.internal().callPy(pyActor, functionName, args);
|
||||
}
|
||||
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0) {
|
||||
Object[] args = new Object[]{obj0};
|
||||
return Ray.internal().callPy(pyActor, functionName, args);
|
||||
}
|
||||
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1) {
|
||||
Object[] args = new Object[]{obj0, obj1};
|
||||
return Ray.internal().callPy(pyActor, functionName, args);
|
||||
}
|
||||
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2};
|
||||
return Ray.internal().callPy(pyActor, functionName, args);
|
||||
}
|
||||
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2, Object obj3) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
|
||||
return Ray.internal().callPy(pyActor, functionName, args);
|
||||
}
|
||||
public static RayObject callPy(RayPyActor pyActor, String functionName, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
|
||||
return Ray.internal().callPy(pyActor, functionName, args);
|
||||
}
|
||||
public static RayPyActor createPyActor(String moduleName, String className) {
|
||||
Object[] args = new Object[]{};
|
||||
return Ray.internal().createPyActor(moduleName, className, args, null);
|
||||
}
|
||||
public static RayPyActor createPyActor(String moduleName, String className, ActorCreationOptions options) {
|
||||
Object[] args = new Object[]{};
|
||||
return Ray.internal().createPyActor(moduleName, className, args, options);
|
||||
}
|
||||
public static RayPyActor createPyActor(String moduleName, String className, Object obj0) {
|
||||
Object[] args = new Object[]{obj0};
|
||||
return Ray.internal().createPyActor(moduleName, className, args, null);
|
||||
}
|
||||
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, ActorCreationOptions options) {
|
||||
Object[] args = new Object[]{obj0};
|
||||
return Ray.internal().createPyActor(moduleName, className, args, options);
|
||||
}
|
||||
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1) {
|
||||
Object[] args = new Object[]{obj0, obj1};
|
||||
return Ray.internal().createPyActor(moduleName, className, args, null);
|
||||
}
|
||||
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, ActorCreationOptions options) {
|
||||
Object[] args = new Object[]{obj0, obj1};
|
||||
return Ray.internal().createPyActor(moduleName, className, args, options);
|
||||
}
|
||||
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2};
|
||||
return Ray.internal().createPyActor(moduleName, className, args, null);
|
||||
}
|
||||
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, ActorCreationOptions options) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2};
|
||||
return Ray.internal().createPyActor(moduleName, className, args, options);
|
||||
}
|
||||
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
|
||||
return Ray.internal().createPyActor(moduleName, className, args, null);
|
||||
}
|
||||
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, ActorCreationOptions options) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2, obj3};
|
||||
return Ray.internal().createPyActor(moduleName, className, args, options);
|
||||
}
|
||||
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
|
||||
return Ray.internal().createPyActor(moduleName, className, args, null);
|
||||
}
|
||||
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, ActorCreationOptions options) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4};
|
||||
return Ray.internal().createPyActor(moduleName, className, args, options);
|
||||
}
|
||||
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5};
|
||||
return Ray.internal().createPyActor(moduleName, className, args, null);
|
||||
}
|
||||
public static RayPyActor createPyActor(String moduleName, String className, Object obj0, Object obj1, Object obj2, Object obj3, Object obj4, Object obj5, ActorCreationOptions options) {
|
||||
Object[] args = new Object[]{obj0, obj1, obj2, obj3, obj4, obj5};
|
||||
return Ray.internal().createPyActor(moduleName, className, args, options);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
package org.ray.api;
|
||||
|
||||
/**
|
||||
* Handle of a Python actor.
|
||||
*/
|
||||
public interface RayPyActor extends RayActor {
|
||||
|
||||
/**
|
||||
* @return Module name of the Python actor class.
|
||||
*/
|
||||
String getModuleName();
|
||||
|
||||
/**
|
||||
* @return Name of the Python actor class.
|
||||
*/
|
||||
String getClassName();
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package org.ray.api.runtime;
|
||||
import java.util.List;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.RuntimeContext;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.function.RayFunc;
|
||||
@@ -45,8 +46,8 @@ public interface RayRuntime {
|
||||
<T> List<T> get(List<UniqueId> objectIds);
|
||||
|
||||
/**
|
||||
* Wait for a list of RayObjects to be locally available,
|
||||
* until specified number of objects are ready, or specified timeout has passed.
|
||||
* Wait for a list of RayObjects to be locally available, until specified number of objects are
|
||||
* ready, or specified timeout has passed.
|
||||
*
|
||||
* @param waitList A list of RayObject to wait for.
|
||||
* @param numReturns The number of objects that should be returned.
|
||||
@@ -96,4 +97,37 @@ public interface RayRuntime {
|
||||
ActorCreationOptions options);
|
||||
|
||||
RuntimeContext getRuntimeContext();
|
||||
|
||||
/**
|
||||
* Invoke a remote Python function.
|
||||
*
|
||||
* @param moduleName Module name of the Python function.
|
||||
* @param functionName Name of the Python function.
|
||||
* @param args Arguments of the function.
|
||||
* @param options The options for this call.
|
||||
* @return The result object.
|
||||
*/
|
||||
RayObject callPy(String moduleName, String functionName, Object[] args, CallOptions options);
|
||||
|
||||
/**
|
||||
* Invoke a remote Python function on an actor.
|
||||
*
|
||||
* @param pyActor A handle to the actor.
|
||||
* @param functionName Name of the actor method.
|
||||
* @param args Arguments of the function.
|
||||
* @return The result object.
|
||||
*/
|
||||
RayObject callPy(RayPyActor pyActor, String functionName, Object[] args);
|
||||
|
||||
/**
|
||||
* Create a Python actor on a remote node.
|
||||
*
|
||||
* @param moduleName Module name of the Python actor class.
|
||||
* @param className Name of the Python actor class.
|
||||
* @param args Arguments of the actor constructor.
|
||||
* @param options The options for creating actor.
|
||||
* @return A handle to the actor.
|
||||
*/
|
||||
RayPyActor createPyActor(String moduleName, String className, Object[] args,
|
||||
ActorCreationOptions options);
|
||||
}
|
||||
|
||||
@@ -235,7 +235,6 @@
|
||||
<failsOnError>true</failsOnError>
|
||||
<failOnViolation>true</failOnViolation>
|
||||
<violationSeverity>warning</violationSeverity>
|
||||
<format>xml</format>
|
||||
<outputFile>${project.build.directory}/checkstyle-errors.xml</outputFile>
|
||||
<linkXRef>false</linkXRef>
|
||||
</configuration>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
@@ -10,6 +11,7 @@ import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.RuntimeContext;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.exception.RayException;
|
||||
@@ -20,12 +22,14 @@ import org.ray.api.options.BaseTaskOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy.GetResult;
|
||||
import org.ray.runtime.raylet.RayletClient;
|
||||
import org.ray.runtime.task.ArgumentsBuilder;
|
||||
import org.ray.runtime.task.TaskLanguage;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.ray.runtime.util.ResourceUtil;
|
||||
import org.ray.runtime.util.UniqueIdUtil;
|
||||
@@ -69,7 +73,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
functionManager = new FunctionManager(rayConfig.driverResourcePath);
|
||||
worker = new Worker(this);
|
||||
workerContext = new WorkerContext(rayConfig.workerMode,
|
||||
rayConfig.driverId, rayConfig.runMode);
|
||||
rayConfig.driverId, rayConfig.runMode);
|
||||
runtimeContext = new RuntimeContextImpl(this);
|
||||
}
|
||||
|
||||
@@ -229,7 +233,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
@Override
|
||||
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
|
||||
TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, args, false, options);
|
||||
TaskSpec spec = createTaskSpec(func, null, RayActorImpl.NIL, args, false, options);
|
||||
rayletClient.submitTask(spec);
|
||||
return new RayObjectImpl(spec.returnIds[0]);
|
||||
}
|
||||
@@ -242,7 +246,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
RayActorImpl<?> actorImpl = (RayActorImpl) actor;
|
||||
TaskSpec spec;
|
||||
synchronized (actor) {
|
||||
spec = createTaskSpec(func, actorImpl, args, false, null);
|
||||
spec = createTaskSpec(func, null, actorImpl, args, false, null);
|
||||
spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor());
|
||||
actorImpl.setTaskCursor(spec.returnIds[1]);
|
||||
actorImpl.clearNewActorHandles();
|
||||
@@ -255,7 +259,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc,
|
||||
Object[] args, ActorCreationOptions options) {
|
||||
TaskSpec spec = createTaskSpec(actorFactoryFunc, RayActorImpl.NIL,
|
||||
TaskSpec spec = createTaskSpec(actorFactoryFunc, null, RayActorImpl.NIL,
|
||||
args, true, options);
|
||||
RayActorImpl<?> actor = new RayActorImpl(spec.returnIds[0]);
|
||||
actor.increaseTaskCounter();
|
||||
@@ -264,17 +268,71 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
return (RayActor<T>) actor;
|
||||
}
|
||||
|
||||
private void checkPyArguments(Object[] args) {
|
||||
for (Object arg : args) {
|
||||
Preconditions.checkArgument(
|
||||
(arg instanceof RayPyActor) || (arg instanceof byte[]),
|
||||
"Python argument can only be a RayPyActor or a byte array, not {}.",
|
||||
arg.getClass().getName());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callPy(String moduleName, String functionName, Object[] args,
|
||||
CallOptions options) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, "", functionName);
|
||||
TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, false, options);
|
||||
rayletClient.submitTask(spec);
|
||||
return new RayObjectImpl(spec.returnIds[0]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callPy(RayPyActor pyActor, String functionName, Object... args) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor desc = new PyFunctionDescriptor(pyActor.getModuleName(),
|
||||
pyActor.getClassName(), functionName);
|
||||
RayPyActorImpl actorImpl = (RayPyActorImpl) pyActor;
|
||||
TaskSpec spec;
|
||||
synchronized (pyActor) {
|
||||
spec = createTaskSpec(null, desc, actorImpl, args, false, null);
|
||||
spec.getExecutionDependencies().add(actorImpl.getTaskCursor());
|
||||
actorImpl.setTaskCursor(spec.returnIds[1]);
|
||||
actorImpl.clearNewActorHandles();
|
||||
}
|
||||
rayletClient.submitTask(spec);
|
||||
return new RayObjectImpl(spec.returnIds[0]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayPyActor createPyActor(String moduleName, String className, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, className, "__init__");
|
||||
TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, true, options);
|
||||
RayPyActorImpl actor = new RayPyActorImpl(spec.actorCreationId, moduleName, className);
|
||||
actor.increaseTaskCounter();
|
||||
actor.setTaskCursor(spec.returnIds[0]);
|
||||
rayletClient.submitTask(spec);
|
||||
return actor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the task specification.
|
||||
*
|
||||
* @param func The target remote function.
|
||||
* @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a
|
||||
* Python task.
|
||||
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
|
||||
* @param args The arguments for the remote function.
|
||||
* @param isActorCreationTask Whether this task is an actor creation task.
|
||||
* @return A TaskSpec object.
|
||||
*/
|
||||
private TaskSpec createTaskSpec(RayFunc func, RayActorImpl<?> actor, Object[] args,
|
||||
private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDescriptor,
|
||||
RayActorImpl<?> actor, Object[] args,
|
||||
boolean isActorCreationTask, BaseTaskOptions taskOptions) {
|
||||
Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null));
|
||||
|
||||
UniqueId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(),
|
||||
workerContext.getCurrentTaskId(), workerContext.nextTaskIndex());
|
||||
int numReturns = actor.getId().isNil() ? 1 : 2;
|
||||
@@ -302,7 +360,16 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions;
|
||||
}
|
||||
|
||||
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentDriverId(), func);
|
||||
TaskLanguage language;
|
||||
FunctionDescriptor functionDescriptor;
|
||||
if (func != null) {
|
||||
language = TaskLanguage.JAVA;
|
||||
functionDescriptor = functionManager.getFunction(workerContext.getCurrentDriverId(), func)
|
||||
.getFunctionDescriptor();
|
||||
} else {
|
||||
language = TaskLanguage.PYTHON;
|
||||
functionDescriptor = pyFunctionDescriptor;
|
||||
}
|
||||
|
||||
return new TaskSpec(
|
||||
workerContext.getCurrentDriverId(),
|
||||
@@ -315,10 +382,11 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
actor.getHandleId(),
|
||||
actor.increaseTaskCounter(),
|
||||
actor.getNewActorHandles().toArray(new UniqueId[0]),
|
||||
ArgumentsBuilder.wrap(args),
|
||||
ArgumentsBuilder.wrap(args, language == TaskLanguage.PYTHON),
|
||||
returnIds,
|
||||
resources,
|
||||
rayFunction.getFunctionDescriptor()
|
||||
language,
|
||||
functionDescriptor
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -10,26 +10,32 @@ import org.ray.api.RayActor;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.util.Sha1Digestor;
|
||||
|
||||
public final class RayActorImpl<T> implements RayActor<T>, Externalizable {
|
||||
public class RayActorImpl<T> implements RayActor<T>, Externalizable {
|
||||
|
||||
public static final RayActorImpl NIL = new RayActorImpl();
|
||||
|
||||
private UniqueId id;
|
||||
private UniqueId handleId;
|
||||
/**
|
||||
* Id of this actor.
|
||||
*/
|
||||
protected UniqueId id;
|
||||
/**
|
||||
* Handle id of this actor.
|
||||
*/
|
||||
protected UniqueId handleId;
|
||||
/**
|
||||
* The number of tasks that have been invoked on this actor.
|
||||
*/
|
||||
private int taskCounter;
|
||||
protected int taskCounter;
|
||||
/**
|
||||
* The unique id of the last return of the last task.
|
||||
* It's used as a dependency for the next task.
|
||||
*/
|
||||
private UniqueId taskCursor;
|
||||
protected UniqueId taskCursor;
|
||||
/**
|
||||
* The number of times that this actor handle has been forked.
|
||||
* It's used to make sure ids of actor handles are unique.
|
||||
*/
|
||||
private int numForks;
|
||||
protected int numForks;
|
||||
|
||||
/**
|
||||
* The new actor handles that were created from this handle
|
||||
@@ -37,7 +43,7 @@ public final class RayActorImpl<T> implements RayActor<T>, Externalizable {
|
||||
* used to garbage-collect dummy objects that are no longer
|
||||
* necessary in the backend.
|
||||
*/
|
||||
private List<UniqueId> newActorHandles;
|
||||
protected List<UniqueId> newActorHandles;
|
||||
|
||||
public RayActorImpl() {
|
||||
this(UniqueId.NIL, UniqueId.NIL);
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInput;
|
||||
import java.io.ObjectOutput;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.id.UniqueId;
|
||||
|
||||
public class RayPyActorImpl extends RayActorImpl implements RayPyActor {
|
||||
|
||||
public static final RayPyActorImpl NIL = new RayPyActorImpl(UniqueId.NIL, null, null);
|
||||
|
||||
/**
|
||||
* Module name of the Python actor class.
|
||||
*/
|
||||
private String moduleName;
|
||||
|
||||
/**
|
||||
* Name of the Python actor class.
|
||||
*/
|
||||
private String className;
|
||||
|
||||
private RayPyActorImpl() {}
|
||||
|
||||
public RayPyActorImpl(UniqueId id, String moduleName, String className) {
|
||||
super(id);
|
||||
this.moduleName = moduleName;
|
||||
this.className = className;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModuleName() {
|
||||
return moduleName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getClassName() {
|
||||
return className;
|
||||
}
|
||||
|
||||
public RayPyActorImpl fork() {
|
||||
RayPyActorImpl ret = new RayPyActorImpl();
|
||||
ret.id = this.id;
|
||||
ret.taskCounter = 0;
|
||||
ret.numForks = 0;
|
||||
ret.taskCursor = this.taskCursor;
|
||||
ret.moduleName = this.moduleName;
|
||||
ret.className = this.className;
|
||||
ret.handleId = this.computeNextActorHandleId();
|
||||
newActorHandles.add(ret.handleId);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeExternal(ObjectOutput out) throws IOException {
|
||||
super.writeExternal(out);
|
||||
out.writeObject(this.moduleName);
|
||||
out.writeObject(this.className);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
|
||||
super.readExternal(in);
|
||||
this.moduleName = (String) in.readObject();
|
||||
this.className = (String) in.readObject();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ public class Worker {
|
||||
try {
|
||||
// Get method
|
||||
RayFunction rayFunction = runtime.getFunctionManager()
|
||||
.getFunction(spec.driverId, spec.functionDescriptor);
|
||||
.getFunction(spec.driverId, spec.getJavaFunctionDescriptor());
|
||||
// Set context
|
||||
runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader);
|
||||
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
|
||||
|
||||
@@ -1,52 +1,11 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
|
||||
/**
|
||||
* Represents the function's metadata.
|
||||
* Base interface of a Ray task's function descriptor.
|
||||
*
|
||||
* A function descriptor is a list of strings that can uniquely describe a function. It's used to
|
||||
* load a function in workers.
|
||||
*/
|
||||
public final class FunctionDescriptor {
|
||||
public interface FunctionDescriptor {
|
||||
|
||||
/**
|
||||
* Function's class name.
|
||||
*/
|
||||
public final String className;
|
||||
/**
|
||||
* Function's name.
|
||||
*/
|
||||
public final String name;
|
||||
/**
|
||||
* Function's type descriptor.
|
||||
*/
|
||||
public final String typeDescriptor;
|
||||
|
||||
public FunctionDescriptor(String className, String name, String typeDescriptor) {
|
||||
this.className = className;
|
||||
this.name = name;
|
||||
this.typeDescriptor = typeDescriptor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return className + "." + name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
FunctionDescriptor that = (FunctionDescriptor) o;
|
||||
return Objects.equal(className, that.className) &&
|
||||
Objects.equal(name, that.name) &&
|
||||
Objects.equal(typeDescriptor, that.typeDescriptor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(className, name, typeDescriptor);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,10 +30,10 @@ public class FunctionManager {
|
||||
static final String CONSTRUCTOR_NAME = "<init>";
|
||||
|
||||
/**
|
||||
* Cache from a RayFunc object to its corresponding FunctionDescriptor. Because
|
||||
* Cache from a RayFunc object to its corresponding JavaFunctionDescriptor. Because
|
||||
* `LambdaUtils.getSerializedLambda` is expensive.
|
||||
*/
|
||||
private static final ThreadLocal<WeakHashMap<Class<? extends RayFunc>, FunctionDescriptor>>
|
||||
private static final ThreadLocal<WeakHashMap<Class<? extends RayFunc>, JavaFunctionDescriptor>>
|
||||
RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new);
|
||||
|
||||
/**
|
||||
@@ -64,13 +64,13 @@ public class FunctionManager {
|
||||
* @return A RayFunction object.
|
||||
*/
|
||||
public RayFunction getFunction(UniqueId driverId, RayFunc func) {
|
||||
FunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
|
||||
JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
|
||||
if (functionDescriptor == null) {
|
||||
SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func);
|
||||
final String className = serializedLambda.getImplClass().replace('/', '.');
|
||||
final String methodName = serializedLambda.getImplMethodName();
|
||||
final String typeDescriptor = serializedLambda.getImplMethodSignature();
|
||||
functionDescriptor = new FunctionDescriptor(className, methodName, typeDescriptor);
|
||||
functionDescriptor = new JavaFunctionDescriptor(className, methodName, typeDescriptor);
|
||||
RAY_FUNC_CACHE.get().put(func.getClass(),functionDescriptor);
|
||||
}
|
||||
return getFunction(driverId, functionDescriptor);
|
||||
@@ -83,7 +83,7 @@ public class FunctionManager {
|
||||
* @param functionDescriptor The function descriptor.
|
||||
* @return A RayFunction object.
|
||||
*/
|
||||
public RayFunction getFunction(UniqueId driverId, FunctionDescriptor functionDescriptor) {
|
||||
public RayFunction getFunction(UniqueId driverId, JavaFunctionDescriptor functionDescriptor) {
|
||||
DriverFunctionTable driverFunctionTable = driverFunctionTables.get(driverId);
|
||||
if (driverFunctionTable == null) {
|
||||
String resourcePath = driverResourcePath + "/" + driverId.toString() + "/";
|
||||
@@ -122,7 +122,7 @@ public class FunctionManager {
|
||||
this.functions = new HashMap<>();
|
||||
}
|
||||
|
||||
RayFunction getFunction(FunctionDescriptor descriptor) {
|
||||
RayFunction getFunction(JavaFunctionDescriptor descriptor) {
|
||||
Map<Pair<String, String>, RayFunction> classFunctions = functions.get(descriptor.className);
|
||||
if (classFunctions == null) {
|
||||
classFunctions = loadFunctionsForClass(descriptor.className);
|
||||
@@ -150,7 +150,7 @@ public class FunctionManager {
|
||||
e instanceof Method ? Type.getType((Method) e) : Type.getType((Constructor) e);
|
||||
final String typeDescriptor = type.getDescriptor();
|
||||
RayFunction rayFunction = new RayFunction(e, classLoader,
|
||||
new FunctionDescriptor(className, methodName, typeDescriptor));
|
||||
new JavaFunctionDescriptor(className, methodName, typeDescriptor));
|
||||
map.put(ImmutablePair.of(methodName, typeDescriptor), rayFunction);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
|
||||
+52
@@ -0,0 +1,52 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
|
||||
/**
|
||||
* Represents metadata of Java function.
|
||||
*/
|
||||
public final class JavaFunctionDescriptor implements FunctionDescriptor {
|
||||
|
||||
/**
|
||||
* Function's class name.
|
||||
*/
|
||||
public final String className;
|
||||
/**
|
||||
* Function's name.
|
||||
*/
|
||||
public final String name;
|
||||
/**
|
||||
* Function's type descriptor.
|
||||
*/
|
||||
public final String typeDescriptor;
|
||||
|
||||
public JavaFunctionDescriptor(String className, String name, String typeDescriptor) {
|
||||
this.className = className;
|
||||
this.name = name;
|
||||
this.typeDescriptor = typeDescriptor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return className + "." + name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
JavaFunctionDescriptor that = (JavaFunctionDescriptor) o;
|
||||
return Objects.equal(className, that.className) &&
|
||||
Objects.equal(name, that.name) &&
|
||||
Objects.equal(typeDescriptor, that.typeDescriptor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(className, name, typeDescriptor);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
/**
|
||||
* Represents metadata of a Python function.
|
||||
*/
|
||||
public class PyFunctionDescriptor implements FunctionDescriptor {
|
||||
|
||||
public String moduleName;
|
||||
|
||||
public String className;
|
||||
|
||||
public String functionName;
|
||||
|
||||
public PyFunctionDescriptor(String moduleName, String className, String functionName) {
|
||||
this.moduleName = moduleName;
|
||||
this.className = className;
|
||||
this.functionName = functionName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return moduleName + "." + className + "." + functionName;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,10 +23,10 @@ public class RayFunction {
|
||||
/**
|
||||
* Function's metadata.
|
||||
*/
|
||||
public final FunctionDescriptor functionDescriptor;
|
||||
public final JavaFunctionDescriptor functionDescriptor;
|
||||
|
||||
public RayFunction(Executable executable, ClassLoader classLoader,
|
||||
FunctionDescriptor functionDescriptor) {
|
||||
JavaFunctionDescriptor functionDescriptor) {
|
||||
this.executable = executable;
|
||||
this.classLoader = classLoader;
|
||||
this.functionDescriptor = functionDescriptor;
|
||||
@@ -53,7 +53,7 @@ public class RayFunction {
|
||||
return (Method) executable;
|
||||
}
|
||||
|
||||
public FunctionDescriptor getFunctionDescriptor() {
|
||||
public JavaFunctionDescriptor getFunctionDescriptor() {
|
||||
return functionDescriptor;
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,8 @@ public class ObjectStoreProxy {
|
||||
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes();
|
||||
|
||||
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
|
||||
|
||||
private final AbstractRayRuntime runtime;
|
||||
|
||||
private static ThreadLocal<ObjectStoreLink> objectStore;
|
||||
@@ -83,9 +85,8 @@ public class ObjectStoreProxy {
|
||||
|
||||
GetResult<T> result;
|
||||
if (meta != null) {
|
||||
// If meta is not null, deserialize the exception.
|
||||
RayException exception = deserializeRayExceptionFromMeta(meta, ids.get(i));
|
||||
result = new GetResult<>(true, null, exception);
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
result = deserializeFromMeta(meta, data, ids.get(i));
|
||||
} else if (data != null) {
|
||||
// If data is not null, deserialize the Java object.
|
||||
Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader());
|
||||
@@ -112,13 +113,16 @@ public class ObjectStoreProxy {
|
||||
return results;
|
||||
}
|
||||
|
||||
private RayException deserializeRayExceptionFromMeta(byte[] meta, UniqueId objectId) {
|
||||
if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
|
||||
return RayWorkerException.INSTANCE;
|
||||
@SuppressWarnings("unchecked")
|
||||
private <T> GetResult<T> deserializeFromMeta(byte[] meta, byte[] data, UniqueId objectId) {
|
||||
if (Arrays.equals(meta, RAW_TYPE_META)) {
|
||||
return (GetResult<T>) new GetResult<>(true, data, null);
|
||||
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
|
||||
return new GetResult<>(true, null, RayWorkerException.INSTANCE);
|
||||
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
|
||||
return RayActorException.INSTANCE;
|
||||
return new GetResult<>(true, null, RayActorException.INSTANCE);
|
||||
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
|
||||
return new UnreconstructableException(objectId);
|
||||
return new GetResult<>(true, null, new UnreconstructableException(objectId));
|
||||
}
|
||||
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
|
||||
}
|
||||
@@ -131,7 +135,13 @@ public class ObjectStoreProxy {
|
||||
*/
|
||||
public void put(UniqueId id, Object object) {
|
||||
try {
|
||||
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
|
||||
if (object instanceof byte[]) {
|
||||
// If the object is a byte array, skip serializing it and use a special metadata to
|
||||
// indicate it's raw binary. So that this object can also be read by Python.
|
||||
objectStore.get().put(id.getBytes(), (byte[]) object, RAW_TYPE_META);
|
||||
} else {
|
||||
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
|
||||
}
|
||||
} catch (DuplicateObjectException e) {
|
||||
LOGGER.warn(e.getMessage());
|
||||
}
|
||||
|
||||
@@ -12,12 +12,13 @@ import org.ray.api.RayObject;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.generated.Arg;
|
||||
import org.ray.runtime.generated.Language;
|
||||
import org.ray.runtime.generated.ResourcePair;
|
||||
import org.ray.runtime.generated.TaskInfo;
|
||||
import org.ray.runtime.task.FunctionArg;
|
||||
import org.ray.runtime.task.TaskLanguage;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.ray.runtime.util.UniqueIdUtil;
|
||||
import org.slf4j.Logger;
|
||||
@@ -183,13 +184,14 @@ public class RayletClientImpl implements RayletClient {
|
||||
resources.put(info.requiredResources(i).key(), info.requiredResources(i).value());
|
||||
}
|
||||
// Deserialize function descriptor
|
||||
Preconditions.checkArgument(info.language() == Language.JAVA);
|
||||
Preconditions.checkArgument(info.functionDescriptorLength() == 3);
|
||||
FunctionDescriptor functionDescriptor = new FunctionDescriptor(
|
||||
JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor(
|
||||
info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2)
|
||||
);
|
||||
return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId,
|
||||
maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles,
|
||||
args, returnIds, resources, functionDescriptor);
|
||||
args, returnIds, resources, TaskLanguage.JAVA, functionDescriptor);
|
||||
}
|
||||
|
||||
private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) {
|
||||
@@ -250,12 +252,29 @@ public class RayletClientImpl implements RayletClient {
|
||||
int requiredPlacementResourcesOffset =
|
||||
fbb.createVectorOfTables(requiredPlacementResourcesOffsets);
|
||||
|
||||
int[] functionDescriptorOffsets = new int[]{
|
||||
fbb.createString(task.functionDescriptor.className),
|
||||
fbb.createString(task.functionDescriptor.name),
|
||||
fbb.createString(task.functionDescriptor.typeDescriptor)
|
||||
};
|
||||
int functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
|
||||
int language;
|
||||
int functionDescriptorOffset;
|
||||
|
||||
if (task.language == TaskLanguage.JAVA) {
|
||||
// This is a Java task.
|
||||
language = Language.JAVA;
|
||||
int[] functionDescriptorOffsets = new int[]{
|
||||
fbb.createString(task.getJavaFunctionDescriptor().className),
|
||||
fbb.createString(task.getJavaFunctionDescriptor().name),
|
||||
fbb.createString(task.getJavaFunctionDescriptor().typeDescriptor)
|
||||
};
|
||||
functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
|
||||
} else {
|
||||
// This is a Python task.
|
||||
language = Language.PYTHON;
|
||||
int[] functionDescriptorOffsets = new int[]{
|
||||
fbb.createString(task.getPyFunctionDescriptor().moduleName),
|
||||
fbb.createString(task.getPyFunctionDescriptor().className),
|
||||
fbb.createString(task.getPyFunctionDescriptor().functionName),
|
||||
fbb.createString("")
|
||||
};
|
||||
functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
|
||||
}
|
||||
|
||||
int root = TaskInfo.createTaskInfo(
|
||||
fbb,
|
||||
@@ -274,7 +293,7 @@ public class RayletClientImpl implements RayletClient {
|
||||
returnsOffset,
|
||||
requiredResourcesOffset,
|
||||
requiredPlacementResourcesOffset,
|
||||
Language.JAVA,
|
||||
language,
|
||||
functionDescriptorOffset);
|
||||
fbb.finish(root);
|
||||
ByteBuffer buffer = fbb.dataBuffer();
|
||||
|
||||
@@ -12,16 +12,15 @@ import org.ray.runtime.util.Serializer;
|
||||
public class ArgumentsBuilder {
|
||||
|
||||
/**
|
||||
* If the the size of an argument's serialized data is smaller than this number,
|
||||
* the argument will be passed by value. Otherwise it'll be passed by reference.
|
||||
* If the the size of an argument's serialized data is smaller than this number, the argument will
|
||||
* be passed by value. Otherwise it'll be passed by reference.
|
||||
*/
|
||||
private static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024;
|
||||
|
||||
|
||||
/**
|
||||
* Convert real function arguments to task spec arguments.
|
||||
*/
|
||||
public static FunctionArg[] wrap(Object[] args) {
|
||||
public static FunctionArg[] wrap(Object[] args, boolean crossLanguage) {
|
||||
FunctionArg[] ret = new FunctionArg[args.length];
|
||||
for (int i = 0; i < ret.length; i++) {
|
||||
Object arg = args[i];
|
||||
@@ -33,10 +32,15 @@ public class ArgumentsBuilder {
|
||||
data = Serializer.encode(arg);
|
||||
} else if (arg instanceof RayObject) {
|
||||
id = ((RayObject) arg).getId();
|
||||
} else if (arg instanceof byte[] && crossLanguage) {
|
||||
// If the argument is a byte array and will be used by a different language,
|
||||
// do not inline this argument. Because the other language doesn't know how
|
||||
// to deserialize it.
|
||||
id = Ray.put(arg).getId();
|
||||
} else {
|
||||
byte[] serialized = Serializer.encode(arg);
|
||||
if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
id = ((AbstractRayRuntime)Ray.internal()).putSerialized(serialized).getId();
|
||||
id = ((AbstractRayRuntime) Ray.internal()).putSerialized(serialized).getId();
|
||||
} else {
|
||||
data = serialized;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
/**
|
||||
* Language of a Ray task.
|
||||
*/
|
||||
public enum TaskLanguage {
|
||||
|
||||
JAVA,
|
||||
|
||||
PYTHON,
|
||||
}
|
||||
@@ -1,12 +1,14 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.util.ResourceUtil;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
|
||||
|
||||
/**
|
||||
* Represents necessary information of a task for scheduling and executing.
|
||||
@@ -52,9 +54,13 @@ public class TaskSpec {
|
||||
// The task's resource demands.
|
||||
public final Map<String, Double> resources;
|
||||
|
||||
// Function descriptor is a list of strings that can uniquely identify a function.
|
||||
// It will be sent to worker and used to load the target callable function.
|
||||
public final FunctionDescriptor functionDescriptor;
|
||||
// Language of this task.
|
||||
public final TaskLanguage language;
|
||||
|
||||
// Descriptor of the remote function.
|
||||
// Note, if task language is Java, the type is JavaFunctionDescriptor. If the task language
|
||||
// is Python, the type is PyFunctionDescriptor.
|
||||
private final FunctionDescriptor functionDescriptor;
|
||||
|
||||
private List<UniqueId> executionDependencies;
|
||||
|
||||
@@ -66,10 +72,22 @@ public class TaskSpec {
|
||||
return !actorCreationId.isNil();
|
||||
}
|
||||
|
||||
public TaskSpec(UniqueId driverId, UniqueId taskId, UniqueId parentTaskId, int parentCounter,
|
||||
UniqueId actorCreationId, int maxActorReconstructions, UniqueId actorId,
|
||||
UniqueId actorHandleId, int actorCounter, UniqueId[] newActorHandles, FunctionArg[] args,
|
||||
UniqueId[] returnIds, Map<String, Double> resources, FunctionDescriptor functionDescriptor) {
|
||||
public TaskSpec(
|
||||
UniqueId driverId,
|
||||
UniqueId taskId,
|
||||
UniqueId parentTaskId,
|
||||
int parentCounter,
|
||||
UniqueId actorCreationId,
|
||||
int maxActorReconstructions,
|
||||
UniqueId actorId,
|
||||
UniqueId actorHandleId,
|
||||
int actorCounter,
|
||||
UniqueId[] newActorHandles,
|
||||
FunctionArg[] args,
|
||||
UniqueId[] returnIds,
|
||||
Map<String, Double> resources,
|
||||
TaskLanguage language,
|
||||
FunctionDescriptor functionDescriptor) {
|
||||
this.driverId = driverId;
|
||||
this.taskId = taskId;
|
||||
this.parentTaskId = parentTaskId;
|
||||
@@ -83,10 +101,30 @@ public class TaskSpec {
|
||||
this.args = args;
|
||||
this.returnIds = returnIds;
|
||||
this.resources = resources;
|
||||
this.language = language;
|
||||
if (language == TaskLanguage.JAVA) {
|
||||
Preconditions.checkArgument(functionDescriptor instanceof JavaFunctionDescriptor,
|
||||
"Expect JavaFunctionDescriptor type, but got {}.", functionDescriptor.getClass());
|
||||
} else if (language == TaskLanguage.PYTHON) {
|
||||
Preconditions.checkArgument(functionDescriptor instanceof PyFunctionDescriptor,
|
||||
"Expect PyFunctionDescriptor type, but got {}.", functionDescriptor.getClass());
|
||||
} else {
|
||||
Preconditions.checkArgument(false, "Unknown task language: {}.", language);
|
||||
}
|
||||
this.functionDescriptor = functionDescriptor;
|
||||
this.executionDependencies = new ArrayList<>();
|
||||
}
|
||||
|
||||
public JavaFunctionDescriptor getJavaFunctionDescriptor() {
|
||||
Preconditions.checkState(language == TaskLanguage.JAVA);
|
||||
return (JavaFunctionDescriptor) functionDescriptor;
|
||||
}
|
||||
|
||||
public PyFunctionDescriptor getPyFunctionDescriptor() {
|
||||
Preconditions.checkState(language == TaskLanguage.PYTHON);
|
||||
return (PyFunctionDescriptor) functionDescriptor;
|
||||
}
|
||||
|
||||
public List<UniqueId> getExecutionDependencies() {
|
||||
return executionDependencies;
|
||||
}
|
||||
@@ -99,13 +137,17 @@ public class TaskSpec {
|
||||
", parentTaskId=" + parentTaskId +
|
||||
", parentCounter=" + parentCounter +
|
||||
", actorCreationId=" + actorCreationId +
|
||||
", maxActorReconstructions=" + maxActorReconstructions +
|
||||
", actorId=" + actorId +
|
||||
", actorHandleId=" + actorHandleId +
|
||||
", actorCounter=" + actorCounter +
|
||||
", newActorHandles=" + Arrays.toString(newActorHandles) +
|
||||
", args=" + Arrays.toString(args) +
|
||||
", returnIds=" + Arrays.toString(returnIds) +
|
||||
", resources=" + ResourceUtil.getResourcesStringFromMap(resources) +
|
||||
", resources=" + resources +
|
||||
", language=" + language +
|
||||
", functionDescriptor=" + functionDescriptor +
|
||||
", executionDependencies=" + executionDependencies +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
newLine("");
|
||||
newLine("package org.ray.api;");
|
||||
newLine("");
|
||||
newLine("import org.ray.api.function.RayFunc;");
|
||||
newLine("import org.ray.api.function.RayFunc0;");
|
||||
newLine("import org.ray.api.function.RayFunc1;");
|
||||
newLine("import org.ray.api.function.RayFunc2;");
|
||||
@@ -30,7 +29,6 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
newLine("import org.ray.api.function.RayFunc5;");
|
||||
newLine("import org.ray.api.function.RayFunc6;");
|
||||
newLine("import org.ray.api.options.ActorCreationOptions;");
|
||||
newLine("import org.ray.api.options.BaseTaskOptions;");
|
||||
newLine("import org.ray.api.options.CallOptions;");
|
||||
newLine("");
|
||||
|
||||
@@ -46,6 +44,7 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
buildCalls(i, false, false, false);
|
||||
buildCalls(i, false, false, true);
|
||||
}
|
||||
|
||||
newLine(1, "// ===========================================");
|
||||
newLine(1, "// Methods for remote actor method invocation.");
|
||||
newLine(1, "// ===========================================");
|
||||
@@ -59,6 +58,21 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
buildCalls(i, false, true, false);
|
||||
buildCalls(i, false, true, true);
|
||||
}
|
||||
|
||||
newLine(1, "// ===========================");
|
||||
newLine(1, "// Cross-language methods.");
|
||||
newLine(1, "// ===========================");
|
||||
for (int i = 0; i <= MAX_PARAMETERS; i++) {
|
||||
buildPyCalls(i, false, false, false);
|
||||
buildPyCalls(i, false, false, true);
|
||||
}
|
||||
for (int i = 0; i <= MAX_PARAMETERS - 1; i++) {
|
||||
buildPyCalls(i, true, false, false);
|
||||
}
|
||||
for (int i = 0; i <= MAX_PARAMETERS; i++) {
|
||||
buildPyCalls(i, false, true, false);
|
||||
buildPyCalls(i,false, true, true);
|
||||
}
|
||||
newLine("}");
|
||||
return sb.toString();
|
||||
}
|
||||
@@ -117,18 +131,86 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
String funcName = !forActorCreation ? "call" : "createActor";
|
||||
String funcArgs = !forActor ? "f, args" : "f, actor, args";
|
||||
for (String param : generateParameters(0, numParameters)) {
|
||||
// method signature
|
||||
// Method signature.
|
||||
newLine(1, String.format(
|
||||
"public static <%s> %s %s(%s%s) {",
|
||||
genericTypes, returnType, funcName, paramPrefix + param, optionsParam
|
||||
));
|
||||
// method body
|
||||
// Method body.
|
||||
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
|
||||
newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg));
|
||||
newLine(1, "}");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the `Ray.callPy` or `Ray.createPyActor` methods.
|
||||
* @param forActor build actor api when true, otherwise build task api.
|
||||
* @param forActorCreation build `Ray.createPyActor` when true, otherwise build `Ray.callPy`.
|
||||
*/
|
||||
private void buildPyCalls(int numParameters, boolean forActor,
|
||||
boolean forActorCreation, boolean hasOptionsParam) {
|
||||
String argList = "";
|
||||
String paramList = "";
|
||||
for (int i = 0; i < numParameters; i++) {
|
||||
paramList += "Object obj" + i + ", ";
|
||||
argList += "obj" + i + ", ";
|
||||
}
|
||||
if (argList.endsWith(", ")) {
|
||||
argList = argList.substring(0, argList.length() - 2);
|
||||
}
|
||||
if (paramList.endsWith(", ")) {
|
||||
paramList = paramList.substring(0, paramList.length() - 2);
|
||||
}
|
||||
|
||||
String paramPrefix = "";
|
||||
String funcArgs = "";
|
||||
if (forActorCreation) {
|
||||
paramPrefix += "String moduleName, String className";
|
||||
funcArgs += "moduleName, className";
|
||||
} else if (forActor) {
|
||||
paramPrefix += "RayPyActor pyActor, String functionName";
|
||||
funcArgs += "pyActor, functionName";
|
||||
} else {
|
||||
paramPrefix += "String moduleName, String functionName";
|
||||
funcArgs += "moduleName, functionName";
|
||||
}
|
||||
if (numParameters > 0) {
|
||||
paramPrefix += ", ";
|
||||
}
|
||||
|
||||
String optionsParam;
|
||||
if (hasOptionsParam) {
|
||||
optionsParam = forActorCreation ? ", ActorCreationOptions options" : ", CallOptions options";
|
||||
} else {
|
||||
optionsParam = "";
|
||||
}
|
||||
|
||||
String optionsArg;
|
||||
if (forActor) {
|
||||
optionsArg = "";
|
||||
} else {
|
||||
if (hasOptionsParam) {
|
||||
optionsArg = ", options";
|
||||
} else {
|
||||
optionsArg = ", null";
|
||||
}
|
||||
}
|
||||
|
||||
String returnType = !forActorCreation ? "RayObject" : "RayPyActor";
|
||||
String funcName = !forActorCreation ? "callPy" : "createPyActor";
|
||||
funcArgs += ", args";
|
||||
// Method signature.
|
||||
newLine(1, String.format(
|
||||
"public static %s %s(%s%s) {",
|
||||
returnType, funcName, paramPrefix + paramList, optionsParam
|
||||
));
|
||||
// Method body.
|
||||
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
|
||||
newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg));
|
||||
newLine(1, "}");
|
||||
}
|
||||
|
||||
private List<String> generateParameters(int from, int to) {
|
||||
List<String> res = new ArrayList<>();
|
||||
dfs(from, from, to, "", res);
|
||||
@@ -155,3 +237,4 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
FileUtil.overrideFile(path, new RayCallGenerator().build());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -40,20 +40,20 @@ public class FunctionManagerTest {
|
||||
private static RayFunc0<Object> fooFunc;
|
||||
private static RayFunc1<Bar, Object> barFunc;
|
||||
private static RayFunc0<Bar> barConstructor;
|
||||
private static FunctionDescriptor fooDescriptor;
|
||||
private static FunctionDescriptor barDescriptor;
|
||||
private static FunctionDescriptor barConstructorDescriptor;
|
||||
private static JavaFunctionDescriptor fooDescriptor;
|
||||
private static JavaFunctionDescriptor barDescriptor;
|
||||
private static JavaFunctionDescriptor barConstructorDescriptor;
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass() {
|
||||
fooFunc = FunctionManagerTest::foo;
|
||||
barConstructor = Bar::new;
|
||||
barFunc = Bar::bar;
|
||||
fooDescriptor = new FunctionDescriptor(FunctionManagerTest.class.getName(), "foo",
|
||||
fooDescriptor = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(), "foo",
|
||||
"()Ljava/lang/Object;");
|
||||
barDescriptor = new FunctionDescriptor(Bar.class.getName(), "bar",
|
||||
barDescriptor = new JavaFunctionDescriptor(Bar.class.getName(), "bar",
|
||||
"()Ljava/lang/Object;");
|
||||
barConstructorDescriptor = new FunctionDescriptor(Bar.class.getName(),
|
||||
barConstructorDescriptor = new JavaFunctionDescriptor(Bar.class.getName(),
|
||||
FunctionManager.CONSTRUCTOR_NAME,
|
||||
"()V");
|
||||
}
|
||||
@@ -132,7 +132,7 @@ public class FunctionManagerTest {
|
||||
Files.copy(Paths.get(srcJarPath), Paths.get(destJarPath), StandardCopyOption.REPLACE_EXISTING);
|
||||
|
||||
final FunctionManager functionManager = new FunctionManager(resourcePath);
|
||||
FunctionDescriptor sayHelloDescriptor = new FunctionDescriptor("org.ray.exercise.Exercise02",
|
||||
JavaFunctionDescriptor sayHelloDescriptor = new JavaFunctionDescriptor("org.ray.exercise.Exercise02",
|
||||
"sayHello", "()Ljava/lang/String;");
|
||||
RayFunction func = functionManager.getFunction(driverId, sayHelloDescriptor);
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), sayHelloDescriptor);
|
||||
|
||||
+4
-1
@@ -49,9 +49,12 @@
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<version>2.21.0</version>
|
||||
<version>3.0.0-M3</version>
|
||||
<configuration>
|
||||
<useFile>false</useFile>
|
||||
<trimStackTrace>false</trimStackTrace>
|
||||
<testSourceDirectory>${basedir}/src/main/java/</testSourceDirectory>
|
||||
<testResourcesDirectory>${basedir}/src/main/resources/</testResourcesDirectory>
|
||||
<testClassesDirectory>${project.build.directory}/classes/</testClassesDirectory>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import java.io.File;
|
||||
import java.lang.ProcessBuilder.Redirect;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.ray.api.Ray;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.testng.SkipException;
|
||||
import org.testng.annotations.AfterClass;
|
||||
import org.testng.annotations.BeforeClass;
|
||||
|
||||
public abstract class BaseMultiLanguageTest {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(BaseMultiLanguageTest.class);
|
||||
|
||||
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/test/plasma_store_socket";
|
||||
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
|
||||
|
||||
/**
|
||||
* Execute an external command.
|
||||
*
|
||||
* @return Whether the command succeeded.
|
||||
*/
|
||||
private boolean executeCommand(List<String> command, int waitTimeoutSeconds,
|
||||
Map<String, String> env) {
|
||||
try {
|
||||
LOGGER.info("Executing command: {}", String.join(" ", command));
|
||||
ProcessBuilder processBuilder = new ProcessBuilder(command).redirectOutput(Redirect.INHERIT)
|
||||
.redirectError(Redirect.INHERIT);
|
||||
for (Entry<String, String> entry : env.entrySet()) {
|
||||
processBuilder.environment().put(entry.getKey(), entry.getValue());
|
||||
}
|
||||
Process process = processBuilder.start();
|
||||
process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
|
||||
return process.exitValue() == 0;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
public void setUp() {
|
||||
if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) {
|
||||
LOGGER.info("Skip Multi-language tests because environment variable "
|
||||
+ "ENABLE_MULTI_LANGUAGE_TESTS isn't set");
|
||||
throw new SkipException("Skip test.");
|
||||
}
|
||||
|
||||
// Delete existing socket files.
|
||||
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
|
||||
File file = new File(socket);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
// Start ray cluster.
|
||||
String workerOptions =
|
||||
" -classpath " + System.getProperty("java.class.path");
|
||||
final List<String> startCommand = ImmutableList.of(
|
||||
"ray",
|
||||
"start",
|
||||
"--head",
|
||||
"--redis-port=6379",
|
||||
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
|
||||
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
|
||||
"--load-code-from-local",
|
||||
"--include-java",
|
||||
"--java-worker-options=" + workerOptions
|
||||
);
|
||||
if (!executeCommand(startCommand, 10, getRayStartEnv())) {
|
||||
throw new RuntimeException("Couldn't start ray cluster.");
|
||||
}
|
||||
|
||||
// Connect to the cluster.
|
||||
System.setProperty("ray.redis.address", "127.0.0.1:6379");
|
||||
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
|
||||
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
|
||||
Ray.init();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The environment variables needed for the `ray start` command.
|
||||
*/
|
||||
protected Map<String, String> getRayStartEnv() {
|
||||
return ImmutableMap.of();
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public void tearDown() {
|
||||
// Disconnect to the cluster.
|
||||
Ray.shutdown();
|
||||
System.clearProperty("ray.redis.address");
|
||||
System.clearProperty("ray.object-store.socket-name");
|
||||
System.clearProperty("ray.raylet.socket-name");
|
||||
|
||||
// Stop ray cluster.
|
||||
final List<String> stopCommand = ImmutableList.of(
|
||||
"ray",
|
||||
"stop"
|
||||
);
|
||||
if (!executeCommand(stopCommand, 10, ImmutableMap.of())) {
|
||||
throw new RuntimeException("Couldn't stop ray cluster");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
|
||||
|
||||
private static final String PYTHON_MODULE = "test_cross_language_invocation";
|
||||
|
||||
@Override
|
||||
protected Map<String, String> getRayStartEnv() {
|
||||
// Delete and re-create the temp dir.
|
||||
File tempDir = new File(
|
||||
System.getProperty("java.io.tmpdir") + File.separator + "ray_cross_language_test");
|
||||
FileUtils.deleteQuietly(tempDir);
|
||||
tempDir.mkdirs();
|
||||
tempDir.deleteOnExit();
|
||||
|
||||
// Write the test Python file to the temp dir.
|
||||
InputStream in = CrossLanguageInvocationTest.class
|
||||
.getResourceAsStream("/" + PYTHON_MODULE + ".py");
|
||||
File pythonFile = new File(
|
||||
tempDir.getAbsolutePath() + File.separator + PYTHON_MODULE + ".py");
|
||||
try {
|
||||
FileUtils.copyInputStreamToFile(in, pythonFile);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
return ImmutableMap.of("PYTHONPATH", tempDir.getAbsolutePath());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCallingPythonFunction() {
|
||||
RayObject res = Ray.callPy(PYTHON_MODULE, "py_func", "hello".getBytes());
|
||||
Assert.assertEquals(res.get(), "Response from Python: hello".getBytes());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCallingPythonActor() {
|
||||
RayPyActor actor = Ray.createPyActor(PYTHON_MODULE, "Counter", "1".getBytes());
|
||||
RayObject res = Ray.callPy(actor, "increase", "1".getBytes());
|
||||
Assert.assertEquals(res.get(), "2".getBytes());
|
||||
}
|
||||
}
|
||||
@@ -1,113 +1,18 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.io.File;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.testng.Assert;
|
||||
import org.testng.SkipException;
|
||||
import org.testng.annotations.AfterMethod;
|
||||
import org.testng.annotations.BeforeMethod;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
/**
|
||||
* Test starting a ray cluster with multi-language support.
|
||||
*/
|
||||
public class MultiLanguageClusterTest {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(MultiLanguageClusterTest.class);
|
||||
|
||||
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/test/plasma_store_socket";
|
||||
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
|
||||
public class MultiLanguageClusterTest extends BaseMultiLanguageTest {
|
||||
|
||||
@RayRemote
|
||||
public static String echo(String word) {
|
||||
return word;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute an external command.
|
||||
*
|
||||
* @return Whether the command succeeded.
|
||||
*/
|
||||
private boolean executeCommand(List<String> command, int waitTimeoutSeconds) {
|
||||
try {
|
||||
LOGGER.info("Executing command: {}", String.join(" ", command));
|
||||
Process process = new ProcessBuilder(command).inheritIO().start();
|
||||
process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
|
||||
return process.exitValue() == 0;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeMethod
|
||||
public void setUp(Method method) {
|
||||
String testName = method.getName();
|
||||
if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) {
|
||||
LOGGER.info("Skip " + testName +
|
||||
" because env variable ENABLE_MULTI_LANGUAGE_TESTS isn't set");
|
||||
throw new SkipException("Skip test.");
|
||||
}
|
||||
|
||||
// Delete existing socket files.
|
||||
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
|
||||
File file = new File(socket);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
// Start ray cluster.
|
||||
String testDir = System.getProperty("user.dir");
|
||||
String workerOptions =
|
||||
" -classpath " + String.format("%s/../../build/java/*:%s/target/*", testDir, testDir);
|
||||
final List<String> startCommand = ImmutableList.of(
|
||||
"ray",
|
||||
"start",
|
||||
"--head",
|
||||
"--redis-port=6379",
|
||||
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
|
||||
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
|
||||
"--load-code-from-local",
|
||||
"--include-java",
|
||||
"--java-worker-options=" + workerOptions
|
||||
);
|
||||
if (!executeCommand(startCommand, 10)) {
|
||||
throw new RuntimeException("Couldn't start ray cluster.");
|
||||
}
|
||||
|
||||
// Connect to the cluster.
|
||||
System.setProperty("ray.redis.address", "127.0.0.1:6379");
|
||||
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
|
||||
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
|
||||
Ray.init();
|
||||
}
|
||||
|
||||
@AfterMethod
|
||||
public void tearDown() {
|
||||
// Disconnect to the cluster.
|
||||
Ray.shutdown();
|
||||
System.clearProperty("ray.redis.address");
|
||||
System.clearProperty("ray.object-store.socket-name");
|
||||
System.clearProperty("ray.raylet.socket-name");
|
||||
|
||||
// Stop ray cluster.
|
||||
final List<String> stopCommand = ImmutableList.of(
|
||||
"ray",
|
||||
"stop"
|
||||
);
|
||||
if (!executeCommand(stopCommand, 10)) {
|
||||
throw new RuntimeException("Couldn't stop ray cluster");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultiLanguageCluster() {
|
||||
RayObject<String> obj = Ray.call(MultiLanguageClusterTest::echo, "hello");
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
# This file is used by CrossLanguageInvocationTest.java to test cross-language
|
||||
# invocation.
|
||||
import ray
|
||||
import six
|
||||
|
||||
|
||||
@ray.remote
|
||||
def py_func(value):
|
||||
assert isinstance(value, bytes)
|
||||
return b"Response from Python: " + value
|
||||
|
||||
|
||||
@ray.remote
|
||||
class Counter(object):
|
||||
def __init__(self, value):
|
||||
self.value = int(value)
|
||||
|
||||
def increase(self, delta):
|
||||
self.value += int(delta)
|
||||
return str(self.value).encode("utf-8") if six.PY3 else str(self.value)
|
||||
Reference in New Issue
Block a user