mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 03:37:20 +08:00
Cross-language invocation Part 1: Java calling Python functions and actors (#4166)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
@@ -10,6 +11,7 @@ import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.RuntimeContext;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.exception.RayException;
|
||||
@@ -20,12 +22,14 @@ import org.ray.api.options.BaseTaskOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy.GetResult;
|
||||
import org.ray.runtime.raylet.RayletClient;
|
||||
import org.ray.runtime.task.ArgumentsBuilder;
|
||||
import org.ray.runtime.task.TaskLanguage;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.ray.runtime.util.ResourceUtil;
|
||||
import org.ray.runtime.util.UniqueIdUtil;
|
||||
@@ -69,7 +73,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
functionManager = new FunctionManager(rayConfig.driverResourcePath);
|
||||
worker = new Worker(this);
|
||||
workerContext = new WorkerContext(rayConfig.workerMode,
|
||||
rayConfig.driverId, rayConfig.runMode);
|
||||
rayConfig.driverId, rayConfig.runMode);
|
||||
runtimeContext = new RuntimeContextImpl(this);
|
||||
}
|
||||
|
||||
@@ -229,7 +233,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
@Override
|
||||
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
|
||||
TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, args, false, options);
|
||||
TaskSpec spec = createTaskSpec(func, null, RayActorImpl.NIL, args, false, options);
|
||||
rayletClient.submitTask(spec);
|
||||
return new RayObjectImpl(spec.returnIds[0]);
|
||||
}
|
||||
@@ -242,7 +246,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
RayActorImpl<?> actorImpl = (RayActorImpl) actor;
|
||||
TaskSpec spec;
|
||||
synchronized (actor) {
|
||||
spec = createTaskSpec(func, actorImpl, args, false, null);
|
||||
spec = createTaskSpec(func, null, actorImpl, args, false, null);
|
||||
spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor());
|
||||
actorImpl.setTaskCursor(spec.returnIds[1]);
|
||||
actorImpl.clearNewActorHandles();
|
||||
@@ -255,7 +259,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc,
|
||||
Object[] args, ActorCreationOptions options) {
|
||||
TaskSpec spec = createTaskSpec(actorFactoryFunc, RayActorImpl.NIL,
|
||||
TaskSpec spec = createTaskSpec(actorFactoryFunc, null, RayActorImpl.NIL,
|
||||
args, true, options);
|
||||
RayActorImpl<?> actor = new RayActorImpl(spec.returnIds[0]);
|
||||
actor.increaseTaskCounter();
|
||||
@@ -264,17 +268,71 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
return (RayActor<T>) actor;
|
||||
}
|
||||
|
||||
private void checkPyArguments(Object[] args) {
|
||||
for (Object arg : args) {
|
||||
Preconditions.checkArgument(
|
||||
(arg instanceof RayPyActor) || (arg instanceof byte[]),
|
||||
"Python argument can only be a RayPyActor or a byte array, not {}.",
|
||||
arg.getClass().getName());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callPy(String moduleName, String functionName, Object[] args,
|
||||
CallOptions options) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, "", functionName);
|
||||
TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, false, options);
|
||||
rayletClient.submitTask(spec);
|
||||
return new RayObjectImpl(spec.returnIds[0]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callPy(RayPyActor pyActor, String functionName, Object... args) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor desc = new PyFunctionDescriptor(pyActor.getModuleName(),
|
||||
pyActor.getClassName(), functionName);
|
||||
RayPyActorImpl actorImpl = (RayPyActorImpl) pyActor;
|
||||
TaskSpec spec;
|
||||
synchronized (pyActor) {
|
||||
spec = createTaskSpec(null, desc, actorImpl, args, false, null);
|
||||
spec.getExecutionDependencies().add(actorImpl.getTaskCursor());
|
||||
actorImpl.setTaskCursor(spec.returnIds[1]);
|
||||
actorImpl.clearNewActorHandles();
|
||||
}
|
||||
rayletClient.submitTask(spec);
|
||||
return new RayObjectImpl(spec.returnIds[0]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayPyActor createPyActor(String moduleName, String className, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, className, "__init__");
|
||||
TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, true, options);
|
||||
RayPyActorImpl actor = new RayPyActorImpl(spec.actorCreationId, moduleName, className);
|
||||
actor.increaseTaskCounter();
|
||||
actor.setTaskCursor(spec.returnIds[0]);
|
||||
rayletClient.submitTask(spec);
|
||||
return actor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the task specification.
|
||||
*
|
||||
* @param func The target remote function.
|
||||
* @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a
|
||||
* Python task.
|
||||
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
|
||||
* @param args The arguments for the remote function.
|
||||
* @param isActorCreationTask Whether this task is an actor creation task.
|
||||
* @return A TaskSpec object.
|
||||
*/
|
||||
private TaskSpec createTaskSpec(RayFunc func, RayActorImpl<?> actor, Object[] args,
|
||||
private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDescriptor,
|
||||
RayActorImpl<?> actor, Object[] args,
|
||||
boolean isActorCreationTask, BaseTaskOptions taskOptions) {
|
||||
Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null));
|
||||
|
||||
UniqueId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(),
|
||||
workerContext.getCurrentTaskId(), workerContext.nextTaskIndex());
|
||||
int numReturns = actor.getId().isNil() ? 1 : 2;
|
||||
@@ -302,7 +360,16 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions;
|
||||
}
|
||||
|
||||
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentDriverId(), func);
|
||||
TaskLanguage language;
|
||||
FunctionDescriptor functionDescriptor;
|
||||
if (func != null) {
|
||||
language = TaskLanguage.JAVA;
|
||||
functionDescriptor = functionManager.getFunction(workerContext.getCurrentDriverId(), func)
|
||||
.getFunctionDescriptor();
|
||||
} else {
|
||||
language = TaskLanguage.PYTHON;
|
||||
functionDescriptor = pyFunctionDescriptor;
|
||||
}
|
||||
|
||||
return new TaskSpec(
|
||||
workerContext.getCurrentDriverId(),
|
||||
@@ -315,10 +382,11 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
actor.getHandleId(),
|
||||
actor.increaseTaskCounter(),
|
||||
actor.getNewActorHandles().toArray(new UniqueId[0]),
|
||||
ArgumentsBuilder.wrap(args),
|
||||
ArgumentsBuilder.wrap(args, language == TaskLanguage.PYTHON),
|
||||
returnIds,
|
||||
resources,
|
||||
rayFunction.getFunctionDescriptor()
|
||||
language,
|
||||
functionDescriptor
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -10,26 +10,32 @@ import org.ray.api.RayActor;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.util.Sha1Digestor;
|
||||
|
||||
public final class RayActorImpl<T> implements RayActor<T>, Externalizable {
|
||||
public class RayActorImpl<T> implements RayActor<T>, Externalizable {
|
||||
|
||||
public static final RayActorImpl NIL = new RayActorImpl();
|
||||
|
||||
private UniqueId id;
|
||||
private UniqueId handleId;
|
||||
/**
|
||||
* Id of this actor.
|
||||
*/
|
||||
protected UniqueId id;
|
||||
/**
|
||||
* Handle id of this actor.
|
||||
*/
|
||||
protected UniqueId handleId;
|
||||
/**
|
||||
* The number of tasks that have been invoked on this actor.
|
||||
*/
|
||||
private int taskCounter;
|
||||
protected int taskCounter;
|
||||
/**
|
||||
* The unique id of the last return of the last task.
|
||||
* It's used as a dependency for the next task.
|
||||
*/
|
||||
private UniqueId taskCursor;
|
||||
protected UniqueId taskCursor;
|
||||
/**
|
||||
* The number of times that this actor handle has been forked.
|
||||
* It's used to make sure ids of actor handles are unique.
|
||||
*/
|
||||
private int numForks;
|
||||
protected int numForks;
|
||||
|
||||
/**
|
||||
* The new actor handles that were created from this handle
|
||||
@@ -37,7 +43,7 @@ public final class RayActorImpl<T> implements RayActor<T>, Externalizable {
|
||||
* used to garbage-collect dummy objects that are no longer
|
||||
* necessary in the backend.
|
||||
*/
|
||||
private List<UniqueId> newActorHandles;
|
||||
protected List<UniqueId> newActorHandles;
|
||||
|
||||
public RayActorImpl() {
|
||||
this(UniqueId.NIL, UniqueId.NIL);
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInput;
|
||||
import java.io.ObjectOutput;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.id.UniqueId;
|
||||
|
||||
public class RayPyActorImpl extends RayActorImpl implements RayPyActor {
|
||||
|
||||
public static final RayPyActorImpl NIL = new RayPyActorImpl(UniqueId.NIL, null, null);
|
||||
|
||||
/**
|
||||
* Module name of the Python actor class.
|
||||
*/
|
||||
private String moduleName;
|
||||
|
||||
/**
|
||||
* Name of the Python actor class.
|
||||
*/
|
||||
private String className;
|
||||
|
||||
private RayPyActorImpl() {}
|
||||
|
||||
public RayPyActorImpl(UniqueId id, String moduleName, String className) {
|
||||
super(id);
|
||||
this.moduleName = moduleName;
|
||||
this.className = className;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModuleName() {
|
||||
return moduleName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getClassName() {
|
||||
return className;
|
||||
}
|
||||
|
||||
public RayPyActorImpl fork() {
|
||||
RayPyActorImpl ret = new RayPyActorImpl();
|
||||
ret.id = this.id;
|
||||
ret.taskCounter = 0;
|
||||
ret.numForks = 0;
|
||||
ret.taskCursor = this.taskCursor;
|
||||
ret.moduleName = this.moduleName;
|
||||
ret.className = this.className;
|
||||
ret.handleId = this.computeNextActorHandleId();
|
||||
newActorHandles.add(ret.handleId);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeExternal(ObjectOutput out) throws IOException {
|
||||
super.writeExternal(out);
|
||||
out.writeObject(this.moduleName);
|
||||
out.writeObject(this.className);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
|
||||
super.readExternal(in);
|
||||
this.moduleName = (String) in.readObject();
|
||||
this.className = (String) in.readObject();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ public class Worker {
|
||||
try {
|
||||
// Get method
|
||||
RayFunction rayFunction = runtime.getFunctionManager()
|
||||
.getFunction(spec.driverId, spec.functionDescriptor);
|
||||
.getFunction(spec.driverId, spec.getJavaFunctionDescriptor());
|
||||
// Set context
|
||||
runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader);
|
||||
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
|
||||
|
||||
@@ -1,52 +1,11 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
|
||||
/**
|
||||
* Represents the function's metadata.
|
||||
* Base interface of a Ray task's function descriptor.
|
||||
*
|
||||
* A function descriptor is a list of strings that can uniquely describe a function. It's used to
|
||||
* load a function in workers.
|
||||
*/
|
||||
public final class FunctionDescriptor {
|
||||
public interface FunctionDescriptor {
|
||||
|
||||
/**
|
||||
* Function's class name.
|
||||
*/
|
||||
public final String className;
|
||||
/**
|
||||
* Function's name.
|
||||
*/
|
||||
public final String name;
|
||||
/**
|
||||
* Function's type descriptor.
|
||||
*/
|
||||
public final String typeDescriptor;
|
||||
|
||||
public FunctionDescriptor(String className, String name, String typeDescriptor) {
|
||||
this.className = className;
|
||||
this.name = name;
|
||||
this.typeDescriptor = typeDescriptor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return className + "." + name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
FunctionDescriptor that = (FunctionDescriptor) o;
|
||||
return Objects.equal(className, that.className) &&
|
||||
Objects.equal(name, that.name) &&
|
||||
Objects.equal(typeDescriptor, that.typeDescriptor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(className, name, typeDescriptor);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,10 +30,10 @@ public class FunctionManager {
|
||||
static final String CONSTRUCTOR_NAME = "<init>";
|
||||
|
||||
/**
|
||||
* Cache from a RayFunc object to its corresponding FunctionDescriptor. Because
|
||||
* Cache from a RayFunc object to its corresponding JavaFunctionDescriptor. Because
|
||||
* `LambdaUtils.getSerializedLambda` is expensive.
|
||||
*/
|
||||
private static final ThreadLocal<WeakHashMap<Class<? extends RayFunc>, FunctionDescriptor>>
|
||||
private static final ThreadLocal<WeakHashMap<Class<? extends RayFunc>, JavaFunctionDescriptor>>
|
||||
RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new);
|
||||
|
||||
/**
|
||||
@@ -64,13 +64,13 @@ public class FunctionManager {
|
||||
* @return A RayFunction object.
|
||||
*/
|
||||
public RayFunction getFunction(UniqueId driverId, RayFunc func) {
|
||||
FunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
|
||||
JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
|
||||
if (functionDescriptor == null) {
|
||||
SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func);
|
||||
final String className = serializedLambda.getImplClass().replace('/', '.');
|
||||
final String methodName = serializedLambda.getImplMethodName();
|
||||
final String typeDescriptor = serializedLambda.getImplMethodSignature();
|
||||
functionDescriptor = new FunctionDescriptor(className, methodName, typeDescriptor);
|
||||
functionDescriptor = new JavaFunctionDescriptor(className, methodName, typeDescriptor);
|
||||
RAY_FUNC_CACHE.get().put(func.getClass(),functionDescriptor);
|
||||
}
|
||||
return getFunction(driverId, functionDescriptor);
|
||||
@@ -83,7 +83,7 @@ public class FunctionManager {
|
||||
* @param functionDescriptor The function descriptor.
|
||||
* @return A RayFunction object.
|
||||
*/
|
||||
public RayFunction getFunction(UniqueId driverId, FunctionDescriptor functionDescriptor) {
|
||||
public RayFunction getFunction(UniqueId driverId, JavaFunctionDescriptor functionDescriptor) {
|
||||
DriverFunctionTable driverFunctionTable = driverFunctionTables.get(driverId);
|
||||
if (driverFunctionTable == null) {
|
||||
String resourcePath = driverResourcePath + "/" + driverId.toString() + "/";
|
||||
@@ -122,7 +122,7 @@ public class FunctionManager {
|
||||
this.functions = new HashMap<>();
|
||||
}
|
||||
|
||||
RayFunction getFunction(FunctionDescriptor descriptor) {
|
||||
RayFunction getFunction(JavaFunctionDescriptor descriptor) {
|
||||
Map<Pair<String, String>, RayFunction> classFunctions = functions.get(descriptor.className);
|
||||
if (classFunctions == null) {
|
||||
classFunctions = loadFunctionsForClass(descriptor.className);
|
||||
@@ -150,7 +150,7 @@ public class FunctionManager {
|
||||
e instanceof Method ? Type.getType((Method) e) : Type.getType((Constructor) e);
|
||||
final String typeDescriptor = type.getDescriptor();
|
||||
RayFunction rayFunction = new RayFunction(e, classLoader,
|
||||
new FunctionDescriptor(className, methodName, typeDescriptor));
|
||||
new JavaFunctionDescriptor(className, methodName, typeDescriptor));
|
||||
map.put(ImmutablePair.of(methodName, typeDescriptor), rayFunction);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
|
||||
+52
@@ -0,0 +1,52 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
|
||||
/**
|
||||
* Represents metadata of Java function.
|
||||
*/
|
||||
public final class JavaFunctionDescriptor implements FunctionDescriptor {
|
||||
|
||||
/**
|
||||
* Function's class name.
|
||||
*/
|
||||
public final String className;
|
||||
/**
|
||||
* Function's name.
|
||||
*/
|
||||
public final String name;
|
||||
/**
|
||||
* Function's type descriptor.
|
||||
*/
|
||||
public final String typeDescriptor;
|
||||
|
||||
public JavaFunctionDescriptor(String className, String name, String typeDescriptor) {
|
||||
this.className = className;
|
||||
this.name = name;
|
||||
this.typeDescriptor = typeDescriptor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return className + "." + name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
JavaFunctionDescriptor that = (JavaFunctionDescriptor) o;
|
||||
return Objects.equal(className, that.className) &&
|
||||
Objects.equal(name, that.name) &&
|
||||
Objects.equal(typeDescriptor, that.typeDescriptor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(className, name, typeDescriptor);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
/**
|
||||
* Represents metadata of a Python function.
|
||||
*/
|
||||
public class PyFunctionDescriptor implements FunctionDescriptor {
|
||||
|
||||
public String moduleName;
|
||||
|
||||
public String className;
|
||||
|
||||
public String functionName;
|
||||
|
||||
public PyFunctionDescriptor(String moduleName, String className, String functionName) {
|
||||
this.moduleName = moduleName;
|
||||
this.className = className;
|
||||
this.functionName = functionName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return moduleName + "." + className + "." + functionName;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,10 +23,10 @@ public class RayFunction {
|
||||
/**
|
||||
* Function's metadata.
|
||||
*/
|
||||
public final FunctionDescriptor functionDescriptor;
|
||||
public final JavaFunctionDescriptor functionDescriptor;
|
||||
|
||||
public RayFunction(Executable executable, ClassLoader classLoader,
|
||||
FunctionDescriptor functionDescriptor) {
|
||||
JavaFunctionDescriptor functionDescriptor) {
|
||||
this.executable = executable;
|
||||
this.classLoader = classLoader;
|
||||
this.functionDescriptor = functionDescriptor;
|
||||
@@ -53,7 +53,7 @@ public class RayFunction {
|
||||
return (Method) executable;
|
||||
}
|
||||
|
||||
public FunctionDescriptor getFunctionDescriptor() {
|
||||
public JavaFunctionDescriptor getFunctionDescriptor() {
|
||||
return functionDescriptor;
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,8 @@ public class ObjectStoreProxy {
|
||||
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes();
|
||||
|
||||
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
|
||||
|
||||
private final AbstractRayRuntime runtime;
|
||||
|
||||
private static ThreadLocal<ObjectStoreLink> objectStore;
|
||||
@@ -83,9 +85,8 @@ public class ObjectStoreProxy {
|
||||
|
||||
GetResult<T> result;
|
||||
if (meta != null) {
|
||||
// If meta is not null, deserialize the exception.
|
||||
RayException exception = deserializeRayExceptionFromMeta(meta, ids.get(i));
|
||||
result = new GetResult<>(true, null, exception);
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
result = deserializeFromMeta(meta, data, ids.get(i));
|
||||
} else if (data != null) {
|
||||
// If data is not null, deserialize the Java object.
|
||||
Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader());
|
||||
@@ -112,13 +113,16 @@ public class ObjectStoreProxy {
|
||||
return results;
|
||||
}
|
||||
|
||||
private RayException deserializeRayExceptionFromMeta(byte[] meta, UniqueId objectId) {
|
||||
if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
|
||||
return RayWorkerException.INSTANCE;
|
||||
@SuppressWarnings("unchecked")
|
||||
private <T> GetResult<T> deserializeFromMeta(byte[] meta, byte[] data, UniqueId objectId) {
|
||||
if (Arrays.equals(meta, RAW_TYPE_META)) {
|
||||
return (GetResult<T>) new GetResult<>(true, data, null);
|
||||
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
|
||||
return new GetResult<>(true, null, RayWorkerException.INSTANCE);
|
||||
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
|
||||
return RayActorException.INSTANCE;
|
||||
return new GetResult<>(true, null, RayActorException.INSTANCE);
|
||||
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
|
||||
return new UnreconstructableException(objectId);
|
||||
return new GetResult<>(true, null, new UnreconstructableException(objectId));
|
||||
}
|
||||
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
|
||||
}
|
||||
@@ -131,7 +135,13 @@ public class ObjectStoreProxy {
|
||||
*/
|
||||
public void put(UniqueId id, Object object) {
|
||||
try {
|
||||
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
|
||||
if (object instanceof byte[]) {
|
||||
// If the object is a byte array, skip serializing it and use a special metadata to
|
||||
// indicate it's raw binary. So that this object can also be read by Python.
|
||||
objectStore.get().put(id.getBytes(), (byte[]) object, RAW_TYPE_META);
|
||||
} else {
|
||||
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
|
||||
}
|
||||
} catch (DuplicateObjectException e) {
|
||||
LOGGER.warn(e.getMessage());
|
||||
}
|
||||
|
||||
@@ -12,12 +12,13 @@ import org.ray.api.RayObject;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.generated.Arg;
|
||||
import org.ray.runtime.generated.Language;
|
||||
import org.ray.runtime.generated.ResourcePair;
|
||||
import org.ray.runtime.generated.TaskInfo;
|
||||
import org.ray.runtime.task.FunctionArg;
|
||||
import org.ray.runtime.task.TaskLanguage;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.ray.runtime.util.UniqueIdUtil;
|
||||
import org.slf4j.Logger;
|
||||
@@ -183,13 +184,14 @@ public class RayletClientImpl implements RayletClient {
|
||||
resources.put(info.requiredResources(i).key(), info.requiredResources(i).value());
|
||||
}
|
||||
// Deserialize function descriptor
|
||||
Preconditions.checkArgument(info.language() == Language.JAVA);
|
||||
Preconditions.checkArgument(info.functionDescriptorLength() == 3);
|
||||
FunctionDescriptor functionDescriptor = new FunctionDescriptor(
|
||||
JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor(
|
||||
info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2)
|
||||
);
|
||||
return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId,
|
||||
maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles,
|
||||
args, returnIds, resources, functionDescriptor);
|
||||
args, returnIds, resources, TaskLanguage.JAVA, functionDescriptor);
|
||||
}
|
||||
|
||||
private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) {
|
||||
@@ -250,12 +252,29 @@ public class RayletClientImpl implements RayletClient {
|
||||
int requiredPlacementResourcesOffset =
|
||||
fbb.createVectorOfTables(requiredPlacementResourcesOffsets);
|
||||
|
||||
int[] functionDescriptorOffsets = new int[]{
|
||||
fbb.createString(task.functionDescriptor.className),
|
||||
fbb.createString(task.functionDescriptor.name),
|
||||
fbb.createString(task.functionDescriptor.typeDescriptor)
|
||||
};
|
||||
int functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
|
||||
int language;
|
||||
int functionDescriptorOffset;
|
||||
|
||||
if (task.language == TaskLanguage.JAVA) {
|
||||
// This is a Java task.
|
||||
language = Language.JAVA;
|
||||
int[] functionDescriptorOffsets = new int[]{
|
||||
fbb.createString(task.getJavaFunctionDescriptor().className),
|
||||
fbb.createString(task.getJavaFunctionDescriptor().name),
|
||||
fbb.createString(task.getJavaFunctionDescriptor().typeDescriptor)
|
||||
};
|
||||
functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
|
||||
} else {
|
||||
// This is a Python task.
|
||||
language = Language.PYTHON;
|
||||
int[] functionDescriptorOffsets = new int[]{
|
||||
fbb.createString(task.getPyFunctionDescriptor().moduleName),
|
||||
fbb.createString(task.getPyFunctionDescriptor().className),
|
||||
fbb.createString(task.getPyFunctionDescriptor().functionName),
|
||||
fbb.createString("")
|
||||
};
|
||||
functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
|
||||
}
|
||||
|
||||
int root = TaskInfo.createTaskInfo(
|
||||
fbb,
|
||||
@@ -274,7 +293,7 @@ public class RayletClientImpl implements RayletClient {
|
||||
returnsOffset,
|
||||
requiredResourcesOffset,
|
||||
requiredPlacementResourcesOffset,
|
||||
Language.JAVA,
|
||||
language,
|
||||
functionDescriptorOffset);
|
||||
fbb.finish(root);
|
||||
ByteBuffer buffer = fbb.dataBuffer();
|
||||
|
||||
@@ -12,16 +12,15 @@ import org.ray.runtime.util.Serializer;
|
||||
public class ArgumentsBuilder {
|
||||
|
||||
/**
|
||||
* If the the size of an argument's serialized data is smaller than this number,
|
||||
* the argument will be passed by value. Otherwise it'll be passed by reference.
|
||||
* If the the size of an argument's serialized data is smaller than this number, the argument will
|
||||
* be passed by value. Otherwise it'll be passed by reference.
|
||||
*/
|
||||
private static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024;
|
||||
|
||||
|
||||
/**
|
||||
* Convert real function arguments to task spec arguments.
|
||||
*/
|
||||
public static FunctionArg[] wrap(Object[] args) {
|
||||
public static FunctionArg[] wrap(Object[] args, boolean crossLanguage) {
|
||||
FunctionArg[] ret = new FunctionArg[args.length];
|
||||
for (int i = 0; i < ret.length; i++) {
|
||||
Object arg = args[i];
|
||||
@@ -33,10 +32,15 @@ public class ArgumentsBuilder {
|
||||
data = Serializer.encode(arg);
|
||||
} else if (arg instanceof RayObject) {
|
||||
id = ((RayObject) arg).getId();
|
||||
} else if (arg instanceof byte[] && crossLanguage) {
|
||||
// If the argument is a byte array and will be used by a different language,
|
||||
// do not inline this argument. Because the other language doesn't know how
|
||||
// to deserialize it.
|
||||
id = Ray.put(arg).getId();
|
||||
} else {
|
||||
byte[] serialized = Serializer.encode(arg);
|
||||
if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
id = ((AbstractRayRuntime)Ray.internal()).putSerialized(serialized).getId();
|
||||
id = ((AbstractRayRuntime) Ray.internal()).putSerialized(serialized).getId();
|
||||
} else {
|
||||
data = serialized;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
/**
|
||||
* Language of a Ray task.
|
||||
*/
|
||||
public enum TaskLanguage {
|
||||
|
||||
JAVA,
|
||||
|
||||
PYTHON,
|
||||
}
|
||||
@@ -1,12 +1,14 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.util.ResourceUtil;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
|
||||
|
||||
/**
|
||||
* Represents necessary information of a task for scheduling and executing.
|
||||
@@ -52,9 +54,13 @@ public class TaskSpec {
|
||||
// The task's resource demands.
|
||||
public final Map<String, Double> resources;
|
||||
|
||||
// Function descriptor is a list of strings that can uniquely identify a function.
|
||||
// It will be sent to worker and used to load the target callable function.
|
||||
public final FunctionDescriptor functionDescriptor;
|
||||
// Language of this task.
|
||||
public final TaskLanguage language;
|
||||
|
||||
// Descriptor of the remote function.
|
||||
// Note, if task language is Java, the type is JavaFunctionDescriptor. If the task language
|
||||
// is Python, the type is PyFunctionDescriptor.
|
||||
private final FunctionDescriptor functionDescriptor;
|
||||
|
||||
private List<UniqueId> executionDependencies;
|
||||
|
||||
@@ -66,10 +72,22 @@ public class TaskSpec {
|
||||
return !actorCreationId.isNil();
|
||||
}
|
||||
|
||||
public TaskSpec(UniqueId driverId, UniqueId taskId, UniqueId parentTaskId, int parentCounter,
|
||||
UniqueId actorCreationId, int maxActorReconstructions, UniqueId actorId,
|
||||
UniqueId actorHandleId, int actorCounter, UniqueId[] newActorHandles, FunctionArg[] args,
|
||||
UniqueId[] returnIds, Map<String, Double> resources, FunctionDescriptor functionDescriptor) {
|
||||
public TaskSpec(
|
||||
UniqueId driverId,
|
||||
UniqueId taskId,
|
||||
UniqueId parentTaskId,
|
||||
int parentCounter,
|
||||
UniqueId actorCreationId,
|
||||
int maxActorReconstructions,
|
||||
UniqueId actorId,
|
||||
UniqueId actorHandleId,
|
||||
int actorCounter,
|
||||
UniqueId[] newActorHandles,
|
||||
FunctionArg[] args,
|
||||
UniqueId[] returnIds,
|
||||
Map<String, Double> resources,
|
||||
TaskLanguage language,
|
||||
FunctionDescriptor functionDescriptor) {
|
||||
this.driverId = driverId;
|
||||
this.taskId = taskId;
|
||||
this.parentTaskId = parentTaskId;
|
||||
@@ -83,10 +101,30 @@ public class TaskSpec {
|
||||
this.args = args;
|
||||
this.returnIds = returnIds;
|
||||
this.resources = resources;
|
||||
this.language = language;
|
||||
if (language == TaskLanguage.JAVA) {
|
||||
Preconditions.checkArgument(functionDescriptor instanceof JavaFunctionDescriptor,
|
||||
"Expect JavaFunctionDescriptor type, but got {}.", functionDescriptor.getClass());
|
||||
} else if (language == TaskLanguage.PYTHON) {
|
||||
Preconditions.checkArgument(functionDescriptor instanceof PyFunctionDescriptor,
|
||||
"Expect PyFunctionDescriptor type, but got {}.", functionDescriptor.getClass());
|
||||
} else {
|
||||
Preconditions.checkArgument(false, "Unknown task language: {}.", language);
|
||||
}
|
||||
this.functionDescriptor = functionDescriptor;
|
||||
this.executionDependencies = new ArrayList<>();
|
||||
}
|
||||
|
||||
public JavaFunctionDescriptor getJavaFunctionDescriptor() {
|
||||
Preconditions.checkState(language == TaskLanguage.JAVA);
|
||||
return (JavaFunctionDescriptor) functionDescriptor;
|
||||
}
|
||||
|
||||
public PyFunctionDescriptor getPyFunctionDescriptor() {
|
||||
Preconditions.checkState(language == TaskLanguage.PYTHON);
|
||||
return (PyFunctionDescriptor) functionDescriptor;
|
||||
}
|
||||
|
||||
public List<UniqueId> getExecutionDependencies() {
|
||||
return executionDependencies;
|
||||
}
|
||||
@@ -99,13 +137,17 @@ public class TaskSpec {
|
||||
", parentTaskId=" + parentTaskId +
|
||||
", parentCounter=" + parentCounter +
|
||||
", actorCreationId=" + actorCreationId +
|
||||
", maxActorReconstructions=" + maxActorReconstructions +
|
||||
", actorId=" + actorId +
|
||||
", actorHandleId=" + actorHandleId +
|
||||
", actorCounter=" + actorCounter +
|
||||
", newActorHandles=" + Arrays.toString(newActorHandles) +
|
||||
", args=" + Arrays.toString(args) +
|
||||
", returnIds=" + Arrays.toString(returnIds) +
|
||||
", resources=" + ResourceUtil.getResourcesStringFromMap(resources) +
|
||||
", resources=" + resources +
|
||||
", language=" + language +
|
||||
", functionDescriptor=" + functionDescriptor +
|
||||
", executionDependencies=" + executionDependencies +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
newLine("");
|
||||
newLine("package org.ray.api;");
|
||||
newLine("");
|
||||
newLine("import org.ray.api.function.RayFunc;");
|
||||
newLine("import org.ray.api.function.RayFunc0;");
|
||||
newLine("import org.ray.api.function.RayFunc1;");
|
||||
newLine("import org.ray.api.function.RayFunc2;");
|
||||
@@ -30,7 +29,6 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
newLine("import org.ray.api.function.RayFunc5;");
|
||||
newLine("import org.ray.api.function.RayFunc6;");
|
||||
newLine("import org.ray.api.options.ActorCreationOptions;");
|
||||
newLine("import org.ray.api.options.BaseTaskOptions;");
|
||||
newLine("import org.ray.api.options.CallOptions;");
|
||||
newLine("");
|
||||
|
||||
@@ -46,6 +44,7 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
buildCalls(i, false, false, false);
|
||||
buildCalls(i, false, false, true);
|
||||
}
|
||||
|
||||
newLine(1, "// ===========================================");
|
||||
newLine(1, "// Methods for remote actor method invocation.");
|
||||
newLine(1, "// ===========================================");
|
||||
@@ -59,6 +58,21 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
buildCalls(i, false, true, false);
|
||||
buildCalls(i, false, true, true);
|
||||
}
|
||||
|
||||
newLine(1, "// ===========================");
|
||||
newLine(1, "// Cross-language methods.");
|
||||
newLine(1, "// ===========================");
|
||||
for (int i = 0; i <= MAX_PARAMETERS; i++) {
|
||||
buildPyCalls(i, false, false, false);
|
||||
buildPyCalls(i, false, false, true);
|
||||
}
|
||||
for (int i = 0; i <= MAX_PARAMETERS - 1; i++) {
|
||||
buildPyCalls(i, true, false, false);
|
||||
}
|
||||
for (int i = 0; i <= MAX_PARAMETERS; i++) {
|
||||
buildPyCalls(i, false, true, false);
|
||||
buildPyCalls(i,false, true, true);
|
||||
}
|
||||
newLine("}");
|
||||
return sb.toString();
|
||||
}
|
||||
@@ -117,18 +131,86 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
String funcName = !forActorCreation ? "call" : "createActor";
|
||||
String funcArgs = !forActor ? "f, args" : "f, actor, args";
|
||||
for (String param : generateParameters(0, numParameters)) {
|
||||
// method signature
|
||||
// Method signature.
|
||||
newLine(1, String.format(
|
||||
"public static <%s> %s %s(%s%s) {",
|
||||
genericTypes, returnType, funcName, paramPrefix + param, optionsParam
|
||||
));
|
||||
// method body
|
||||
// Method body.
|
||||
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
|
||||
newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg));
|
||||
newLine(1, "}");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the `Ray.callPy` or `Ray.createPyActor` methods.
|
||||
* @param forActor build actor api when true, otherwise build task api.
|
||||
* @param forActorCreation build `Ray.createPyActor` when true, otherwise build `Ray.callPy`.
|
||||
*/
|
||||
private void buildPyCalls(int numParameters, boolean forActor,
|
||||
boolean forActorCreation, boolean hasOptionsParam) {
|
||||
String argList = "";
|
||||
String paramList = "";
|
||||
for (int i = 0; i < numParameters; i++) {
|
||||
paramList += "Object obj" + i + ", ";
|
||||
argList += "obj" + i + ", ";
|
||||
}
|
||||
if (argList.endsWith(", ")) {
|
||||
argList = argList.substring(0, argList.length() - 2);
|
||||
}
|
||||
if (paramList.endsWith(", ")) {
|
||||
paramList = paramList.substring(0, paramList.length() - 2);
|
||||
}
|
||||
|
||||
String paramPrefix = "";
|
||||
String funcArgs = "";
|
||||
if (forActorCreation) {
|
||||
paramPrefix += "String moduleName, String className";
|
||||
funcArgs += "moduleName, className";
|
||||
} else if (forActor) {
|
||||
paramPrefix += "RayPyActor pyActor, String functionName";
|
||||
funcArgs += "pyActor, functionName";
|
||||
} else {
|
||||
paramPrefix += "String moduleName, String functionName";
|
||||
funcArgs += "moduleName, functionName";
|
||||
}
|
||||
if (numParameters > 0) {
|
||||
paramPrefix += ", ";
|
||||
}
|
||||
|
||||
String optionsParam;
|
||||
if (hasOptionsParam) {
|
||||
optionsParam = forActorCreation ? ", ActorCreationOptions options" : ", CallOptions options";
|
||||
} else {
|
||||
optionsParam = "";
|
||||
}
|
||||
|
||||
String optionsArg;
|
||||
if (forActor) {
|
||||
optionsArg = "";
|
||||
} else {
|
||||
if (hasOptionsParam) {
|
||||
optionsArg = ", options";
|
||||
} else {
|
||||
optionsArg = ", null";
|
||||
}
|
||||
}
|
||||
|
||||
String returnType = !forActorCreation ? "RayObject" : "RayPyActor";
|
||||
String funcName = !forActorCreation ? "callPy" : "createPyActor";
|
||||
funcArgs += ", args";
|
||||
// Method signature.
|
||||
newLine(1, String.format(
|
||||
"public static %s %s(%s%s) {",
|
||||
returnType, funcName, paramPrefix + paramList, optionsParam
|
||||
));
|
||||
// Method body.
|
||||
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
|
||||
newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg));
|
||||
newLine(1, "}");
|
||||
}
|
||||
|
||||
private List<String> generateParameters(int from, int to) {
|
||||
List<String> res = new ArrayList<>();
|
||||
dfs(from, from, to, "", res);
|
||||
@@ -155,3 +237,4 @@ public class RayCallGenerator extends BaseGenerator {
|
||||
FileUtil.overrideFile(path, new RayCallGenerator().build());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -40,20 +40,20 @@ public class FunctionManagerTest {
|
||||
private static RayFunc0<Object> fooFunc;
|
||||
private static RayFunc1<Bar, Object> barFunc;
|
||||
private static RayFunc0<Bar> barConstructor;
|
||||
private static FunctionDescriptor fooDescriptor;
|
||||
private static FunctionDescriptor barDescriptor;
|
||||
private static FunctionDescriptor barConstructorDescriptor;
|
||||
private static JavaFunctionDescriptor fooDescriptor;
|
||||
private static JavaFunctionDescriptor barDescriptor;
|
||||
private static JavaFunctionDescriptor barConstructorDescriptor;
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass() {
|
||||
fooFunc = FunctionManagerTest::foo;
|
||||
barConstructor = Bar::new;
|
||||
barFunc = Bar::bar;
|
||||
fooDescriptor = new FunctionDescriptor(FunctionManagerTest.class.getName(), "foo",
|
||||
fooDescriptor = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(), "foo",
|
||||
"()Ljava/lang/Object;");
|
||||
barDescriptor = new FunctionDescriptor(Bar.class.getName(), "bar",
|
||||
barDescriptor = new JavaFunctionDescriptor(Bar.class.getName(), "bar",
|
||||
"()Ljava/lang/Object;");
|
||||
barConstructorDescriptor = new FunctionDescriptor(Bar.class.getName(),
|
||||
barConstructorDescriptor = new JavaFunctionDescriptor(Bar.class.getName(),
|
||||
FunctionManager.CONSTRUCTOR_NAME,
|
||||
"()V");
|
||||
}
|
||||
@@ -132,7 +132,7 @@ public class FunctionManagerTest {
|
||||
Files.copy(Paths.get(srcJarPath), Paths.get(destJarPath), StandardCopyOption.REPLACE_EXISTING);
|
||||
|
||||
final FunctionManager functionManager = new FunctionManager(resourcePath);
|
||||
FunctionDescriptor sayHelloDescriptor = new FunctionDescriptor("org.ray.exercise.Exercise02",
|
||||
JavaFunctionDescriptor sayHelloDescriptor = new JavaFunctionDescriptor("org.ray.exercise.Exercise02",
|
||||
"sayHello", "()Ljava/lang/String;");
|
||||
RayFunction func = functionManager.getFunction(driverId, sayHelloDescriptor);
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), sayHelloDescriptor);
|
||||
|
||||
Reference in New Issue
Block a user