Cross-language invocation Part 1: Java calling Python functions and actors (#4166)

This commit is contained in:
Hao Chen
2019-03-21 13:34:21 +08:00
committed by GitHub
parent 828dc08ac8
commit d03999d01e
28 changed files with 872 additions and 228 deletions
@@ -1,5 +1,6 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Arrays;
@@ -10,6 +11,7 @@ import java.util.Map;
import java.util.stream.Collectors;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
import org.ray.api.RuntimeContext;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
@@ -20,12 +22,14 @@ import org.ray.api.options.BaseTaskOptions;
import org.ray.api.options.CallOptions;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.functionmanager.FunctionDescriptor;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.functionmanager.RayFunction;
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.ray.runtime.objectstore.ObjectStoreProxy.GetResult;
import org.ray.runtime.raylet.RayletClient;
import org.ray.runtime.task.ArgumentsBuilder;
import org.ray.runtime.task.TaskLanguage;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.ResourceUtil;
import org.ray.runtime.util.UniqueIdUtil;
@@ -69,7 +73,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
functionManager = new FunctionManager(rayConfig.driverResourcePath);
worker = new Worker(this);
workerContext = new WorkerContext(rayConfig.workerMode,
rayConfig.driverId, rayConfig.runMode);
rayConfig.driverId, rayConfig.runMode);
runtimeContext = new RuntimeContextImpl(this);
}
@@ -229,7 +233,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, args, false, options);
TaskSpec spec = createTaskSpec(func, null, RayActorImpl.NIL, args, false, options);
rayletClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
@@ -242,7 +246,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
RayActorImpl<?> actorImpl = (RayActorImpl) actor;
TaskSpec spec;
synchronized (actor) {
spec = createTaskSpec(func, actorImpl, args, false, null);
spec = createTaskSpec(func, null, actorImpl, args, false, null);
spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor());
actorImpl.setTaskCursor(spec.returnIds[1]);
actorImpl.clearNewActorHandles();
@@ -255,7 +259,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@SuppressWarnings("unchecked")
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc,
Object[] args, ActorCreationOptions options) {
TaskSpec spec = createTaskSpec(actorFactoryFunc, RayActorImpl.NIL,
TaskSpec spec = createTaskSpec(actorFactoryFunc, null, RayActorImpl.NIL,
args, true, options);
RayActorImpl<?> actor = new RayActorImpl(spec.returnIds[0]);
actor.increaseTaskCounter();
@@ -264,17 +268,71 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return (RayActor<T>) actor;
}
private void checkPyArguments(Object[] args) {
for (Object arg : args) {
Preconditions.checkArgument(
(arg instanceof RayPyActor) || (arg instanceof byte[]),
"Python argument can only be a RayPyActor or a byte array, not {}.",
arg.getClass().getName());
}
}
@Override
public RayObject callPy(String moduleName, String functionName, Object[] args,
CallOptions options) {
checkPyArguments(args);
PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, "", functionName);
TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, false, options);
rayletClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
@Override
public RayObject callPy(RayPyActor pyActor, String functionName, Object... args) {
checkPyArguments(args);
PyFunctionDescriptor desc = new PyFunctionDescriptor(pyActor.getModuleName(),
pyActor.getClassName(), functionName);
RayPyActorImpl actorImpl = (RayPyActorImpl) pyActor;
TaskSpec spec;
synchronized (pyActor) {
spec = createTaskSpec(null, desc, actorImpl, args, false, null);
spec.getExecutionDependencies().add(actorImpl.getTaskCursor());
actorImpl.setTaskCursor(spec.returnIds[1]);
actorImpl.clearNewActorHandles();
}
rayletClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
@Override
public RayPyActor createPyActor(String moduleName, String className, Object[] args,
ActorCreationOptions options) {
checkPyArguments(args);
PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, className, "__init__");
TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, true, options);
RayPyActorImpl actor = new RayPyActorImpl(spec.actorCreationId, moduleName, className);
actor.increaseTaskCounter();
actor.setTaskCursor(spec.returnIds[0]);
rayletClient.submitTask(spec);
return actor;
}
/**
* Create the task specification.
*
* @param func The target remote function.
* @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a
* Python task.
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
* @param args The arguments for the remote function.
* @param isActorCreationTask Whether this task is an actor creation task.
* @return A TaskSpec object.
*/
private TaskSpec createTaskSpec(RayFunc func, RayActorImpl<?> actor, Object[] args,
private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDescriptor,
RayActorImpl<?> actor, Object[] args,
boolean isActorCreationTask, BaseTaskOptions taskOptions) {
Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null));
UniqueId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(),
workerContext.getCurrentTaskId(), workerContext.nextTaskIndex());
int numReturns = actor.getId().isNil() ? 1 : 2;
@@ -302,7 +360,16 @@ public abstract class AbstractRayRuntime implements RayRuntime {
maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions;
}
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentDriverId(), func);
TaskLanguage language;
FunctionDescriptor functionDescriptor;
if (func != null) {
language = TaskLanguage.JAVA;
functionDescriptor = functionManager.getFunction(workerContext.getCurrentDriverId(), func)
.getFunctionDescriptor();
} else {
language = TaskLanguage.PYTHON;
functionDescriptor = pyFunctionDescriptor;
}
return new TaskSpec(
workerContext.getCurrentDriverId(),
@@ -315,10 +382,11 @@ public abstract class AbstractRayRuntime implements RayRuntime {
actor.getHandleId(),
actor.increaseTaskCounter(),
actor.getNewActorHandles().toArray(new UniqueId[0]),
ArgumentsBuilder.wrap(args),
ArgumentsBuilder.wrap(args, language == TaskLanguage.PYTHON),
returnIds,
resources,
rayFunction.getFunctionDescriptor()
language,
functionDescriptor
);
}
@@ -10,26 +10,32 @@ import org.ray.api.RayActor;
import org.ray.api.id.UniqueId;
import org.ray.runtime.util.Sha1Digestor;
public final class RayActorImpl<T> implements RayActor<T>, Externalizable {
public class RayActorImpl<T> implements RayActor<T>, Externalizable {
public static final RayActorImpl NIL = new RayActorImpl();
private UniqueId id;
private UniqueId handleId;
/**
* Id of this actor.
*/
protected UniqueId id;
/**
* Handle id of this actor.
*/
protected UniqueId handleId;
/**
* The number of tasks that have been invoked on this actor.
*/
private int taskCounter;
protected int taskCounter;
/**
* The unique id of the last return of the last task.
* It's used as a dependency for the next task.
*/
private UniqueId taskCursor;
protected UniqueId taskCursor;
/**
* The number of times that this actor handle has been forked.
* It's used to make sure ids of actor handles are unique.
*/
private int numForks;
protected int numForks;
/**
* The new actor handles that were created from this handle
@@ -37,7 +43,7 @@ public final class RayActorImpl<T> implements RayActor<T>, Externalizable {
* used to garbage-collect dummy objects that are no longer
* necessary in the backend.
*/
private List<UniqueId> newActorHandles;
protected List<UniqueId> newActorHandles;
public RayActorImpl() {
this(UniqueId.NIL, UniqueId.NIL);
@@ -0,0 +1,69 @@
package org.ray.runtime;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import org.ray.api.RayPyActor;
import org.ray.api.id.UniqueId;
public class RayPyActorImpl extends RayActorImpl implements RayPyActor {
public static final RayPyActorImpl NIL = new RayPyActorImpl(UniqueId.NIL, null, null);
/**
* Module name of the Python actor class.
*/
private String moduleName;
/**
* Name of the Python actor class.
*/
private String className;
private RayPyActorImpl() {}
public RayPyActorImpl(UniqueId id, String moduleName, String className) {
super(id);
this.moduleName = moduleName;
this.className = className;
}
@Override
public String getModuleName() {
return moduleName;
}
@Override
public String getClassName() {
return className;
}
public RayPyActorImpl fork() {
RayPyActorImpl ret = new RayPyActorImpl();
ret.id = this.id;
ret.taskCounter = 0;
ret.numForks = 0;
ret.taskCursor = this.taskCursor;
ret.moduleName = this.moduleName;
ret.className = this.className;
ret.handleId = this.computeNextActorHandleId();
newActorHandles.add(ret.handleId);
return ret;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
super.writeExternal(out);
out.writeObject(this.moduleName);
out.writeObject(this.className);
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
super.readExternal(in);
this.moduleName = (String) in.readObject();
this.className = (String) in.readObject();
}
}
@@ -85,7 +85,7 @@ public class Worker {
try {
// Get method
RayFunction rayFunction = runtime.getFunctionManager()
.getFunction(spec.driverId, spec.functionDescriptor);
.getFunction(spec.driverId, spec.getJavaFunctionDescriptor());
// Set context
runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader);
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
@@ -1,52 +1,11 @@
package org.ray.runtime.functionmanager;
import com.google.common.base.Objects;
/**
* Represents the function's metadata.
* Base interface of a Ray task's function descriptor.
*
* A function descriptor is a list of strings that can uniquely describe a function. It's used to
* load a function in workers.
*/
public final class FunctionDescriptor {
public interface FunctionDescriptor {
/**
* Function's class name.
*/
public final String className;
/**
* Function's name.
*/
public final String name;
/**
* Function's type descriptor.
*/
public final String typeDescriptor;
public FunctionDescriptor(String className, String name, String typeDescriptor) {
this.className = className;
this.name = name;
this.typeDescriptor = typeDescriptor;
}
@Override
public String toString() {
return className + "." + name;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FunctionDescriptor that = (FunctionDescriptor) o;
return Objects.equal(className, that.className) &&
Objects.equal(name, that.name) &&
Objects.equal(typeDescriptor, that.typeDescriptor);
}
@Override
public int hashCode() {
return Objects.hashCode(className, name, typeDescriptor);
}
}
@@ -30,10 +30,10 @@ public class FunctionManager {
static final String CONSTRUCTOR_NAME = "<init>";
/**
* Cache from a RayFunc object to its corresponding FunctionDescriptor. Because
* Cache from a RayFunc object to its corresponding JavaFunctionDescriptor. Because
* `LambdaUtils.getSerializedLambda` is expensive.
*/
private static final ThreadLocal<WeakHashMap<Class<? extends RayFunc>, FunctionDescriptor>>
private static final ThreadLocal<WeakHashMap<Class<? extends RayFunc>, JavaFunctionDescriptor>>
RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new);
/**
@@ -64,13 +64,13 @@ public class FunctionManager {
* @return A RayFunction object.
*/
public RayFunction getFunction(UniqueId driverId, RayFunc func) {
FunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
if (functionDescriptor == null) {
SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func);
final String className = serializedLambda.getImplClass().replace('/', '.');
final String methodName = serializedLambda.getImplMethodName();
final String typeDescriptor = serializedLambda.getImplMethodSignature();
functionDescriptor = new FunctionDescriptor(className, methodName, typeDescriptor);
functionDescriptor = new JavaFunctionDescriptor(className, methodName, typeDescriptor);
RAY_FUNC_CACHE.get().put(func.getClass(),functionDescriptor);
}
return getFunction(driverId, functionDescriptor);
@@ -83,7 +83,7 @@ public class FunctionManager {
* @param functionDescriptor The function descriptor.
* @return A RayFunction object.
*/
public RayFunction getFunction(UniqueId driverId, FunctionDescriptor functionDescriptor) {
public RayFunction getFunction(UniqueId driverId, JavaFunctionDescriptor functionDescriptor) {
DriverFunctionTable driverFunctionTable = driverFunctionTables.get(driverId);
if (driverFunctionTable == null) {
String resourcePath = driverResourcePath + "/" + driverId.toString() + "/";
@@ -122,7 +122,7 @@ public class FunctionManager {
this.functions = new HashMap<>();
}
RayFunction getFunction(FunctionDescriptor descriptor) {
RayFunction getFunction(JavaFunctionDescriptor descriptor) {
Map<Pair<String, String>, RayFunction> classFunctions = functions.get(descriptor.className);
if (classFunctions == null) {
classFunctions = loadFunctionsForClass(descriptor.className);
@@ -150,7 +150,7 @@ public class FunctionManager {
e instanceof Method ? Type.getType((Method) e) : Type.getType((Constructor) e);
final String typeDescriptor = type.getDescriptor();
RayFunction rayFunction = new RayFunction(e, classLoader,
new FunctionDescriptor(className, methodName, typeDescriptor));
new JavaFunctionDescriptor(className, methodName, typeDescriptor));
map.put(ImmutablePair.of(methodName, typeDescriptor), rayFunction);
}
} catch (Exception e) {
@@ -0,0 +1,52 @@
package org.ray.runtime.functionmanager;
import com.google.common.base.Objects;
/**
* Represents metadata of Java function.
*/
public final class JavaFunctionDescriptor implements FunctionDescriptor {
/**
* Function's class name.
*/
public final String className;
/**
* Function's name.
*/
public final String name;
/**
* Function's type descriptor.
*/
public final String typeDescriptor;
public JavaFunctionDescriptor(String className, String name, String typeDescriptor) {
this.className = className;
this.name = name;
this.typeDescriptor = typeDescriptor;
}
@Override
public String toString() {
return className + "." + name;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
JavaFunctionDescriptor that = (JavaFunctionDescriptor) o;
return Objects.equal(className, that.className) &&
Objects.equal(name, that.name) &&
Objects.equal(typeDescriptor, that.typeDescriptor);
}
@Override
public int hashCode() {
return Objects.hashCode(className, name, typeDescriptor);
}
}
@@ -0,0 +1,25 @@
package org.ray.runtime.functionmanager;
/**
* Represents metadata of a Python function.
*/
public class PyFunctionDescriptor implements FunctionDescriptor {
public String moduleName;
public String className;
public String functionName;
public PyFunctionDescriptor(String moduleName, String className, String functionName) {
this.moduleName = moduleName;
this.className = className;
this.functionName = functionName;
}
@Override
public String toString() {
return moduleName + "." + className + "." + functionName;
}
}
@@ -23,10 +23,10 @@ public class RayFunction {
/**
* Function's metadata.
*/
public final FunctionDescriptor functionDescriptor;
public final JavaFunctionDescriptor functionDescriptor;
public RayFunction(Executable executable, ClassLoader classLoader,
FunctionDescriptor functionDescriptor) {
JavaFunctionDescriptor functionDescriptor) {
this.executable = executable;
this.classLoader = classLoader;
this.functionDescriptor = functionDescriptor;
@@ -53,7 +53,7 @@ public class RayFunction {
return (Method) executable;
}
public FunctionDescriptor getFunctionDescriptor() {
public JavaFunctionDescriptor getFunctionDescriptor() {
return functionDescriptor;
}
@@ -36,6 +36,8 @@ public class ObjectStoreProxy {
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes();
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
private final AbstractRayRuntime runtime;
private static ThreadLocal<ObjectStoreLink> objectStore;
@@ -83,9 +85,8 @@ public class ObjectStoreProxy {
GetResult<T> result;
if (meta != null) {
// If meta is not null, deserialize the exception.
RayException exception = deserializeRayExceptionFromMeta(meta, ids.get(i));
result = new GetResult<>(true, null, exception);
// If meta is not null, deserialize the object from meta.
result = deserializeFromMeta(meta, data, ids.get(i));
} else if (data != null) {
// If data is not null, deserialize the Java object.
Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader());
@@ -112,13 +113,16 @@ public class ObjectStoreProxy {
return results;
}
private RayException deserializeRayExceptionFromMeta(byte[] meta, UniqueId objectId) {
if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
return RayWorkerException.INSTANCE;
@SuppressWarnings("unchecked")
private <T> GetResult<T> deserializeFromMeta(byte[] meta, byte[] data, UniqueId objectId) {
if (Arrays.equals(meta, RAW_TYPE_META)) {
return (GetResult<T>) new GetResult<>(true, data, null);
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
return new GetResult<>(true, null, RayWorkerException.INSTANCE);
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
return RayActorException.INSTANCE;
return new GetResult<>(true, null, RayActorException.INSTANCE);
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
return new UnreconstructableException(objectId);
return new GetResult<>(true, null, new UnreconstructableException(objectId));
}
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
}
@@ -131,7 +135,13 @@ public class ObjectStoreProxy {
*/
public void put(UniqueId id, Object object) {
try {
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
if (object instanceof byte[]) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
objectStore.get().put(id.getBytes(), (byte[]) object, RAW_TYPE_META);
} else {
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
}
} catch (DuplicateObjectException e) {
LOGGER.warn(e.getMessage());
}
@@ -12,12 +12,13 @@ import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.functionmanager.FunctionDescriptor;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.generated.Arg;
import org.ray.runtime.generated.Language;
import org.ray.runtime.generated.ResourcePair;
import org.ray.runtime.generated.TaskInfo;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.TaskLanguage;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.UniqueIdUtil;
import org.slf4j.Logger;
@@ -183,13 +184,14 @@ public class RayletClientImpl implements RayletClient {
resources.put(info.requiredResources(i).key(), info.requiredResources(i).value());
}
// Deserialize function descriptor
Preconditions.checkArgument(info.language() == Language.JAVA);
Preconditions.checkArgument(info.functionDescriptorLength() == 3);
FunctionDescriptor functionDescriptor = new FunctionDescriptor(
JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor(
info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2)
);
return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId,
maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles,
args, returnIds, resources, functionDescriptor);
args, returnIds, resources, TaskLanguage.JAVA, functionDescriptor);
}
private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) {
@@ -250,12 +252,29 @@ public class RayletClientImpl implements RayletClient {
int requiredPlacementResourcesOffset =
fbb.createVectorOfTables(requiredPlacementResourcesOffsets);
int[] functionDescriptorOffsets = new int[]{
fbb.createString(task.functionDescriptor.className),
fbb.createString(task.functionDescriptor.name),
fbb.createString(task.functionDescriptor.typeDescriptor)
};
int functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
int language;
int functionDescriptorOffset;
if (task.language == TaskLanguage.JAVA) {
// This is a Java task.
language = Language.JAVA;
int[] functionDescriptorOffsets = new int[]{
fbb.createString(task.getJavaFunctionDescriptor().className),
fbb.createString(task.getJavaFunctionDescriptor().name),
fbb.createString(task.getJavaFunctionDescriptor().typeDescriptor)
};
functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
} else {
// This is a Python task.
language = Language.PYTHON;
int[] functionDescriptorOffsets = new int[]{
fbb.createString(task.getPyFunctionDescriptor().moduleName),
fbb.createString(task.getPyFunctionDescriptor().className),
fbb.createString(task.getPyFunctionDescriptor().functionName),
fbb.createString("")
};
functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
}
int root = TaskInfo.createTaskInfo(
fbb,
@@ -274,7 +293,7 @@ public class RayletClientImpl implements RayletClient {
returnsOffset,
requiredResourcesOffset,
requiredPlacementResourcesOffset,
Language.JAVA,
language,
functionDescriptorOffset);
fbb.finish(root);
ByteBuffer buffer = fbb.dataBuffer();
@@ -12,16 +12,15 @@ import org.ray.runtime.util.Serializer;
public class ArgumentsBuilder {
/**
* If the the size of an argument's serialized data is smaller than this number,
* the argument will be passed by value. Otherwise it'll be passed by reference.
* If the the size of an argument's serialized data is smaller than this number, the argument will
* be passed by value. Otherwise it'll be passed by reference.
*/
private static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024;
/**
* Convert real function arguments to task spec arguments.
*/
public static FunctionArg[] wrap(Object[] args) {
public static FunctionArg[] wrap(Object[] args, boolean crossLanguage) {
FunctionArg[] ret = new FunctionArg[args.length];
for (int i = 0; i < ret.length; i++) {
Object arg = args[i];
@@ -33,10 +32,15 @@ public class ArgumentsBuilder {
data = Serializer.encode(arg);
} else if (arg instanceof RayObject) {
id = ((RayObject) arg).getId();
} else if (arg instanceof byte[] && crossLanguage) {
// If the argument is a byte array and will be used by a different language,
// do not inline this argument. Because the other language doesn't know how
// to deserialize it.
id = Ray.put(arg).getId();
} else {
byte[] serialized = Serializer.encode(arg);
if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) {
id = ((AbstractRayRuntime)Ray.internal()).putSerialized(serialized).getId();
id = ((AbstractRayRuntime) Ray.internal()).putSerialized(serialized).getId();
} else {
data = serialized;
}
@@ -0,0 +1,11 @@
package org.ray.runtime.task;
/**
* Language of a Ray task.
*/
public enum TaskLanguage {
JAVA,
PYTHON,
}
@@ -1,12 +1,14 @@
package org.ray.runtime.task;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.ray.api.id.UniqueId;
import org.ray.runtime.functionmanager.FunctionDescriptor;
import org.ray.runtime.util.ResourceUtil;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
/**
* Represents necessary information of a task for scheduling and executing.
@@ -52,9 +54,13 @@ public class TaskSpec {
// The task's resource demands.
public final Map<String, Double> resources;
// Function descriptor is a list of strings that can uniquely identify a function.
// It will be sent to worker and used to load the target callable function.
public final FunctionDescriptor functionDescriptor;
// Language of this task.
public final TaskLanguage language;
// Descriptor of the remote function.
// Note, if task language is Java, the type is JavaFunctionDescriptor. If the task language
// is Python, the type is PyFunctionDescriptor.
private final FunctionDescriptor functionDescriptor;
private List<UniqueId> executionDependencies;
@@ -66,10 +72,22 @@ public class TaskSpec {
return !actorCreationId.isNil();
}
public TaskSpec(UniqueId driverId, UniqueId taskId, UniqueId parentTaskId, int parentCounter,
UniqueId actorCreationId, int maxActorReconstructions, UniqueId actorId,
UniqueId actorHandleId, int actorCounter, UniqueId[] newActorHandles, FunctionArg[] args,
UniqueId[] returnIds, Map<String, Double> resources, FunctionDescriptor functionDescriptor) {
public TaskSpec(
UniqueId driverId,
UniqueId taskId,
UniqueId parentTaskId,
int parentCounter,
UniqueId actorCreationId,
int maxActorReconstructions,
UniqueId actorId,
UniqueId actorHandleId,
int actorCounter,
UniqueId[] newActorHandles,
FunctionArg[] args,
UniqueId[] returnIds,
Map<String, Double> resources,
TaskLanguage language,
FunctionDescriptor functionDescriptor) {
this.driverId = driverId;
this.taskId = taskId;
this.parentTaskId = parentTaskId;
@@ -83,10 +101,30 @@ public class TaskSpec {
this.args = args;
this.returnIds = returnIds;
this.resources = resources;
this.language = language;
if (language == TaskLanguage.JAVA) {
Preconditions.checkArgument(functionDescriptor instanceof JavaFunctionDescriptor,
"Expect JavaFunctionDescriptor type, but got {}.", functionDescriptor.getClass());
} else if (language == TaskLanguage.PYTHON) {
Preconditions.checkArgument(functionDescriptor instanceof PyFunctionDescriptor,
"Expect PyFunctionDescriptor type, but got {}.", functionDescriptor.getClass());
} else {
Preconditions.checkArgument(false, "Unknown task language: {}.", language);
}
this.functionDescriptor = functionDescriptor;
this.executionDependencies = new ArrayList<>();
}
public JavaFunctionDescriptor getJavaFunctionDescriptor() {
Preconditions.checkState(language == TaskLanguage.JAVA);
return (JavaFunctionDescriptor) functionDescriptor;
}
public PyFunctionDescriptor getPyFunctionDescriptor() {
Preconditions.checkState(language == TaskLanguage.PYTHON);
return (PyFunctionDescriptor) functionDescriptor;
}
public List<UniqueId> getExecutionDependencies() {
return executionDependencies;
}
@@ -99,13 +137,17 @@ public class TaskSpec {
", parentTaskId=" + parentTaskId +
", parentCounter=" + parentCounter +
", actorCreationId=" + actorCreationId +
", maxActorReconstructions=" + maxActorReconstructions +
", actorId=" + actorId +
", actorHandleId=" + actorHandleId +
", actorCounter=" + actorCounter +
", newActorHandles=" + Arrays.toString(newActorHandles) +
", args=" + Arrays.toString(args) +
", returnIds=" + Arrays.toString(returnIds) +
", resources=" + ResourceUtil.getResourcesStringFromMap(resources) +
", resources=" + resources +
", language=" + language +
", functionDescriptor=" + functionDescriptor +
", executionDependencies=" + executionDependencies +
'}';
}
}
@@ -21,7 +21,6 @@ public class RayCallGenerator extends BaseGenerator {
newLine("");
newLine("package org.ray.api;");
newLine("");
newLine("import org.ray.api.function.RayFunc;");
newLine("import org.ray.api.function.RayFunc0;");
newLine("import org.ray.api.function.RayFunc1;");
newLine("import org.ray.api.function.RayFunc2;");
@@ -30,7 +29,6 @@ public class RayCallGenerator extends BaseGenerator {
newLine("import org.ray.api.function.RayFunc5;");
newLine("import org.ray.api.function.RayFunc6;");
newLine("import org.ray.api.options.ActorCreationOptions;");
newLine("import org.ray.api.options.BaseTaskOptions;");
newLine("import org.ray.api.options.CallOptions;");
newLine("");
@@ -46,6 +44,7 @@ public class RayCallGenerator extends BaseGenerator {
buildCalls(i, false, false, false);
buildCalls(i, false, false, true);
}
newLine(1, "// ===========================================");
newLine(1, "// Methods for remote actor method invocation.");
newLine(1, "// ===========================================");
@@ -59,6 +58,21 @@ public class RayCallGenerator extends BaseGenerator {
buildCalls(i, false, true, false);
buildCalls(i, false, true, true);
}
newLine(1, "// ===========================");
newLine(1, "// Cross-language methods.");
newLine(1, "// ===========================");
for (int i = 0; i <= MAX_PARAMETERS; i++) {
buildPyCalls(i, false, false, false);
buildPyCalls(i, false, false, true);
}
for (int i = 0; i <= MAX_PARAMETERS - 1; i++) {
buildPyCalls(i, true, false, false);
}
for (int i = 0; i <= MAX_PARAMETERS; i++) {
buildPyCalls(i, false, true, false);
buildPyCalls(i,false, true, true);
}
newLine("}");
return sb.toString();
}
@@ -117,18 +131,86 @@ public class RayCallGenerator extends BaseGenerator {
String funcName = !forActorCreation ? "call" : "createActor";
String funcArgs = !forActor ? "f, args" : "f, actor, args";
for (String param : generateParameters(0, numParameters)) {
// method signature
// Method signature.
newLine(1, String.format(
"public static <%s> %s %s(%s%s) {",
genericTypes, returnType, funcName, paramPrefix + param, optionsParam
));
// method body
// Method body.
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg));
newLine(1, "}");
}
}
/**
* Build the `Ray.callPy` or `Ray.createPyActor` methods.
* @param forActor build actor api when true, otherwise build task api.
* @param forActorCreation build `Ray.createPyActor` when true, otherwise build `Ray.callPy`.
*/
private void buildPyCalls(int numParameters, boolean forActor,
boolean forActorCreation, boolean hasOptionsParam) {
String argList = "";
String paramList = "";
for (int i = 0; i < numParameters; i++) {
paramList += "Object obj" + i + ", ";
argList += "obj" + i + ", ";
}
if (argList.endsWith(", ")) {
argList = argList.substring(0, argList.length() - 2);
}
if (paramList.endsWith(", ")) {
paramList = paramList.substring(0, paramList.length() - 2);
}
String paramPrefix = "";
String funcArgs = "";
if (forActorCreation) {
paramPrefix += "String moduleName, String className";
funcArgs += "moduleName, className";
} else if (forActor) {
paramPrefix += "RayPyActor pyActor, String functionName";
funcArgs += "pyActor, functionName";
} else {
paramPrefix += "String moduleName, String functionName";
funcArgs += "moduleName, functionName";
}
if (numParameters > 0) {
paramPrefix += ", ";
}
String optionsParam;
if (hasOptionsParam) {
optionsParam = forActorCreation ? ", ActorCreationOptions options" : ", CallOptions options";
} else {
optionsParam = "";
}
String optionsArg;
if (forActor) {
optionsArg = "";
} else {
if (hasOptionsParam) {
optionsArg = ", options";
} else {
optionsArg = ", null";
}
}
String returnType = !forActorCreation ? "RayObject" : "RayPyActor";
String funcName = !forActorCreation ? "callPy" : "createPyActor";
funcArgs += ", args";
// Method signature.
newLine(1, String.format(
"public static %s %s(%s%s) {",
returnType, funcName, paramPrefix + paramList, optionsParam
));
// Method body.
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg));
newLine(1, "}");
}
private List<String> generateParameters(int from, int to) {
List<String> res = new ArrayList<>();
dfs(from, from, to, "", res);
@@ -155,3 +237,4 @@ public class RayCallGenerator extends BaseGenerator {
FileUtil.overrideFile(path, new RayCallGenerator().build());
}
}
@@ -40,20 +40,20 @@ public class FunctionManagerTest {
private static RayFunc0<Object> fooFunc;
private static RayFunc1<Bar, Object> barFunc;
private static RayFunc0<Bar> barConstructor;
private static FunctionDescriptor fooDescriptor;
private static FunctionDescriptor barDescriptor;
private static FunctionDescriptor barConstructorDescriptor;
private static JavaFunctionDescriptor fooDescriptor;
private static JavaFunctionDescriptor barDescriptor;
private static JavaFunctionDescriptor barConstructorDescriptor;
@BeforeClass
public static void beforeClass() {
fooFunc = FunctionManagerTest::foo;
barConstructor = Bar::new;
barFunc = Bar::bar;
fooDescriptor = new FunctionDescriptor(FunctionManagerTest.class.getName(), "foo",
fooDescriptor = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(), "foo",
"()Ljava/lang/Object;");
barDescriptor = new FunctionDescriptor(Bar.class.getName(), "bar",
barDescriptor = new JavaFunctionDescriptor(Bar.class.getName(), "bar",
"()Ljava/lang/Object;");
barConstructorDescriptor = new FunctionDescriptor(Bar.class.getName(),
barConstructorDescriptor = new JavaFunctionDescriptor(Bar.class.getName(),
FunctionManager.CONSTRUCTOR_NAME,
"()V");
}
@@ -132,7 +132,7 @@ public class FunctionManagerTest {
Files.copy(Paths.get(srcJarPath), Paths.get(destJarPath), StandardCopyOption.REPLACE_EXISTING);
final FunctionManager functionManager = new FunctionManager(resourcePath);
FunctionDescriptor sayHelloDescriptor = new FunctionDescriptor("org.ray.exercise.Exercise02",
JavaFunctionDescriptor sayHelloDescriptor = new JavaFunctionDescriptor("org.ray.exercise.Exercise02",
"sayHello", "()Ljava/lang/String;");
RayFunction func = functionManager.getFunction(driverId, sayHelloDescriptor);
Assert.assertEquals(func.getFunctionDescriptor(), sayHelloDescriptor);